Search: - Previously /api/events did one ILIKE %q% across the columns, so "female city" required the literal substring "female city" to appear somewhere. Now the query is tokenised on whitespace; every token must match somewhere (AND), and each token matches either by substring (ILIKE) across the searched columns OR by trigram similarity (pg_trgm) against a concatenated text blob with a 0.3 threshold — handles typos like "femalle" → "female". - Results ranked by summed similarity score across all tokens, then recency. Empty query falls back to "newest 100". - schema.sql: CREATE EXTENSION IF NOT EXISTS pg_trgm (idempotent; applied by ensure_schema on api startup). Admin gating: - auth.py: User now carries `is_admin`. Computed from a comma-separated ADMIN_EMAILS env var (case-insensitive match against `preferred_username`/`upn`/`email` claim). New `require_admin` FastAPI dependency 403s non-admins. - In DEV_AUTH_BYPASS mode the dev user is admin by default; flip DEV_AUTH_IS_ADMIN=false to test the read-only UX without enabling SSO. - POST /api/runs and POST /api/backfill now gated by require_admin. - /api/me carries is_admin so the SPA can hide the destructive buttons for non-admins. Frontend: - App.tsx fetches /api/me on mount and hides Run Now + Backfill unless `is_admin` is true. Non-admins still see search + results + recent-runs table. docker-compose / .env.example: thread ADMIN_EMAILS + DEV_AUTH_IS_ADMIN into the api container. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""
|
|
Marriott Box Tagger — FastAPI backend.
|
|
|
|
Endpoints (all under /api/, all behind require_auth except /api/health):
|
|
GET /api/health — liveness + config flags
|
|
GET /api/me — who am I (after auth)
|
|
GET /api/events?q=…&limit=… — search tagging_events across all
|
|
text + JSONB fields
|
|
POST /api/runs — kick off a tagging pass in a
|
|
background thread; returns run_id
|
|
GET /api/runs — recent runs (run_id + counts)
|
|
GET /api/runs/{run_id}/events — events for a single run, newest first
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from typing import Optional
|
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from psycopg.rows import dict_row
|
|
|
|
import db
|
|
from auth import User, maybe_auth_info, require_admin, require_auth
|
|
|
|
BOX_FILE_URL = "https://app.box.com/file/{file_id}"
|
|
|
|
app = FastAPI(title="Marriott Box Tagger API", version="1.0.0")
|
|
|
|
# CORS: only meaningful in dev (when the Vite dev server hits FastAPI cross-origin).
|
|
# In prod, Apache serves both SPA and API under the same origin.
|
|
_cors_origins = [o.strip() for o in os.getenv("CORS_ORIGINS", "").split(",") if o.strip()]
|
|
if _cors_origins:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# In-memory record of background runs (run_id → state). Survives only as long as the
|
|
# api container; durable record of what each run produced is in tagging_events.
|
|
_runs: dict[str, dict] = {}
|
|
_runs_lock = threading.Lock()
|
|
|
|
|
|
@contextmanager
|
|
def _conn():
|
|
c = db.get_conn()
|
|
try:
|
|
yield c
|
|
finally:
|
|
c.close()
|
|
|
|
|
|
# ── Health / identity ────────────────────────────────────────────────────────
|
|
|
|
|
|
@app.get("/api/health")
|
|
def health():
|
|
db_ok = False
|
|
db_error = None
|
|
try:
|
|
with _conn() as c:
|
|
with c.cursor() as cur:
|
|
cur.execute("SELECT 1")
|
|
cur.fetchone()
|
|
db_ok = True
|
|
except Exception as e:
|
|
db_error = f"{type(e).__name__}: {e}"
|
|
return {
|
|
"ok": True,
|
|
"db": {"ok": db_ok, "error": db_error},
|
|
"auth": maybe_auth_info(),
|
|
}
|
|
|
|
|
|
@app.get("/api/me")
|
|
def me(user: User = Depends(require_auth)):
|
|
return user.to_dict()
|
|
|
|
|
|
# ── Search ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
# Columns the search ILIKE walks over (substring match, case-insensitive).
|
|
_SEARCH_COLS = [
|
|
"file_name",
|
|
"folder_path",
|
|
"description",
|
|
"status",
|
|
"file_id",
|
|
"coalesce(validated_metadata::text, '')",
|
|
"coalesce(raw_response::text, '')",
|
|
"coalesce(scenes::text, '')",
|
|
]
|
|
|
|
# Single concatenated text used for trigram similarity (fuzzy / typo tolerance).
|
|
_SEARCH_BLOB = (
|
|
"coalesce(file_name,'')||' '||coalesce(folder_path,'')||' '||"
|
|
"coalesce(description,'')||' '||coalesce(validated_metadata::text,'')||' '||"
|
|
"coalesce(scenes::text,'')"
|
|
)
|
|
|
|
# Fuzzy threshold for trigram similarity. 0.3 catches typos like
|
|
# "femalle" → "female" without flooding the results with noise.
|
|
_SIM_THRESHOLD = 0.3
|
|
|
|
# Short tokens (1-2 chars) are too noisy for trigrams — fall back to substring
|
|
# match only for those.
|
|
_MIN_FUZZY_TOKEN_LEN = 3
|
|
|
|
|
|
def _build_search_sql(q: str, limit: int):
|
|
"""
|
|
Tokenise the query on whitespace, AND-match every token across the columns,
|
|
where each token may match by substring OR by trigram similarity. Results
|
|
ranked by summed similarity score, then recency.
|
|
"""
|
|
tokens = [t for t in q.strip().split() if t]
|
|
common_cols = (
|
|
"id, run_id, created_at, file_id, file_name, folder_path, media_type, "
|
|
"gemini_model, description, scenes, validated_metadata, raw_response, "
|
|
"metadata_write_success, description_write_success, scene_comment_write_success, "
|
|
"status, error_message, duration_ms"
|
|
)
|
|
|
|
if not tokens:
|
|
return (
|
|
f"SELECT {common_cols} FROM tagging_events "
|
|
f"ORDER BY created_at DESC LIMIT %(limit)s",
|
|
{"limit": limit},
|
|
)
|
|
|
|
params: dict = {"limit": limit}
|
|
clauses: list[str] = []
|
|
score_terms: list[str] = []
|
|
for i, tok in enumerate(tokens):
|
|
like_key = f"like_{i}"
|
|
sim_key = f"sim_{i}"
|
|
params[like_key] = f"%{tok}%"
|
|
params[sim_key] = tok
|
|
col_ors = " OR ".join(f"{c} ILIKE %({like_key})s" for c in _SEARCH_COLS)
|
|
if len(tok) >= _MIN_FUZZY_TOKEN_LEN:
|
|
clauses.append(
|
|
f"(({col_ors}) "
|
|
f"OR similarity({_SEARCH_BLOB}, %({sim_key})s) > {_SIM_THRESHOLD})"
|
|
)
|
|
score_terms.append(f"similarity({_SEARCH_BLOB}, %({sim_key})s)")
|
|
else:
|
|
clauses.append(f"({col_ors})")
|
|
|
|
where = " AND ".join(clauses)
|
|
score_sql = " + ".join(score_terms) if score_terms else "0"
|
|
sql = (
|
|
f"SELECT {common_cols}, ({score_sql}) AS _score "
|
|
f"FROM tagging_events "
|
|
f"WHERE {where} "
|
|
f"ORDER BY _score DESC, created_at DESC "
|
|
f"LIMIT %(limit)s"
|
|
)
|
|
return sql, params
|
|
|
|
|
|
def _event_to_dict(row):
|
|
out = dict(row)
|
|
fid = out.get("file_id")
|
|
out["box_url"] = BOX_FILE_URL.format(file_id=fid) if fid else None
|
|
if out.get("run_id") is not None:
|
|
out["run_id"] = str(out["run_id"])
|
|
if out.get("created_at") is not None:
|
|
out["created_at"] = out["created_at"].isoformat()
|
|
return out
|
|
|
|
|
|
@app.get("/api/events")
|
|
def search_events(
|
|
q: str = Query("", description="Free-text search across all fields"),
|
|
limit: int = Query(100, ge=1, le=500),
|
|
user: User = Depends(require_auth),
|
|
):
|
|
sql, params = _build_search_sql(q, limit)
|
|
with _conn() as c:
|
|
with c.cursor(row_factory=dict_row) as cur:
|
|
cur.execute(sql, params)
|
|
rows = cur.fetchall()
|
|
return {"q": q, "count": len(rows), "results": [_event_to_dict(r) for r in rows]}
|
|
|
|
|
|
# ── Run-now ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _run_pass_in_thread(run_id: uuid.UUID):
|
|
"""Background worker: open a fresh DB conn and call into the tagger pipeline."""
|
|
# Import inside the thread so we don't pay tagger-side init cost at API startup.
|
|
import main as tagger
|
|
|
|
with _runs_lock:
|
|
_runs[str(run_id)] = {"run_id": str(run_id), "state": "running", "error": None}
|
|
|
|
db_conn = None
|
|
try:
|
|
db_conn = db.get_conn()
|
|
db.ensure_schema(db_conn)
|
|
tagger._run_pass(run_id, db_conn)
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "completed"
|
|
except SystemExit as e:
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "failed"
|
|
_runs[str(run_id)]["error"] = f"SystemExit({e.code})"
|
|
except Exception as e:
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "failed"
|
|
_runs[str(run_id)]["error"] = f"{type(e).__name__}: {e}"
|
|
finally:
|
|
if db_conn is not None:
|
|
try:
|
|
db_conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/api/runs")
|
|
def start_run(user: User = Depends(require_admin)):
|
|
run_id = uuid.uuid4()
|
|
t = threading.Thread(target=_run_pass_in_thread, args=(run_id,), daemon=True)
|
|
t.start()
|
|
return {"run_id": str(run_id), "state": "running", "started_by": user.email or user.oid}
|
|
|
|
|
|
def _run_backfill_in_thread(run_id: uuid.UUID):
|
|
import main as tagger
|
|
|
|
with _runs_lock:
|
|
_runs[str(run_id)] = {"run_id": str(run_id), "state": "running", "error": None, "kind": "backfill"}
|
|
|
|
db_conn = None
|
|
try:
|
|
db_conn = db.get_conn()
|
|
db.ensure_schema(db_conn)
|
|
tagger._run_backfill(run_id, db_conn)
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "completed"
|
|
except SystemExit as e:
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "failed"
|
|
_runs[str(run_id)]["error"] = f"SystemExit({e.code})"
|
|
except Exception as e:
|
|
with _runs_lock:
|
|
_runs[str(run_id)]["state"] = "failed"
|
|
_runs[str(run_id)]["error"] = f"{type(e).__name__}: {e}"
|
|
finally:
|
|
if db_conn is not None:
|
|
try:
|
|
db_conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/api/backfill")
|
|
def start_backfill(user: User = Depends(require_admin)):
|
|
"""
|
|
Walk the Box folder and mirror any existing marriottUsa metadata into the
|
|
local DB as `status='backfilled'` rows. Use this after first deploy (or
|
|
after restoring an empty DB) so the per-file skip check doesn't re-tag
|
|
files Box already has metadata for.
|
|
"""
|
|
run_id = uuid.uuid4()
|
|
t = threading.Thread(target=_run_backfill_in_thread, args=(run_id,), daemon=True)
|
|
t.start()
|
|
return {"run_id": str(run_id), "state": "running", "kind": "backfill", "started_by": user.email or user.oid}
|
|
|
|
|
|
@app.get("/api/runs")
|
|
def list_runs(user: User = Depends(require_auth), limit: int = Query(20, ge=1, le=100)):
|
|
"""Recent runs in the DB, plus the in-memory state if the run is still active."""
|
|
with _conn() as c:
|
|
with c.cursor(row_factory=dict_row) as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT run_id,
|
|
min(created_at) AS started_at,
|
|
max(created_at) AS last_event_at,
|
|
count(*) AS events,
|
|
count(*) FILTER (WHERE status = 'success') AS successes,
|
|
count(*) FILTER (WHERE status LIKE '%%_error') AS errors
|
|
FROM tagging_events
|
|
GROUP BY run_id
|
|
ORDER BY max(created_at) DESC
|
|
LIMIT %s
|
|
""",
|
|
(limit,),
|
|
)
|
|
rows = cur.fetchall()
|
|
out = []
|
|
for r in rows:
|
|
rid = str(r["run_id"])
|
|
live = _runs.get(rid)
|
|
out.append({
|
|
"run_id": rid,
|
|
"started_at": r["started_at"].isoformat() if r["started_at"] else None,
|
|
"last_event_at": r["last_event_at"].isoformat() if r["last_event_at"] else None,
|
|
"events": r["events"],
|
|
"successes": r["successes"],
|
|
"errors": r["errors"],
|
|
"live_state": live["state"] if live else None,
|
|
"live_error": live["error"] if live else None,
|
|
})
|
|
return {"runs": out}
|
|
|
|
|
|
@app.get("/api/runs/{run_id}/events")
|
|
def run_events(run_id: str, user: User = Depends(require_auth), limit: int = Query(500, ge=1, le=2000)):
|
|
try:
|
|
uuid.UUID(run_id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="run_id must be a UUID")
|
|
with _conn() as c:
|
|
with c.cursor(row_factory=dict_row) as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT id, run_id, created_at, file_id, file_name, folder_path, media_type,
|
|
gemini_model, description, scenes, validated_metadata,
|
|
metadata_write_success, description_write_success,
|
|
scene_comment_write_success, status, error_message, duration_ms
|
|
FROM tagging_events
|
|
WHERE run_id = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(run_id, limit),
|
|
)
|
|
rows = cur.fetchall()
|
|
live = _runs.get(run_id)
|
|
return {
|
|
"run_id": run_id,
|
|
"live_state": live["state"] if live else None,
|
|
"live_error": live["error"] if live else None,
|
|
"count": len(rows),
|
|
"events": [_event_to_dict(r) for r in rows],
|
|
}
|