marriott-box-image-video-ta.../api.py
DJP 1f2c2ff8e1 Multi-token + fuzzy search; admin-only Run Now / Backfill
Search:
- Previously /api/events did one ILIKE %q% across the columns, so
  "female city" required the literal substring "female city" to
  appear somewhere. Now the query is tokenised on whitespace; every
  token must match somewhere (AND), and each token matches either
  by substring (ILIKE) across the searched columns OR by trigram
  similarity (pg_trgm) against a concatenated text blob with a 0.3
  threshold — handles typos like "femalle" → "female".
- Results ranked by summed similarity score across all tokens, then
  recency. Empty query falls back to "newest 100".
- schema.sql: CREATE EXTENSION IF NOT EXISTS pg_trgm (idempotent;
  applied by ensure_schema on api startup).

Admin gating:
- auth.py: User now carries `is_admin`. Computed from a
  comma-separated ADMIN_EMAILS env var (case-insensitive match
  against `preferred_username`/`upn`/`email` claim). New
  `require_admin` FastAPI dependency 403s non-admins.
- In DEV_AUTH_BYPASS mode the dev user is admin by default; flip
  DEV_AUTH_IS_ADMIN=false to test the read-only UX without enabling
  SSO.
- POST /api/runs and POST /api/backfill now gated by require_admin.
- /api/me carries is_admin so the SPA can hide the destructive
  buttons for non-admins.

Frontend:
- App.tsx fetches /api/me on mount and hides Run Now + Backfill
  unless `is_admin` is true. Non-admins still see search + results +
  recent-runs table.

docker-compose / .env.example: thread ADMIN_EMAILS +
DEV_AUTH_IS_ADMIN into the api container.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 15:51:50 -04:00

345 lines
12 KiB
Python

