olivas/backend/app/services/saliency/model_manager.py
DJP 3467dbcf03 Initial commit — OliVAS visual attention analysis platform
Full-stack application for predicting where humans look in images using
DeepGaze saliency models. Includes heatmap overlays, gaze sequence prediction,
hotspot detection, AOI analysis, rule-based insights, optional Claude AI
design analysis, and professional PDF report generation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:20:58 -05:00

55 lines
1.6 KiB
Python

import logging
import numpy as np
from PIL import Image
from app.services.saliency.base import BaseSaliencyModel
logger = logging.getLogger("olivas.model_manager")
class ModelManager:
def __init__(self):
self.models: dict[str, BaseSaliencyModel] = {}
self.default_model = "deepgaze_iie"
def load_models(self, device: str = "cpu") -> None:
from app.services.saliency.deepgaze import DeepGazeModel
variants = [
("deepgaze_i", "I"),
("deepgaze_iie", "IIE"),
("deepgaze_iii", "III"),
]
for key, variant in variants:
try:
model = DeepGazeModel(variant=variant, device=device)
model.load()
self.models[key] = model
logger.info(f"Loaded {model.get_name()}")
except Exception as e:
logger.warning(f"Failed to load DeepGaze {variant}: {e}")
def predict(self, image: Image.Image, model_name: str | None = None) -> np.ndarray:
name = model_name or self.default_model
if name not in self.models:
raise RuntimeError(f"Model '{name}' not loaded. Available: {list(self.models.keys())}")
return self.models[name].predict(image)
def list_models(self) -> list[dict]:
return [
{"id": key, "name": model.get_name()} for key, model in self.models.items()
]
def cleanup(self):
self.models.clear()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
model_manager = ModelManager()