olivas/backend/app/services/saliency/deepgaze.py
DJP 3467dbcf03 Initial commit — OliVAS visual attention analysis platform
Full-stack application for predicting where humans look in images using
DeepGaze saliency models. Includes heatmap overlays, gaze sequence prediction,
hotspot detection, AOI analysis, rule-based insights, optional Claude AI
design analysis, and professional PDF report generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:20:58 -05:00

93 lines
2.9 KiB
Python

import logging
import numpy as np
import torch
from PIL import Image
from scipy.ndimage import zoom as scipy_zoom
from scipy.special import logsumexp
from app.services.saliency.base import BaseSaliencyModel
logger = logging.getLogger("olivas.deepgaze")
# Map variant keys to deepgaze_pytorch classes
VARIANT_MAP = {
"I": ("DeepGazeI", "DeepGaze I"),
"IIE": ("DeepGazeIIE", "DeepGaze IIE"),
"III": ("DeepGazeIII", "DeepGaze III"),
}
class DeepGazeModel(BaseSaliencyModel):
"""Unified wrapper for all DeepGaze model variants (I, IIE, III)."""
def __init__(self, variant: str = "IIE", device: str = "cpu"):
if variant not in VARIANT_MAP:
raise ValueError(f"Unknown DeepGaze variant: {variant}. Choose from {list(VARIANT_MAP.keys())}")
self.variant = variant
self.class_name, self.display_name = VARIANT_MAP[variant]
self.device = torch.device(device)
self.model = None
self.centerbias_template = None
def load(self) -> None:
import deepgaze_pytorch
logger.info(f"Loading {self.display_name} on {self.device}...")
model_cls = getattr(deepgaze_pytorch, self.class_name)
self.model = model_cls(pretrained=True).to(self.device)
self.model.eval()
self._create_default_centerbias()
logger.info(f"{self.display_name} loaded successfully")
def _create_default_centerbias(self):
"""Create a generic center bias prior (Gaussian centered)."""
size = 1024
x = np.linspace(-1, 1, size)
y = np.linspace(-1, 1, size)
xx, yy = np.meshgrid(x, y)
self.centerbias_template = -0.5 * (xx**2 + yy**2) / 0.5**2
def predict(self, image: Image.Image) -> np.ndarray:
img_np = np.array(image.convert("RGB"))
h, w = img_np.shape[:2]
# Prepare image tensor [1, C, H, W]
image_tensor = (
torch.tensor([img_np.transpose(2, 0, 1)])
.float()
.to(self.device)
)
# Prepare centerbias
cb = scipy_zoom(
self.centerbias_template,
(h / self.centerbias_template.shape[0], w / self.centerbias_template.shape[1]),
order=0,
)
cb -= logsumexp(cb)
centerbias_tensor = (
torch.tensor([cb]).float().to(self.device)
)
with torch.no_grad():
log_density = self.model(image_tensor, centerbias_tensor)
saliency = torch.exp(log_density).cpu().numpy().squeeze()
# Normalize to [0, 1]
sal_min, sal_max = saliency.min(), saliency.max()
if sal_max - sal_min > 1e-10:
saliency = (saliency - sal_min) / (sal_max - sal_min)
else:
saliency = np.zeros_like(saliency)
return saliency
def get_name(self) -> str:
return self.display_name
# Backwards-compatible alias
DeepGazeIIEModel = lambda device="cpu": DeepGazeModel(variant="IIE", device=device)