olivas/cloud_run/saliency/main.py
Vadym Samoilenko 2c5e17c7c4 Add Google Cloud Run offloading for ML inference and image processing
- Create cloud_run/saliency: FastAPI service running DeepGaze I/IIE/III
  on Cloud Run (4 vCPU, 16GB RAM); pre-downloads model weights in Docker
  build to eliminate cold-start delays; returns saliency map + gaze
  sequence + hotspots + design scores
- Create cloud_run/processing: lightweight FastAPI service for heatmap
  generation and gaze sequence visualization (2 vCPU, 4GB RAM)
- Add cloud_run/deploy.sh for gcloud deployment to project optical-414516
  in region europe-west2
- Refactor analysis pipeline to route via Cloud Run when
  CLOUD_RUN_SALIENCY_URL is set, with local fallback for dev mode
- Add cloud_run_client.py with sync httpx wrappers for background tasks
- Split pyproject.toml: base = API-only deps, [ml] = torch/deepgaze for
  local dev; production Dockerfile is now lightweight (~no PyTorch)
- Preserve Dockerfile.full + docker-compose.dev.yml for local ML dev
- Auth via X-Internal-Secret header (CLOUD_RUN_SECRET env var)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-04 19:39:52 +00:00

331 lines
10 KiB
Python