"""
Marriott Box Tagger — FastAPI backend.
Endpoints (all under /api/, all behind require_auth except /api/health):
GET /api/health — liveness + config flags
GET /api/me — who am I (after auth)
GET /api/events?q=…&limit=… — search tagging_events across all
text + JSONB fields
POST /api/runs — kick off a tagging pass in a
background thread; returns run_id
GET /api/runs — recent runs (run_id + counts)
GET /api/runs/{run_id}/events — events for a single run, newest first
"""
import os
import threading
import uuid
from contextlib import contextmanager
from typing import Optional
from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from psycopg.rows import dict_row
import db
from auth import User, maybe_auth_info, require_admin, require_auth
BOX_FILE_URL = "https://app.box.com/file/{file_id}"
app = FastAPI(title="Marriott Box Tagger API", version="1.0.0")
# CORS: only meaningful in dev (when the Vite dev server hits FastAPI cross-origin).
# In prod, Apache serves both SPA and API under the same origin.
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
if _cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# In-memory record of background runs (run_id → state). Survives only as long as the
# api container; durable record of what each run produced is in tagging_events.
_runs: dict[str, dict] = {}
_runs_lock = threading.Lock()
@contextmanager
def _conn():
c = db.get_conn()
try:
yield c
finally:
c.close()
# ── Health / identity ────────────────────────────────────────────────────────
@app.get("/api/health")
def health():
db_ok = False
db_error = None
try:
with _conn() as c:
with c.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
db_ok = True
except Exception as e:
db_error = f"{type(e).__name__}: {e}"
return {
"ok": True,
"db": {"ok": db_ok, "error": db_error},
"auth": maybe_auth_info(),
}
@app.get("/api/me")
def me(user: User = Depends(require_auth)):
return user.to_dict()
# ── Search ──────────────────────────────────────────────────────────────────
# Columns the search ILIKE walks over (substring match, case-insensitive).
_SEARCH_COLS = [
"file_name",
"folder_path",
"description",
"status",
"file_id",
"coalesce(validated_metadata::text, '')",
"coalesce(raw_response::text, '')",
"coalesce(scenes::text, '')",
]
# Single concatenated text used for trigram similarity (fuzzy / typo tolerance).
_SEARCH_BLOB = (
"coalesce(file_name,'')||' '||coalesce(folder_path,'')||' '||"
"coalesce(description,'')||' '||coalesce(validated_metadata::text,'')||' '||"
"coalesce(scenes::text,'')"
)
# Fuzzy threshold for trigram similarity. 0.3 catches typos like
# "femalle" → "female" without flooding the results with noise.
_SIM_THRESHOLD = 0.3
# Short tokens (1-2 chars) are too noisy for trigrams — fall back to substring
# match only for those.
_MIN_FUZZY_TOKEN_LEN = 3
def _build_search_sql(q: str, limit: int):
"""
Tokenise the query on whitespace, AND-match every token across the columns,
where each token may match by substring OR by trigram similarity. Results
ranked by summed similarity score, then recency.
"""
tokens = [t for t in q.strip().split() if t]
common_cols = (
"id, run_id, created_at, file_id, file_name, folder_path, media_type, "
"gemini_model, description, scenes, validated_metadata, raw_response, "
"metadata_write_success, description_write_success, scene_comment_write_success, "
"status, error_message, duration_ms"
)
if not tokens:
return (
f"SELECT {common_cols} FROM tagging_events "
f"ORDER BY created_at DESC LIMIT %(limit)s",
{"limit": limit},
)
params: dict = {"limit": limit}
clauses: list[str] = []
score_terms: list[str] = []
for i, tok in enumerate(tokens):
like_key = f"like_{i}"
sim_key = f"sim_{i}"
params[like_key] = f"%{tok}%"
params[sim_key] = tok
col_ors = " OR ".join(f"{c} ILIKE %({like_key})s" for c in _SEARCH_COLS)
if len(tok) >= _MIN_FUZZY_TOKEN_LEN:
clauses.append(
f"(({col_ors}) "
f"OR similarity({_SEARCH_BLOB}, %({sim_key})s) > {_SIM_THRESHOLD})"
)
score_terms.append(f"similarity({_SEARCH_BLOB}, %({sim_key})s)")
else:
clauses.append(f"({col_ors})")
where = " AND ".join(clauses)
score_sql = " + ".join(score_terms) if score_terms else "0"
sql = (
f"SELECT {common_cols}, ({score_sql}) AS _score "
f"FROM tagging_events "
f"WHERE {where} "
f"ORDER BY _score DESC, created_at DESC "
f"LIMIT %(limit)s"
)
return sql, params
def _event_to_dict(row):
out = dict(row)
fid = out.get("file_id")
out["box_url"] = BOX_FILE_URL.format(file_id=fid) if fid else None
if out.get("run_id") is not None:
out["run_id"] = str(out["run_id"])
if out.get("created_at") is not None:
out["created_at"] = out["created_at"].isoformat()
return out
@app.get("/api/events")
def search_events(
q: str = Query("", description="Free-text search across all fields"),
limit: int = Query(100, ge=1, le=500),
user: User = Depends(require_auth),
):
sql, params = _build_search_sql(q, limit)
with _conn() as c:
with c.cursor(row_factory=dict_row) as cur:
cur.execute(sql, params)
rows = cur.fetchall()
return {"q": q, "count": len(rows), "results": [_event_to_dict(r) for r in rows]}
# ── Run-now ─────────────────────────────────────────────────────────────────
def _run_pass_in_thread(run_id: uuid.UUID):
"""Background worker: open a fresh DB conn and call into the tagger pipeline."""
# Import inside the thread so we don't pay tagger-side init cost at API startup.
import main as tagger
with _runs_lock:
_runs[str(run_id)] = {"run_id": str(run_id), "state": "running", "error": None}
db_conn = None
try:
db_conn = db.get_conn()
db.ensure_schema(db_conn)
tagger._run_pass(run_id, db_conn)
with _runs_lock:
_runs[str(run_id)]["state"] = "completed"
except SystemExit as e:
with _runs_lock:
_runs[str(run_id)]["state"] = "failed"
_runs[str(run_id)]["error"] = f"SystemExit({e.code})"
except Exception as e:
with _runs_lock:
_runs[str(run_id)]["state"] = "failed"
_runs[str(run_id)]["error"] = f"{type(e).__name__}: {e}"
finally:
if db_conn is not None:
try:
db_conn.close()
except Exception:
pass
@app.post("/api/runs")
def start_run(user: User = Depends(require_admin)):
run_id = uuid.uuid4()
t = threading.Thread(target=_run_pass_in_thread, args=(run_id,), daemon=True)
t.start()
return {"run_id": str(run_id), "state": "running", "started_by": user.email or user.oid}
def _run_backfill_in_thread(run_id: uuid.UUID):
import main as tagger
with _runs_lock:
_runs[str(run_id)] = {"run_id": str(run_id), "state": "running", "error": None, "kind": "backfill"}
db_conn = None
try:
db_conn = db.get_conn()
db.ensure_schema(db_conn)
tagger._run_backfill(run_id, db_conn)
with _runs_lock:
_runs[str(run_id)]["state"] = "completed"
except SystemExit as e:
with _runs_lock:
_runs[str(run_id)]["state"] = "failed"
_runs[str(run_id)]["error"] = f"SystemExit({e.code})"
except Exception as e:
with _runs_lock:
_runs[str(run_id)]["state"] = "failed"
_runs[str(run_id)]["error"] = f"{type(e).__name__}: {e}"
finally:
if db_conn is not None:
try:
db_conn.close()
except Exception:
pass
@app.post("/api/backfill")
def start_backfill(user: User = Depends(require_admin)):
"""
Walk the Box folder and mirror any existing marriottUsa metadata into the
local DB as `status='backfilled'` rows. Use this after first deploy (or
after restoring an empty DB) so the per-file skip check doesn't re-tag
files Box already has metadata for.
"""
run_id = uuid.uuid4()
t = threading.Thread(target=_run_backfill_in_thread, args=(run_id,), daemon=True)
t.start()
return {"run_id": str(run_id), "state": "running", "kind": "backfill", "started_by": user.email or user.oid}
@app.get("/api/runs")
def list_runs(user: User = Depends(require_auth), limit: int = Query(20, ge=1, le=100)):
"""Recent runs in the DB, plus the in-memory state if the run is still active."""
with _conn() as c:
with c.cursor(row_factory=dict_row) as cur:
cur.execute(
"""
SELECT run_id,
min(created_at) AS started_at,
max(created_at) AS last_event_at,
count(*) AS events,
count(*) FILTER (WHERE status = 'success') AS successes,
count(*) FILTER (WHERE status LIKE '%%_error') AS errors
FROM tagging_events
GROUP BY run_id
ORDER BY max(created_at) DESC
LIMIT %s
""",
(limit,),
)
rows = cur.fetchall()
out = []
for r in rows:
rid = str(r["run_id"])
live = _runs.get(rid)
out.append({
"run_id": rid,
"started_at": r["started_at"].isoformat() if r["started_at"] else None,
"last_event_at": r["last_event_at"].isoformat() if r["last_event_at"] else None,
"events": r["events"],
"successes": r["successes"],
"errors": r["errors"],
"live_state": live["state"] if live else None,
"live_error": live["error"] if live else None,
})
return {"runs": out}
@app.get("/api/runs/{run_id}/events")
def run_events(run_id: str, user: User = Depends(require_auth), limit: int = Query(500, ge=1, le=2000)):
try:
uuid.UUID(run_id)
except ValueError:
raise HTTPException(status_code=400, detail="run_id must be a UUID")
with _conn() as c:
with c.cursor(row_factory=dict_row) as cur:
cur.execute(
"""
SELECT id, run_id, created_at, file_id, file_name, folder_path, media_type,
gemini_model, description, scenes, validated_metadata,
metadata_write_success, description_write_success,
scene_comment_write_success, status, error_message, duration_ms
FROM tagging_events
WHERE run_id = %s
ORDER BY created_at DESC
LIMIT %s
""",
(run_id, limit),
)
rows = cur.fetchall()
live = _runs.get(run_id)
return {
"run_id": run_id,
"live_state": live["state"] if live else None,
"live_error": live["error"] if live else None,
"count": len(rows),
"events": [_event_to_dict(r) for r in rows],
}