Azure AD redirect URI is registered as /Pimco-charts (no /auth/callback), so handle the code exchange in the index route and exempt root with ?code= in middleware. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
228 lines
7.6 KiB
Python
228 lines
7.6 KiB
Python
"""FastAPI application: upload data, interpret brief, render SVG, iterate."""
|
|
|
|
from __future__ import annotations
|
|
import json
|
|
import uuid
|
|
import traceback
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, Request
|
|
from fastapi.responses import HTMLResponse, FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
|
|
|
from app.config import OUTPUT_DIR, STATIC_DIR, TEMPLATES_DIR, SESSION_SECRET_KEY
|
|
from app.auth.middleware import AuthMiddleware
|
|
from app.auth.routes import router as auth_router
|
|
from app.data.loader import load_file
|
|
from app.data.analyzer import summarize_data
|
|
from app.data.transformer import prepare_dataframe
|
|
from app.ai.brief_interpreter import interpret_brief, refine_spec
|
|
from app.ai.spec_validator import validate_and_fix_spec
|
|
from app.models.chart_spec import ChartSpec
|
|
from app.models.style import LAYOUT
|
|
from app.renderer.engine import render_chart
|
|
|
|
app = FastAPI(title="PIMCO Chart Generator", root_path="/Pimco-charts")
|
|
app.add_middleware(AuthMiddleware)
|
|
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY, https_only=True, same_site="lax")
|
|
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="127.0.0.1")
|
|
app.include_router(auth_router)
|
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
|
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
|
|
# Simple in-memory session store: session_id -> {spec, data_path, sheets, summary, history}
|
|
_sessions: dict[str, dict] = {}
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def index(request: Request, code: str = "", state: str = "", error: str = ""):
|
|
# OAuth callback: Azure redirects back to root with ?code=&state=
|
|
if code:
|
|
stored_state = request.session.pop("oauth_state", None)
|
|
verifier = request.session.pop("pkce_verifier", None)
|
|
if stored_state and state == stored_state and verifier:
|
|
try:
|
|
from app.auth.msal_client import exchange_code
|
|
result = exchange_code(code=code, verifier=verifier)
|
|
claims = result.get("id_token_claims", {})
|
|
request.session["user"] = {
|
|
"name": claims.get("name", claims.get("preferred_username", "User")),
|
|
"email": claims.get("email", claims.get("preferred_username", "")),
|
|
"oid": claims.get("oid", ""),
|
|
}
|
|
except ValueError:
|
|
pass
|
|
from fastapi.responses import RedirectResponse as RR
|
|
return RR(url="/")
|
|
|
|
return templates.TemplateResponse("upload.html", {
|
|
"request": request,
|
|
"user": request.session.get("user"),
|
|
})
|
|
|
|
|
|
@app.post("/generate", response_class=HTMLResponse)
|
|
async def generate(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
brief: str = Form(...),
|
|
sheet: str = Form(""),
|
|
width: int = Form(2560),
|
|
height: int = Form(1440),
|
|
):
|
|
try:
|
|
# Save uploaded file
|
|
session_id = uuid.uuid4().hex[:12]
|
|
data_path = OUTPUT_DIR / f"data_{session_id}_{file.filename}"
|
|
with open(data_path, "wb") as f:
|
|
content = await file.read()
|
|
f.write(content)
|
|
|
|
# Load and prepare data
|
|
sheets = load_file(data_path)
|
|
if sheet and sheet in sheets:
|
|
sheets = {sheet: sheets[sheet]}
|
|
|
|
prepared = {}
|
|
for name, df in sheets.items():
|
|
prepared[name] = prepare_dataframe(df)
|
|
|
|
summary = summarize_data(prepared)
|
|
|
|
# Interpret brief with Claude
|
|
spec = interpret_brief(brief, summary)
|
|
spec = validate_and_fix_spec(spec, prepared)
|
|
|
|
# Store session
|
|
_sessions[session_id] = {
|
|
"spec": spec,
|
|
"data_path": str(data_path),
|
|
"prepared": prepared,
|
|
"summary": summary,
|
|
"width": width,
|
|
"height": height,
|
|
"history": [
|
|
{"role": "user", "message": brief},
|
|
{"role": "assistant", "message": "Chart generated."},
|
|
],
|
|
}
|
|
|
|
# Render
|
|
svg, filename, spec_json = _render_and_save(spec, prepared, width, height)
|
|
|
|
return templates.TemplateResponse("preview.html", {
|
|
"request": request,
|
|
"svg_content": svg,
|
|
"filename": filename,
|
|
"spec_json": spec_json,
|
|
"session_id": session_id,
|
|
"history": _sessions[session_id]["history"],
|
|
"user": request.session.get("user"),
|
|
})
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return HTMLResponse(
|
|
f'<div class="error"><h3>Error</h3><p>{str(e)}</p></div>',
|
|
status_code=200,
|
|
)
|
|
|
|
|
|
@app.post("/refine", response_class=HTMLResponse)
|
|
async def refine(
|
|
request: Request,
|
|
session_id: str = Form(...),
|
|
edit: str = Form(...),
|
|
):
|
|
try:
|
|
session = _sessions.get(session_id)
|
|
if not session:
|
|
return HTMLResponse(
|
|
'<div class="error"><h3>Session expired</h3>'
|
|
'<p>Please re-upload your data and generate a new chart.</p></div>',
|
|
status_code=200,
|
|
)
|
|
|
|
old_spec = session["spec"]
|
|
summary = session["summary"]
|
|
history = session["history"]
|
|
|
|
# Add the edit to history
|
|
history.append({"role": "user", "message": edit})
|
|
|
|
# Ask Claude to refine the spec
|
|
new_spec = refine_spec(old_spec, edit, summary, history)
|
|
new_spec = validate_and_fix_spec(new_spec, session["prepared"])
|
|
|
|
# Update session
|
|
session["spec"] = new_spec
|
|
history.append({"role": "assistant", "message": "Chart updated."})
|
|
|
|
# Render
|
|
svg, filename, spec_json = _render_and_save(
|
|
new_spec, session["prepared"], session["width"], session["height"]
|
|
)
|
|
|
|
return templates.TemplateResponse("preview.html", {
|
|
"request": request,
|
|
"svg_content": svg,
|
|
"filename": filename,
|
|
"spec_json": spec_json,
|
|
"session_id": session_id,
|
|
"history": history,
|
|
"user": request.session.get("user"),
|
|
})
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return HTMLResponse(
|
|
f'<div class="error"><h3>Error</h3><p>{str(e)}</p></div>',
|
|
status_code=200,
|
|
)
|
|
|
|
|
|
@app.get("/download/{filename}")
|
|
async def download(filename: str):
|
|
filepath = OUTPUT_DIR / filename
|
|
if not filepath.exists():
|
|
return HTMLResponse("<p>File not found</p>", status_code=404)
|
|
return FileResponse(
|
|
filepath,
|
|
media_type="image/svg+xml",
|
|
filename=filename,
|
|
)
|
|
|
|
|
|
def _render_and_save(
|
|
spec: ChartSpec,
|
|
prepared: dict,
|
|
width: int,
|
|
height: int,
|
|
) -> tuple[str, str, str]:
|
|
"""Render a ChartSpec and save the SVG. Returns (svg_str, filename, spec_json)."""
|
|
LAYOUT["single"]["width"] = width
|
|
LAYOUT["single"]["height"] = height
|
|
LAYOUT["dual_panel"]["width"] = width
|
|
LAYOUT["dual_panel"]["height"] = height
|
|
|
|
data_dict = {}
|
|
if len(prepared) == 1:
|
|
data_dict["_default"] = next(iter(prepared.values()))
|
|
else:
|
|
data_dict = dict(prepared)
|
|
data_dict["_default"] = next(iter(prepared.values()))
|
|
|
|
svg = render_chart(spec, data_dict)
|
|
|
|
filename = f"chart_{uuid.uuid4().hex[:8]}.svg"
|
|
svg_path = OUTPUT_DIR / filename
|
|
with open(svg_path, "w") as f:
|
|
f.write(svg)
|
|
|
|
spec_json = json.dumps(spec.model_dump(), indent=2, default=str)
|
|
return svg, filename, spec_json
|