"""OliVAS Saliency Cloud Run Service.
Runs DeepGaze saliency inference and returns:
- saliency map (base64 float32 bytes)
- gaze sequence
- hotspots
- design effectiveness scores
"""
import base64
import io
import logging
import os
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, File, Form, Header, HTTPException, UploadFile
from PIL import Image
from scipy.ndimage import gaussian_filter, zoom
from scipy.special import logsumexp
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("olivas.saliency")
INTERNAL_SECRET = os.environ.get("CLOUD_RUN_SECRET", "")
DEVICE = os.environ.get("DEVICE", "auto")
# Global model cache: {model_key: {"model": ..., "centerbias": ...}}
_model_cache: dict = {}
VARIANT_MAP = {
"deepgaze_i": ("DeepGazeI", "I"),
"deepgaze_iie": ("DeepGazeIIE", "IIE"),
"deepgaze_iii": ("DeepGazeIII", "III"),
}
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("OliVAS Saliency service starting")
yield
_model_cache.clear()
logger.info("OliVAS Saliency service stopped")
app = FastAPI(title="OliVAS Saliency Service", lifespan=lifespan)
def _check_auth(x_internal_secret: str | None) -> None:
if INTERNAL_SECRET and x_internal_secret != INTERNAL_SECRET:
raise HTTPException(status_code=401, detail="Unauthorized")
def _resolve_device() -> str:
if DEVICE == "auto":
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
return DEVICE
def _get_model(model_name: str) -> dict:
device = _resolve_device()
key = f"{model_name}:{device}"
if key in _model_cache:
return _model_cache[key]
if model_name not in VARIANT_MAP:
raise ValueError(f"Unknown model: {model_name}. Choose from {list(VARIANT_MAP)}")
class_name, _ = VARIANT_MAP[model_name]
import torch
import deepgaze_pytorch
logger.info(f"Loading {class_name} on {device}...")
device_obj = torch.device(device)
model_cls = getattr(deepgaze_pytorch, class_name)
model = model_cls(pretrained=True).to(device_obj)
model.eval()
# Pre-compute centerbias template
size = 1024
x = np.linspace(-1, 1, size)
y = np.linspace(-1, 1, size)
xx, yy = np.meshgrid(x, y)
centerbias = -0.5 * (xx**2 + yy**2) / 0.5**2
_model_cache[key] = {"model": model, "centerbias": centerbias, "device": device_obj}
logger.info(f"Loaded {class_name}")
return _model_cache[key]
def _run_inference(image: Image.Image, model_name: str) -> np.ndarray:
import torch
model_data = _get_model(model_name)
model = model_data["model"]
centerbias_template = model_data["centerbias"]
device_obj = model_data["device"]
img_np = np.array(image.convert("RGB"))
h, w = img_np.shape[:2]
image_tensor = torch.tensor([img_np.transpose(2, 0, 1)]).float().to(device_obj)
cb = zoom(
centerbias_template,
(h / centerbias_template.shape[0], w / centerbias_template.shape[1]),
order=0,
)
cb -= logsumexp(cb)
centerbias_tensor = torch.tensor([cb]).float().to(device_obj)
with torch.no_grad():
log_density = model(image_tensor, centerbias_tensor)
saliency = torch.exp(log_density).cpu().numpy().squeeze()
sal_min, sal_max = saliency.min(), saliency.max()
if sal_max - sal_min > 1e-10:
saliency = (saliency - sal_min) / (sal_max - sal_min)
else:
saliency = np.zeros_like(saliency)
return saliency
def _prepare_for_inference(image: Image.Image, max_size: int = 1024) -> tuple[Image.Image, float]:
w, h = image.size
scale = max_size / max(w, h)
if scale < 1.0:
new_size = (int(w * scale), int(h * scale))
return image.resize(new_size, Image.LANCZOS), scale
return image, 1.0
def _upscale_saliency(saliency: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
if saliency.shape == (target_h, target_w):
return saliency
h_scale = target_h / saliency.shape[0]
w_scale = target_w / saliency.shape[1]
return zoom(saliency, (h_scale, w_scale), order=1)
def _extract_gaze_sequence(saliency: np.ndarray, num_fixations: int = 5) -> list[dict]:
sal = saliency.copy().astype(np.float64)
h, w = sal.shape
inhibition_radius = int(max(h, w) * 0.1)
fixations = []
for rank in range(1, num_fixations + 1):
smoothed = gaussian_filter(sal, sigma=max(h, w) * 0.01)
if smoothed.max() < 1e-10:
break
peak_idx = np.unravel_index(np.argmax(smoothed), smoothed.shape)
py, px = int(peak_idx[0]), int(peak_idx[1])
prob = float(saliency[py, px])
fixations.append({
"rank": rank,
"x": px,
"y": py,
"x_pct": round(px / w * 100, 1),
"y_pct": round(py / h * 100, 1),
"probability": round(prob, 4),
})
yy, xx = np.ogrid[:h, :w]
mask = (xx - px) ** 2 + (yy - py) ** 2 <= inhibition_radius**2
sal[mask] = 0.0
return fixations
def _extract_hotspots(saliency: np.ndarray, num_hotspots: int = 5) -> list[dict]:
sal = saliency.copy()
h, w = sal.shape
hotspots = []
radius = int(max(h, w) * 0.08)
for i in range(num_hotspots):
smoothed = gaussian_filter(sal, sigma=max(h, w) * 0.015)
peak_idx = np.unravel_index(np.argmax(smoothed), smoothed.shape)
py, px = int(peak_idx[0]), int(peak_idx[1])
intensity = float(saliency[py, px])
x1, y1 = max(0, px - radius), max(0, py - radius)
x2, y2 = min(w, px + radius), min(h, py + radius)
hotspots.append({
"rank": i + 1,
"center_x": px,
"center_y": py,
"x": x1,
"y": y1,
"width": x2 - x1,
"height": y2 - y1,
"intensity": round(intensity, 4),
})
yy, xx = np.ogrid[:h, :w]
mask = (xx - px) ** 2 + (yy - py) ** 2 <= radius**2
sal[mask] = 0.0
return hotspots
def _compute_design_score(
saliency_full: np.ndarray, hotspots: list[dict], gaze_seq: list[dict]
) -> tuple[float, float]:
sal_sum = saliency_full.sum()
if sal_sum > 0:
prob_dist = saliency_full / sal_sum
prob_dist = prob_dist[prob_dist > 0]
entropy = -np.sum(prob_dist * np.log2(prob_dist))
max_entropy = np.log2(saliency_full.size)
raw_concentration = (1 - entropy / max_entropy) * 100
else:
raw_concentration = 0.0
entropy_score = round(float(np.clip(raw_concentration, 0, 100)), 1)
entropy_adjusted = float(np.sqrt(max(raw_concentration, 0) / 100)) * 100
if len(hotspots) >= 2:
top_intensity = hotspots[0]["intensity"]
rest_mean = float(np.mean([h["intensity"] for h in hotspots[1:]]))
dominance_ratio = top_intensity / rest_mean if rest_mean > 0 else 10.0
peak_dominance = float(100 * (1 - np.exp(-0.5 * dominance_ratio)))
elif len(hotspots) == 1:
peak_dominance = 95.0
else:
peak_dominance = 50.0
intensities = [h["intensity"] for h in hotspots]
n = len(intensities)
if n >= 2:
concordant = sum(
1 for i in range(n) for j in range(i + 1, n) if intensities[i] > intensities[j]
)
total_pairs = n * (n - 1) // 2
monotonicity = concordant / total_pairs if total_pairs > 0 else 1.0
drop_ratio = 1 - (intensities[-1] / intensities[0]) if intensities[0] > 0 else 0.0
hierarchy_clarity = float((0.6 * monotonicity + 0.4 * drop_ratio) * 100)
else:
hierarchy_clarity = 70.0
gaze_points = [(g["x"], g["y"]) for g in gaze_seq]
ng = len(gaze_points)
if ng >= 3:
angles = []
for i in range(ng - 2):
ax = gaze_points[i + 1][0] - gaze_points[i][0]
ay = gaze_points[i + 1][1] - gaze_points[i][1]
bx = gaze_points[i + 2][0] - gaze_points[i + 1][0]
by = gaze_points[i + 2][1] - gaze_points[i + 1][1]
mag_a = np.sqrt(ax**2 + ay**2)
mag_b = np.sqrt(bx**2 + by**2)
if mag_a > 0 and mag_b > 0:
cos_angle = np.clip((ax * bx + ay * by) / (mag_a * mag_b), -1, 1)
angles.append(float(np.degrees(np.arccos(cos_angle))))
avg_angle = float(np.mean(angles)) if angles else 70.0
angle_smoothness = max(0.0, 100 - (avg_angle / 180) * 100)
total_path = sum(
np.sqrt(
(gaze_points[i + 1][0] - gaze_points[i][0]) ** 2
+ (gaze_points[i + 1][1] - gaze_points[i][1]) ** 2
)
for i in range(ng - 1)
)
direct_dist = np.sqrt(
(gaze_points[-1][0] - gaze_points[0][0]) ** 2
+ (gaze_points[-1][1] - gaze_points[0][1]) ** 2
)
path_efficiency = float(direct_dist / total_path) if total_path > 0 else 1.0
gaze_coherence = 0.7 * angle_smoothness + 0.3 * (path_efficiency * 100)
else:
gaze_coherence = 70.0
composite = (
0.30 * peak_dominance
+ 0.25 * hierarchy_clarity
+ 0.25 * gaze_coherence
+ 0.20 * entropy_adjusted
)
overall_score = round(float(np.clip(composite, 0, 100)), 1)
return overall_score, entropy_score
@app.get("/health")
async def health():
return {"status": "ok", "device": _resolve_device()}
@app.post("/predict")
async def predict(
image: UploadFile = File(...),
model: str = Form("deepgaze_iie"),
x_internal_secret: str | None = Header(None),
):
_check_auth(x_internal_secret)
image_data = await image.read()
pil_image = Image.open(io.BytesIO(image_data)).convert("RGB")
orig_w, orig_h = pil_image.size
resized, _ = _prepare_for_inference(pil_image)
logger.info(f"Inference: model={model} original={orig_w}x{orig_h} resized={resized.size}")
saliency = _run_inference(resized, model)
saliency_full = _upscale_saliency(saliency, orig_h, orig_w)
gaze_sequence = _extract_gaze_sequence(saliency_full, num_fixations=5)
hotspots = _extract_hotspots(saliency_full, num_hotspots=5)
overall_score, entropy_score = _compute_design_score(saliency_full, hotspots, gaze_sequence)
saliency_b64 = base64.b64encode(saliency_full.astype(np.float32).tobytes()).decode()
logger.info(f"Done: score={overall_score} entropy={entropy_score}")
return {
"saliency_b64": saliency_b64,
"shape": [orig_h, orig_w],
"gaze_sequence": gaze_sequence,
"hotspots": hotspots,
"overall_score": overall_score,
"entropy_score": entropy_score,
}