colormaps is a module attribute, not a submodule — remove the explicit import. Also relax matplotlib pin back to >=3.5 (minimum for colormaps attribute). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
"""OliVAS Processing Cloud Run Service.
|
|
|
|
Handles image post-processing from saliency maps:
|
|
- Heatmap overlay generation
|
|
- Standalone heatmap generation
|
|
- Gaze sequence visualization image
|
|
"""
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import numpy as np
|
|
from fastapi import FastAPI, Header, HTTPException
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from pydantic import BaseModel
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("olivas.processing")
|
|
|
|
INTERNAL_SECRET = os.environ.get("CLOUD_RUN_SECRET", "")
|
|
|
|
app = FastAPI(title="OliVAS Processing Service")
|
|
|
|
|
|
def _check_auth(x_internal_secret: str | None) -> None:
|
|
if INTERNAL_SECRET and x_internal_secret != INTERNAL_SECRET:
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
class ProcessRequest(BaseModel):
|
|
image_b64: str
|
|
saliency_b64: str
|
|
shape: list[int] # [H, W]
|
|
gaze_sequence: list[dict]
|
|
|
|
|
|
def _img_to_b64(img: Image.Image) -> str:
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="PNG")
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
def _decode_saliency(saliency_b64: str, shape: list[int]) -> np.ndarray:
|
|
h, w = shape
|
|
raw = base64.b64decode(saliency_b64)
|
|
return np.frombuffer(raw, dtype=np.float32).reshape(h, w)
|
|
|
|
|
|
def _generate_heatmap_overlay(image: Image.Image, saliency: np.ndarray) -> Image.Image:
|
|
cmap = matplotlib.colormaps.get_cmap("jet")
|
|
heatmap_rgba = cmap(saliency)
|
|
heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
|
|
heatmap_img = Image.fromarray(heatmap_rgb).resize(image.size, Image.LANCZOS)
|
|
return Image.blend(image.convert("RGB"), heatmap_img, 0.5)
|
|
|
|
|
|
def _generate_standalone_heatmap(saliency: np.ndarray) -> Image.Image:
|
|
cmap = matplotlib.colormaps.get_cmap("jet")
|
|
heatmap_rgba = cmap(saliency)
|
|
return Image.fromarray((heatmap_rgba[:, :, :3] * 255).astype(np.uint8))
|
|
|
|
|
|
def _draw_gaze_sequence(image: Image.Image, gaze_seq: list[dict]) -> Image.Image:
|
|
img = image.copy()
|
|
draw = ImageDraw.Draw(img)
|
|
font = ImageFont.load_default(size=24)
|
|
colors = ["#FF4444", "#FF8800", "#FFCC00", "#44CC44", "#4488FF"]
|
|
|
|
for i, point in enumerate(gaze_seq):
|
|
x, y = point["x"], point["y"]
|
|
color = colors[i % len(colors)]
|
|
r = 25
|
|
draw.ellipse([x - r, y - r, x + r, y + r], outline=color, width=3)
|
|
draw.text((x - 6, y - 12), str(point["rank"]), fill=color, font=font)
|
|
if i < len(gaze_seq) - 1:
|
|
nx, ny = gaze_seq[i + 1]["x"], gaze_seq[i + 1]["y"]
|
|
draw.line([x, y, nx, ny], fill=color, width=2)
|
|
|
|
return img
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/process")
|
|
async def process_images(
|
|
request: ProcessRequest,
|
|
x_internal_secret: str | None = Header(None),
|
|
):
|
|
_check_auth(x_internal_secret)
|
|
|
|
image_data = base64.b64decode(request.image_b64)
|
|
saliency = _decode_saliency(request.saliency_b64, request.shape)
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
logger.info(f"Processing image {image.size}, saliency {saliency.shape}")
|
|
|
|
overlay = _generate_heatmap_overlay(image, saliency)
|
|
standalone = _generate_standalone_heatmap(saliency)
|
|
gaze_img = _draw_gaze_sequence(image, request.gaze_sequence)
|
|
|
|
return {
|
|
"heatmap_overlay_b64": _img_to_b64(overlay),
|
|
"heatmap_standalone_b64": _img_to_b64(standalone),
|
|
"gaze_sequence_img_b64": _img_to_b64(gaze_img),
|
|
}
|