olivas/cloud_run/processing/main.py
Vadym Samoilenko ce1d10d9b2 Fix matplotlib.colormaps import in processing service
colormaps is a module attribute, not a submodule — remove the explicit import.
Also relax matplotlib pin back to >=3.5 (minimum for colormaps attribute).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-11 13:34:00 +00:00

111 lines
3.4 KiB
Python

"""OliVAS Processing Cloud Run Service.
Handles image post-processing from saliency maps:
- Heatmap overlay generation
- Standalone heatmap generation
- Gaze sequence visualization image
"""
import base64
import io
import logging
import os
import matplotlib
matplotlib.use("Agg")
import numpy as np
from fastapi import FastAPI, Header, HTTPException
from PIL import Image, ImageDraw, ImageFont
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("olivas.processing")
INTERNAL_SECRET = os.environ.get("CLOUD_RUN_SECRET", "")
app = FastAPI(title="OliVAS Processing Service")
def _check_auth(x_internal_secret: str | None) -> None:
if INTERNAL_SECRET and x_internal_secret != INTERNAL_SECRET:
raise HTTPException(status_code=401, detail="Unauthorized")
class ProcessRequest(BaseModel):
image_b64: str
saliency_b64: str
shape: list[int] # [H, W]
gaze_sequence: list[dict]
def _img_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def _decode_saliency(saliency_b64: str, shape: list[int]) -> np.ndarray:
h, w = shape
raw = base64.b64decode(saliency_b64)
return np.frombuffer(raw, dtype=np.float32).reshape(h, w)
def _generate_heatmap_overlay(image: Image.Image, saliency: np.ndarray) -> Image.Image:
cmap = matplotlib.colormaps.get_cmap("jet")
heatmap_rgba = cmap(saliency)
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
heatmap_img = Image.fromarray(heatmap_rgb).resize(image.size, Image.LANCZOS)
return Image.blend(image.convert("RGB"), heatmap_img, 0.5)
def _generate_standalone_heatmap(saliency: np.ndarray) -> Image.Image:
cmap = matplotlib.colormaps.get_cmap("jet")
heatmap_rgba = cmap(saliency)
return Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype(np.uint8))
def _draw_gaze_sequence(image: Image.Image, gaze_seq: list[dict]) -> Image.Image:
img = image.copy()
draw = ImageDraw.Draw(img)
font = ImageFont.load_default(size=24)
colors = ["#FF4444", "#FF8800", "#FFCC00", "#44CC44", "#4488FF"]
for i, point in enumerate(gaze_seq):
x, y = point["x"], point["y"]
color = colors[i % len(colors)]
r = 25
draw.ellipse([x - r, y - r, x + r, y + r], outline=color, width=3)
draw.text((x - 6, y - 12), str(point["rank"]), fill=color, font=font)
if i < len(gaze_seq) - 1:
nx, ny = gaze_seq[i + 1]["x"], gaze_seq[i + 1]["y"]
draw.line([x, y, nx, ny], fill=color, width=2)
return img
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/process")
async def process_images(
request: ProcessRequest,
x_internal_secret: str | None = Header(None),
):
_check_auth(x_internal_secret)
image_data = base64.b64decode(request.image_b64)
saliency = _decode_saliency(request.saliency_b64, request.shape)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
logger.info(f"Processing image {image.size}, saliency {saliency.shape}")
overlay = _generate_heatmap_overlay(image, saliency)
standalone = _generate_standalone_heatmap(saliency)
gaze_img = _draw_gaze_sequence(image, request.gaze_sequence)
return {
"heatmap_overlay_b64": _img_to_b64(overlay),
"heatmap_standalone_b64": _img_to_b64(standalone),
"gaze_sequence_img_b64": _img_to_b64(gaze_img),
}