pimco-charts/app/main.py
Vadym Samoilenko 21d469bd82 Fix OAuth callback to use root path (match Azure AD registration)
Azure AD redirect URI is registered as /Pimco-charts (no /auth/callback),
so handle the code exchange in the index route and exempt root with ?code= in middleware.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 15:07:49 +00:00

228 lines
7.6 KiB
Python

"""FastAPI application: upload data, interpret brief, render SVG, iterate."""
from __future__ import annotations
import json
import uuid
import traceback
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from starlette.middleware.sessions import SessionMiddleware
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.config import OUTPUT_DIR, STATIC_DIR, TEMPLATES_DIR, SESSION_SECRET_KEY
from app.auth.middleware import AuthMiddleware
from app.auth.routes import router as auth_router
from app.data.loader import load_file
from app.data.analyzer import summarize_data
from app.data.transformer import prepare_dataframe
from app.ai.brief_interpreter import interpret_brief, refine_spec
from app.ai.spec_validator import validate_and_fix_spec
from app.models.chart_spec import ChartSpec
from app.models.style import LAYOUT
from app.renderer.engine import render_chart
app = FastAPI(title="PIMCO Chart Generator", root_path="/Pimco-charts")
app.add_middleware(AuthMiddleware)
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY, https_only=True, same_site="lax")
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="127.0.0.1")
app.include_router(auth_router)
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
OUTPUT_DIR.mkdir(exist_ok=True)
# Simple in-memory session store: session_id -> {spec, data_path, sheets, summary, history}
_sessions: dict[str, dict] = {}
@app.get("/", response_class=HTMLResponse)
async def index(request: Request, code: str = "", state: str = "", error: str = ""):
# OAuth callback: Azure redirects back to root with ?code=&state=
if code:
stored_state = request.session.pop("oauth_state", None)
verifier = request.session.pop("pkce_verifier", None)
if stored_state and state == stored_state and verifier:
try:
from app.auth.msal_client import exchange_code
result = exchange_code(code=code, verifier=verifier)
claims = result.get("id_token_claims", {})
request.session["user"] = {
"name": claims.get("name", claims.get("preferred_username", "User")),
"email": claims.get("email", claims.get("preferred_username", "")),
"oid": claims.get("oid", ""),
}
except ValueError:
pass
from fastapi.responses import RedirectResponse as RR
return RR(url="/")
return templates.TemplateResponse("upload.html", {
"request": request,
"user": request.session.get("user"),
})
@app.post("/generate", response_class=HTMLResponse)
async def generate(
request: Request,
file: UploadFile = File(...),
brief: str = Form(...),
sheet: str = Form(""),
width: int = Form(2560),
height: int = Form(1440),
):
try:
# Save uploaded file
session_id = uuid.uuid4().hex[:12]
data_path = OUTPUT_DIR / f"data_{session_id}_{file.filename}"
with open(data_path, "wb") as f:
content = await file.read()
f.write(content)
# Load and prepare data
sheets = load_file(data_path)
if sheet and sheet in sheets:
sheets = {sheet: sheets[sheet]}
prepared = {}
for name, df in sheets.items():
prepared[name] = prepare_dataframe(df)
summary = summarize_data(prepared)
# Interpret brief with Claude
spec = interpret_brief(brief, summary)
spec = validate_and_fix_spec(spec, prepared)
# Store session
_sessions[session_id] = {
"spec": spec,
"data_path": str(data_path),
"prepared": prepared,
"summary": summary,
"width": width,
"height": height,
"history": [
{"role": "user", "message": brief},
{"role": "assistant", "message": "Chart generated."},
],
}
# Render
svg, filename, spec_json = _render_and_save(spec, prepared, width, height)
return templates.TemplateResponse("preview.html", {
"request": request,
"svg_content": svg,
"filename": filename,
"spec_json": spec_json,
"session_id": session_id,
"history": _sessions[session_id]["history"],
"user": request.session.get("user"),
})
except Exception as e:
traceback.print_exc()
return HTMLResponse(
f'<div class="error"><h3>Error</h3><p>{str(e)}</p></div>',
status_code=200,
)
@app.post("/refine", response_class=HTMLResponse)
async def refine(
request: Request,
session_id: str = Form(...),
edit: str = Form(...),
):
try:
session = _sessions.get(session_id)
if not session:
return HTMLResponse(
'<div class="error"><h3>Session expired</h3>'
'<p>Please re-upload your data and generate a new chart.</p></div>',
status_code=200,
)
old_spec = session["spec"]
summary = session["summary"]
history = session["history"]
# Add the edit to history
history.append({"role": "user", "message": edit})
# Ask Claude to refine the spec
new_spec = refine_spec(old_spec, edit, summary, history)
new_spec = validate_and_fix_spec(new_spec, session["prepared"])
# Update session
session["spec"] = new_spec
history.append({"role": "assistant", "message": "Chart updated."})
# Render
svg, filename, spec_json = _render_and_save(
new_spec, session["prepared"], session["width"], session["height"]
)
return templates.TemplateResponse("preview.html", {
"request": request,
"svg_content": svg,
"filename": filename,
"spec_json": spec_json,
"session_id": session_id,
"history": history,
"user": request.session.get("user"),
})
except Exception as e:
traceback.print_exc()
return HTMLResponse(
f'<div class="error"><h3>Error</h3><p>{str(e)}</p></div>',
status_code=200,
)
@app.get("/download/{filename}")
async def download(filename: str):
filepath = OUTPUT_DIR / filename
if not filepath.exists():
return HTMLResponse("<p>File not found</p>", status_code=404)
return FileResponse(
filepath,
media_type="image/svg+xml",
filename=filename,
)
def _render_and_save(
spec: ChartSpec,
prepared: dict,
width: int,
height: int,
) -> tuple[str, str, str]:
"""Render a ChartSpec and save the SVG. Returns (svg_str, filename, spec_json)."""
LAYOUT["single"]["width"] = width
LAYOUT["single"]["height"] = height
LAYOUT["dual_panel"]["width"] = width
LAYOUT["dual_panel"]["height"] = height
data_dict = {}
if len(prepared) == 1:
data_dict["_default"] = next(iter(prepared.values()))
else:
data_dict = dict(prepared)
data_dict["_default"] = next(iter(prepared.values()))
svg = render_chart(spec, data_dict)
filename = f"chart_{uuid.uuid4().hex[:8]}.svg"
svg_path = OUTPUT_DIR / filename
with open(svg_path, "w") as f:
f.write(svg)
spec_json = json.dumps(spec.model_dump(), indent=2, default=str)
return svg, filename, spec_json