Phase 3: Content Pipeline — file parsing, content intelligence, slide mapping, native charts
- Step 10: Extended file upload for Excel/CSV/images/URLs (openpyxl, trafilatura) - Step 11: Content intelligence service with rule-based + LLM classification - Step 12: Slide mapping engine mapping content blocks to master deck layouts - Step 13: Chart data extractor, native PPTX chart service (bar/line/pie/gantt/waterfall), ChartDataEditor skeleton Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
cf21ba4516
commit
a2bd4cfefa
13 changed files with 1986 additions and 26 deletions
|
|
@ -1,22 +1,51 @@
|
|||
from http.client import HTTPException
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
from models.decomposed_file_info import DecomposedFileInfo
|
||||
from services.temp_file_service import TEMP_FILE_SERVICE
|
||||
from services.documents_loader import DocumentsLoader
|
||||
import uuid
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from constants.documents import (
|
||||
EXCEL_TYPES,
|
||||
IMAGE_UPLOAD_TYPES,
|
||||
SPREADSHEET_TYPES,
|
||||
UPLOAD_ACCEPTED_FILE_TYPES,
|
||||
)
|
||||
from models.decomposed_file_info import DecomposedFileInfo
|
||||
from services.attachment_parser_service import (
|
||||
extract_images_metadata,
|
||||
parse_csv,
|
||||
parse_excel,
|
||||
parse_url,
|
||||
)
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from services.temp_file_service import TEMP_FILE_SERVICE
|
||||
from utils.validators import validate_files
|
||||
|
||||
FILES_ROUTER = APIRouter(prefix="/files", tags=["Files"])
|
||||
|
||||
|
||||
def _is_spreadsheet(file_path: str) -> bool:
|
||||
mime, _ = mimetypes.guess_type(file_path)
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
return (
|
||||
mime in EXCEL_TYPES
|
||||
or mime in SPREADSHEET_TYPES
|
||||
or ext in (".xlsx", ".xls", ".csv")
|
||||
)
|
||||
|
||||
|
||||
def _is_image(file_path: str) -> bool:
|
||||
mime, _ = mimetypes.guess_type(file_path)
|
||||
return mime in IMAGE_UPLOAD_TYPES
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/upload", response_model=List[str])
|
||||
async def upload_files(files: Optional[List[UploadFile]]):
|
||||
if not files:
|
||||
raise HTTPException(400, "Documents are required")
|
||||
raise HTTPException(status_code=400, detail="Documents are required")
|
||||
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
|
||||
|
|
@ -42,40 +71,117 @@ async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]):
|
|||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
|
||||
txt_files = []
|
||||
spreadsheet_files = []
|
||||
image_files = []
|
||||
other_files = []
|
||||
|
||||
for file_path in file_paths:
|
||||
if file_path.endswith(".txt"):
|
||||
txt_files.append(file_path)
|
||||
elif _is_spreadsheet(file_path):
|
||||
spreadsheet_files.append(file_path)
|
||||
elif _is_image(file_path):
|
||||
image_files.append(file_path)
|
||||
else:
|
||||
other_files.append(file_path)
|
||||
|
||||
documents_loader = DocumentsLoader(file_paths=other_files)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
response: List[DecomposedFileInfo] = []
|
||||
|
||||
response = []
|
||||
for index, parsed_doc in enumerate(parsed_documents):
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{uuid.uuid4()}.txt", temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
# --- Document files (PDF, DOCX, PPTX) via DocumentsLoader ---
|
||||
if other_files:
|
||||
documents_loader = DocumentsLoader(file_paths=other_files)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
|
||||
for index, parsed_doc in enumerate(parsed_documents):
|
||||
out_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{uuid.uuid4()}.txt", temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(out_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
response.append(
|
||||
DecomposedFileInfo(
|
||||
name=os.path.basename(other_files[index]),
|
||||
file_path=out_path,
|
||||
file_type="text",
|
||||
)
|
||||
)
|
||||
|
||||
# --- Plain text files ---
|
||||
for each_file in txt_files:
|
||||
response.append(
|
||||
DecomposedFileInfo(
|
||||
name=os.path.basename(other_files[index]), file_path=file_path
|
||||
name=os.path.basename(each_file),
|
||||
file_path=each_file,
|
||||
file_type="text",
|
||||
)
|
||||
)
|
||||
|
||||
# Return the txt documents as it is
|
||||
for each_file in txt_files:
|
||||
# --- Spreadsheet files (Excel, CSV) ---
|
||||
for sp_path in spreadsheet_files:
|
||||
ext = os.path.splitext(sp_path)[1].lower()
|
||||
if ext in (".xlsx", ".xls"):
|
||||
tables = parse_excel(sp_path)
|
||||
else:
|
||||
tables = [parse_csv(sp_path)]
|
||||
|
||||
# Store parsed table data as JSON file for downstream use
|
||||
json_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{uuid.uuid4()}.json", temp_dir
|
||||
)
|
||||
serialized = [t.model_dump() for t in tables]
|
||||
with open(json_path, "w") as jf:
|
||||
json.dump(serialized, jf)
|
||||
|
||||
response.append(
|
||||
DecomposedFileInfo(name=os.path.basename(each_file), file_path=each_file)
|
||||
DecomposedFileInfo(
|
||||
name=os.path.basename(sp_path),
|
||||
file_path=json_path,
|
||||
file_type="table",
|
||||
table_data=serialized,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Image files ---
|
||||
for img_path in image_files:
|
||||
info = extract_images_metadata(img_path)
|
||||
response.append(
|
||||
DecomposedFileInfo(
|
||||
name=info.filename,
|
||||
file_path=img_path,
|
||||
file_type="image",
|
||||
image_info=info.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class UrlParseRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class UrlParseResponse(BaseModel):
|
||||
content: str
|
||||
url: str
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/url", response_model=UrlParseResponse)
|
||||
async def parse_url_endpoint(body: UrlParseRequest):
|
||||
"""Fetch a URL and extract its article content as text."""
|
||||
if not body.url or not body.url.strip():
|
||||
raise HTTPException(status_code=400, detail="URL is required")
|
||||
|
||||
content = await parse_url(body.url)
|
||||
if not content:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Could not extract content from the provided URL"
|
||||
)
|
||||
|
||||
return UrlParseResponse(content=content, url=body.url)
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/update")
|
||||
async def update_files(
|
||||
file_path: Annotated[str, Body()],
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ WORD_TYPES = [
|
|||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
]
|
||||
SPREADSHEET_TYPES = ["text/csv", "application/csv"]
|
||||
EXCEL_TYPES = [
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-excel",
|
||||
]
|
||||
IMAGE_UPLOAD_TYPES = ["image/png", "image/jpeg", "image/webp", "image/gif"]
|
||||
|
||||
|
||||
PNG_MIME_TYPES = ["image/png"]
|
||||
|
|
@ -16,5 +21,11 @@ WEBP_MIME_TYPES = ["image/webp"]
|
|||
|
||||
|
||||
UPLOAD_ACCEPTED_FILE_TYPES = (
|
||||
PDF_MIME_TYPES + TEXT_MIME_TYPES + POWERPOINT_TYPES + WORD_TYPES
|
||||
PDF_MIME_TYPES
|
||||
+ TEXT_MIME_TYPES
|
||||
+ POWERPOINT_TYPES
|
||||
+ WORD_TYPES
|
||||
+ SPREADSHEET_TYPES
|
||||
+ EXCEL_TYPES
|
||||
+ IMAGE_UPLOAD_TYPES
|
||||
)
|
||||
|
|
|
|||
35
backend/models/content_models.py
Normal file
35
backend/models/content_models.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Content classification models for the content intelligence pipeline."""
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.attachment_parser_service import ImageInfo, TableData
|
||||
|
||||
|
||||
class ContentBlockType(str, Enum):
|
||||
narrative = "narrative"
|
||||
quote = "quote"
|
||||
metric = "metric"
|
||||
table = "table"
|
||||
timeline = "timeline"
|
||||
comparison = "comparison"
|
||||
list_items = "list_items"
|
||||
image_reference = "image_reference"
|
||||
call_to_action = "call_to_action"
|
||||
|
||||
|
||||
class ContentBlock(BaseModel):
|
||||
type: ContentBlockType
|
||||
raw_text: str
|
||||
extracted_data: Optional[Dict[str, Any]] = None
|
||||
source_section: Optional[str] = None
|
||||
priority: int = 5 # 1-10
|
||||
|
||||
|
||||
class ClassifiedContent(BaseModel):
|
||||
title: Optional[str] = None
|
||||
blocks: List[ContentBlock]
|
||||
tables: List[TableData] = []
|
||||
images: List[ImageInfo] = []
|
||||
summary: str = ""
|
||||
|
|
@ -1,6 +1,11 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DecomposedFileInfo(BaseModel):
|
||||
name: str
|
||||
file_path: str
|
||||
file_type: str = "text" # "text" | "table" | "image"
|
||||
table_data: Optional[List[dict]] = None # Serialized TableData for spreadsheets
|
||||
image_info: Optional[dict] = None # Serialized ImageInfo for images
|
||||
|
|
|
|||
|
|
@ -106,8 +106,17 @@ class PptxPictureModel(BaseModel):
|
|||
path: str
|
||||
|
||||
|
||||
class PptxChartDataModel(BaseModel):
|
||||
"""Inline chart data for native PPTX chart rendering."""
|
||||
chart_type: str = "column" # bar, column, line, pie, doughnut, area, scatter, gantt, waterfall
|
||||
title: str = "Chart"
|
||||
categories: List[str] = []
|
||||
series: List[dict] = [] # [{name: str, values: [float]}]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
class PptxShapeModel(BaseModel):
|
||||
shape_type: Literal["textbox", "autoshape", "picture", "connector"]
|
||||
shape_type: Literal["textbox", "autoshape", "picture", "connector", "chart"]
|
||||
|
||||
|
||||
class PptxTextBoxModel(PptxShapeModel):
|
||||
|
|
@ -154,6 +163,14 @@ class PptxConnectorModel(PptxShapeModel):
|
|||
opacity: float = 1.0
|
||||
|
||||
|
||||
class PptxChartBoxModel(PptxShapeModel):
|
||||
shape_type: Literal["chart"] = "chart"
|
||||
position: PptxPositionModel
|
||||
chart_data: PptxChartDataModel
|
||||
brand_colors: Optional[List[str]] = None
|
||||
font_name: Optional[str] = None
|
||||
|
||||
|
||||
class PptxSlideModel(BaseModel):
|
||||
background: Optional[PptxFillModel] = None
|
||||
note: Optional[str] = None
|
||||
|
|
@ -162,6 +179,7 @@ class PptxSlideModel(BaseModel):
|
|||
| PptxAutoShapeBoxModel
|
||||
| PptxConnectorModel
|
||||
| PptxPictureBoxModel
|
||||
| PptxChartBoxModel
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ dependencies = [
|
|||
"alembic>=1.15",
|
||||
"msal>=1.31",
|
||||
"python-jose[cryptography]>=3.3",
|
||||
"openpyxl>=3.1",
|
||||
"trafilatura>=2.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
|
|
|||
183
backend/services/attachment_parser_service.py
Normal file
183
backend/services/attachment_parser_service.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
"""Service for parsing non-document attachments: Excel, CSV, images, URLs."""
|
||||
import csv
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TableData(BaseModel):
|
||||
title: Optional[str] = None
|
||||
headers: List[str]
|
||||
rows: List[List[Any]]
|
||||
sheet_name: Optional[str] = None
|
||||
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
file_path: str
|
||||
filename: str
|
||||
mime_type: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
|
||||
def parse_excel(file_path: str) -> List[TableData]:
|
||||
"""Parse an Excel (.xlsx/.xls) file and return one TableData per sheet."""
|
||||
from openpyxl import load_workbook
|
||||
|
||||
wb = load_workbook(file_path, read_only=True, data_only=True)
|
||||
results: List[TableData] = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows_raw = list(ws.iter_rows(values_only=True))
|
||||
if not rows_raw:
|
||||
continue
|
||||
|
||||
# First non-empty row is treated as headers
|
||||
headers = [str(c) if c is not None else "" for c in rows_raw[0]]
|
||||
data_rows = []
|
||||
for row in rows_raw[1:]:
|
||||
# Skip completely empty rows
|
||||
if all(c is None for c in row):
|
||||
continue
|
||||
data_rows.append([_serialize_cell(c) for c in row])
|
||||
|
||||
if not data_rows and not any(h for h in headers):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
TableData(
|
||||
title=sheet_name if len(wb.sheetnames) > 1 else None,
|
||||
headers=headers,
|
||||
rows=data_rows,
|
||||
sheet_name=sheet_name,
|
||||
)
|
||||
)
|
||||
|
||||
wb.close()
|
||||
return results
|
||||
|
||||
|
||||
def parse_csv(file_path: str) -> TableData:
|
||||
"""Parse a CSV file and return a single TableData."""
|
||||
with open(file_path, "r", encoding="utf-8-sig") as f:
|
||||
# Sniff delimiter
|
||||
sample = f.read(4096)
|
||||
f.seek(0)
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(sample, delimiters=",;\t|")
|
||||
except csv.Error:
|
||||
dialect = csv.excel
|
||||
|
||||
reader = csv.reader(f, dialect)
|
||||
all_rows = list(reader)
|
||||
|
||||
if not all_rows:
|
||||
return TableData(headers=[], rows=[])
|
||||
|
||||
headers = all_rows[0]
|
||||
data_rows = [[_serialize_cell(c) for c in row] for row in all_rows[1:] if any(c.strip() for c in row)]
|
||||
|
||||
return TableData(
|
||||
title=os.path.splitext(os.path.basename(file_path))[0],
|
||||
headers=headers,
|
||||
rows=data_rows,
|
||||
)
|
||||
|
||||
|
||||
def extract_images_metadata(file_path: str) -> ImageInfo:
|
||||
"""Extract metadata from an image file (dimensions, MIME type)."""
|
||||
filename = os.path.basename(file_path)
|
||||
mime_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream"
|
||||
|
||||
width, height = None, None
|
||||
try:
|
||||
# Use python-pptx's image reader or basic header parsing
|
||||
# to avoid adding PIL as a dependency
|
||||
width, height = _read_image_dimensions(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ImageInfo(
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
mime_type=mime_type,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
|
||||
async def parse_url(url: str) -> str:
|
||||
"""Fetch a URL and extract its article content as markdown."""
|
||||
import trafilatura
|
||||
|
||||
downloaded = trafilatura.fetch_url(url)
|
||||
if not downloaded:
|
||||
return ""
|
||||
|
||||
text = trafilatura.extract(
|
||||
downloaded,
|
||||
output_format="txt",
|
||||
include_tables=True,
|
||||
include_links=False,
|
||||
include_images=False,
|
||||
)
|
||||
return text or ""
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _serialize_cell(value: Any) -> Any:
|
||||
"""Convert cell value to JSON-safe type."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _read_image_dimensions(file_path: str) -> tuple:
|
||||
"""Read image dimensions from file header (PNG/JPEG/GIF/WEBP)."""
|
||||
with open(file_path, "rb") as f:
|
||||
header = f.read(32)
|
||||
|
||||
# PNG
|
||||
if header[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
import struct
|
||||
|
||||
w, h = struct.unpack(">II", header[16:24])
|
||||
return w, h
|
||||
|
||||
# JPEG
|
||||
if header[:2] == b"\xff\xd8":
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(2)
|
||||
while True:
|
||||
marker = f.read(2)
|
||||
if len(marker) < 2:
|
||||
break
|
||||
if marker[0] != 0xFF:
|
||||
break
|
||||
if marker[1] in (0xC0, 0xC1, 0xC2):
|
||||
f.read(3) # length + precision
|
||||
import struct
|
||||
|
||||
h, w = struct.unpack(">HH", f.read(4))
|
||||
return w, h
|
||||
else:
|
||||
length = int.from_bytes(f.read(2), "big")
|
||||
f.seek(length - 2, 1)
|
||||
return None, None
|
||||
|
||||
# GIF
|
||||
if header[:6] in (b"GIF87a", b"GIF89a"):
|
||||
import struct
|
||||
|
||||
w, h = struct.unpack("<HH", header[6:10])
|
||||
return w, h
|
||||
|
||||
return None, None
|
||||
228
backend/services/chart_data_extractor.py
Normal file
228
backend/services/chart_data_extractor.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Chart Data Extractor: extract chart-ready data from content blocks and tables."""
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.content_models import ContentBlock, ContentBlockType
|
||||
from services.attachment_parser_service import TableData
|
||||
|
||||
|
||||
class ChartSeries(BaseModel):
|
||||
name: str
|
||||
values: List[float]
|
||||
|
||||
|
||||
class ChartData(BaseModel):
|
||||
chart_type: str # bar, column, line, pie, doughnut, area, scatter, gantt, waterfall
|
||||
title: str
|
||||
categories: List[str]
|
||||
series: List[ChartSeries]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
|
||||
def extract(
|
||||
content_block: ContentBlock,
|
||||
table_data: Optional[TableData] = None,
|
||||
) -> Optional[ChartData]:
|
||||
"""Extract chart data from a content block and/or associated table.
|
||||
|
||||
Returns ChartData if chartable data is found, else None.
|
||||
"""
|
||||
if table_data and table_data.rows and table_data.headers:
|
||||
return _chart_from_table(table_data)
|
||||
|
||||
if content_block.type == ContentBlockType.metric:
|
||||
return _chart_from_metrics(content_block)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# --- Table → ChartData ---
|
||||
|
||||
|
||||
def _chart_from_table(td: TableData) -> Optional[ChartData]:
|
||||
"""Convert a TableData into ChartData.
|
||||
|
||||
Heuristic: first column = categories, remaining numeric columns = series.
|
||||
"""
|
||||
if not td.rows or len(td.headers) < 2:
|
||||
return None
|
||||
|
||||
# Determine which columns are numeric (by checking majority of rows)
|
||||
numeric_cols = []
|
||||
for col_idx in range(1, len(td.headers)):
|
||||
numeric_count = 0
|
||||
for row in td.rows:
|
||||
if col_idx < len(row):
|
||||
val = row[col_idx]
|
||||
if _to_float(val) is not None:
|
||||
numeric_count += 1
|
||||
if numeric_count >= len(td.rows) * 0.5:
|
||||
numeric_cols.append(col_idx)
|
||||
|
||||
if not numeric_cols:
|
||||
return None
|
||||
|
||||
categories = []
|
||||
for row in td.rows:
|
||||
categories.append(str(row[0]) if row else "")
|
||||
|
||||
series_list: List[ChartSeries] = []
|
||||
for col_idx in numeric_cols:
|
||||
values = []
|
||||
for row in td.rows:
|
||||
val = row[col_idx] if col_idx < len(row) else 0
|
||||
values.append(_to_float(val) or 0.0)
|
||||
series_list.append(
|
||||
ChartSeries(name=td.headers[col_idx], values=values)
|
||||
)
|
||||
|
||||
chart_type = _recommend_chart_type(categories, series_list, td)
|
||||
title = td.title or td.sheet_name or "Chart"
|
||||
|
||||
return ChartData(
|
||||
chart_type=chart_type,
|
||||
title=title,
|
||||
categories=categories,
|
||||
series=series_list,
|
||||
)
|
||||
|
||||
|
||||
# --- Metric block → ChartData ---
|
||||
|
||||
_NUMBER_RE = re.compile(
|
||||
r"[\$€£¥]?\s?(\d[\d,.]*)\s?([KMBTkmbt%]?)",
|
||||
)
|
||||
|
||||
|
||||
def _chart_from_metrics(block: ContentBlock) -> Optional[ChartData]:
|
||||
"""Build ChartData from a metric content block's extracted_data."""
|
||||
metrics = (block.extracted_data or {}).get("metrics", [])
|
||||
if not metrics:
|
||||
return None
|
||||
|
||||
categories = []
|
||||
values = []
|
||||
unit = None
|
||||
|
||||
for m in metrics:
|
||||
label = m.get("label", "").strip()
|
||||
raw_value = m.get("value", "")
|
||||
parsed = _parse_metric_value(raw_value)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
numeric_val, val_unit = parsed
|
||||
if val_unit and not unit:
|
||||
unit = val_unit
|
||||
|
||||
categories.append(label or f"Metric {len(categories) + 1}")
|
||||
values.append(numeric_val)
|
||||
|
||||
if len(values) < 2:
|
||||
return None
|
||||
|
||||
chart_type = "bar"
|
||||
# If all values are percentages and sum near 100, use pie
|
||||
if unit == "%" and 90 <= sum(values) <= 110:
|
||||
chart_type = "pie"
|
||||
|
||||
return ChartData(
|
||||
chart_type=chart_type,
|
||||
title=block.source_section or "Key Metrics",
|
||||
categories=categories,
|
||||
series=[ChartSeries(name="Value", values=values)],
|
||||
unit=unit,
|
||||
)
|
||||
|
||||
|
||||
# --- Chart type recommendation ---
|
||||
|
||||
_TIME_PATTERN = re.compile(
|
||||
r"(?:19|20)\d{2}|Q[1-4]|(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _recommend_chart_type(
|
||||
categories: List[str],
|
||||
series: List[ChartSeries],
|
||||
td: Optional[TableData] = None,
|
||||
) -> str:
|
||||
"""Auto-recommend a chart type based on data characteristics."""
|
||||
n_cats = len(categories)
|
||||
n_series = len(series)
|
||||
|
||||
# Check if categories look like time periods
|
||||
time_count = sum(1 for c in categories if _TIME_PATTERN.search(c))
|
||||
is_time_series = time_count >= n_cats * 0.6
|
||||
|
||||
if is_time_series:
|
||||
return "line"
|
||||
|
||||
# Single series
|
||||
if n_series == 1:
|
||||
vals = series[0].values
|
||||
# Parts of a whole
|
||||
total = sum(vals)
|
||||
if 2 <= n_cats <= 8 and 90 <= total <= 110:
|
||||
return "pie"
|
||||
if n_cats <= 6:
|
||||
return "bar"
|
||||
return "column"
|
||||
|
||||
# Multiple series
|
||||
if n_series == 2:
|
||||
return "bar" # grouped bar
|
||||
|
||||
return "column"
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _to_float(val) -> Optional[float]:
|
||||
"""Convert a cell value to float, handling common formats."""
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (int, float)):
|
||||
return float(val)
|
||||
if isinstance(val, str):
|
||||
cleaned = val.strip().replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "").rstrip("%")
|
||||
try:
|
||||
return float(cleaned)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _parse_metric_value(raw: str) -> Optional[tuple]:
|
||||
"""Parse a metric value string like '$2.3M' or '45%' into (float, unit)."""
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
raw = raw.strip()
|
||||
unit = None
|
||||
|
||||
if raw.endswith("%"):
|
||||
unit = "%"
|
||||
raw = raw.rstrip("%").strip()
|
||||
elif raw[-1:].upper() in ("K", "M", "B", "T"):
|
||||
suffix = raw[-1].upper()
|
||||
multipliers = {"K": 1_000, "M": 1_000_000, "B": 1_000_000_000, "T": 1_000_000_000_000}
|
||||
raw_num = raw[:-1].strip()
|
||||
cleaned = raw_num.replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "")
|
||||
try:
|
||||
return float(cleaned) * multipliers[suffix], suffix
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
cleaned = raw.replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "")
|
||||
try:
|
||||
return float(cleaned), unit
|
||||
except ValueError:
|
||||
return None
|
||||
430
backend/services/content_intelligence_service.py
Normal file
430
backend/services/content_intelligence_service.py
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
"""Content Intelligence Service: classify brief content into typed blocks for slide mapping."""
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from models.content_models import (
|
||||
ClassifiedContent,
|
||||
ContentBlock,
|
||||
ContentBlockType,
|
||||
)
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from services.attachment_parser_service import ImageInfo, TableData
|
||||
from services.llm_client import LLMClient
|
||||
from services.score_based_chunker import ScoreBasedChunker
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
# --- Regex patterns for rule-based classification ---
|
||||
|
||||
_METRIC_RE = re.compile(
|
||||
r"""
|
||||
(?: # value-first: $2.3M, 45%, 1,200 units
|
||||
[\$€£¥]\s?\d[\d,.]*[KMBTkmbt%]? |
|
||||
\d[\d,.]*\s?% |
|
||||
\d[\d,.]*\s?[KMBTkmbt]\b
|
||||
)
|
||||
|
|
||||
(?: # "grew 45%", "increased by $2M"
|
||||
(?:grew|growth|increased?|decreased?|rose|fell|dropped|declined|revenue|profit|margin|roi|cagr|arpu)
|
||||
.{0,30}?
|
||||
[\$€£¥]?\d[\d,.]*[KMBTkmbt%]?
|
||||
)
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
|
||||
_QUOTE_RE = re.compile(
|
||||
r'["\u201c\u201d].{15,300}?["\u201c\u201d]' # 15-300 chars inside quotes
|
||||
r"(?:\s*[-\u2014\u2013]\s*.{2,60})?", # optional attribution
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_TABLE_RE = re.compile(r"^\|.+\|$", re.MULTILINE)
|
||||
|
||||
_TIMELINE_RE = re.compile(
|
||||
r"(?:(?:19|20)\d{2}|Q[1-4]|(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\w*\s+\d{4})",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_COMPARISON_RE = re.compile(
|
||||
r"\b(?:vs\.?|versus|compared?\s+to|in\s+contrast|on\s+the\s+other\s+hand|whereas|alternatively)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_LIST_RE = re.compile(r"^[\s]*[-*•]\s+.+", re.MULTILINE)
|
||||
_NUMBERED_LIST_RE = re.compile(r"^[\s]*\d+[.)]\s+.+", re.MULTILINE)
|
||||
|
||||
_IMAGE_REF_RE = re.compile(
|
||||
r"(?:!\[|see\s+(?:figure|image|diagram|chart|photo)|attached\s+image|\.(?:png|jpg|jpeg|gif|webp|svg)\b)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_CTA_RE = re.compile(
|
||||
r"\b(?:contact\s+us|get\s+started|sign\s+up|learn\s+more|next\s+steps|action\s+items|call\s+to\s+action|let's\s+(?:discuss|connect|talk))\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Priority map: higher = more important for presentation
|
||||
_PRIORITY_MAP = {
|
||||
ContentBlockType.metric: 8,
|
||||
ContentBlockType.quote: 7,
|
||||
ContentBlockType.table: 6,
|
||||
ContentBlockType.timeline: 6,
|
||||
ContentBlockType.comparison: 6,
|
||||
ContentBlockType.call_to_action: 7,
|
||||
ContentBlockType.list_items: 5,
|
||||
ContentBlockType.image_reference: 5,
|
||||
ContentBlockType.narrative: 4,
|
||||
}
|
||||
|
||||
|
||||
class ContentIntelligenceService:
|
||||
|
||||
def __init__(self):
|
||||
self._chunker = ScoreBasedChunker()
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
markdown: str,
|
||||
tables: Optional[List[TableData]] = None,
|
||||
images: Optional[List[ImageInfo]] = None,
|
||||
) -> ClassifiedContent:
|
||||
"""Classify markdown content into typed content blocks."""
|
||||
tables = tables or []
|
||||
images = images or []
|
||||
|
||||
# 1. Extract a title (first heading, if any)
|
||||
title = self._extract_title(markdown)
|
||||
|
||||
# 2. Chunk the content
|
||||
chunks = await self._chunk_content(markdown)
|
||||
|
||||
# 3. Rule-based classification per chunk
|
||||
blocks: List[ContentBlock] = []
|
||||
ambiguous_chunks: List[tuple] = [] # (index, text) for LLM classification
|
||||
|
||||
for chunk in chunks:
|
||||
text = f"{chunk.heading}\n{chunk.content}".strip()
|
||||
block_type = self._classify_by_rules(text)
|
||||
|
||||
if block_type:
|
||||
extracted = self._extract_data(block_type, text)
|
||||
blocks.append(
|
||||
ContentBlock(
|
||||
type=block_type,
|
||||
raw_text=text,
|
||||
extracted_data=extracted,
|
||||
source_section=chunk.heading.lstrip("# ").strip(),
|
||||
priority=_PRIORITY_MAP.get(block_type, 4),
|
||||
)
|
||||
)
|
||||
else:
|
||||
ambiguous_chunks.append((len(blocks), text))
|
||||
# Placeholder — will be replaced after LLM classification
|
||||
blocks.append(
|
||||
ContentBlock(
|
||||
type=ContentBlockType.narrative,
|
||||
raw_text=text,
|
||||
source_section=chunk.heading.lstrip("# ").strip(),
|
||||
priority=4,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. LLM batch classification for ambiguous blocks
|
||||
if ambiguous_chunks:
|
||||
llm_types = await self._llm_classify_batch(
|
||||
[text for _, text in ambiguous_chunks]
|
||||
)
|
||||
for (idx, text), btype in zip(ambiguous_chunks, llm_types):
|
||||
extracted = self._extract_data(btype, text)
|
||||
blocks[idx] = ContentBlock(
|
||||
type=btype,
|
||||
raw_text=text,
|
||||
extracted_data=extracted,
|
||||
source_section=blocks[idx].source_section,
|
||||
priority=_PRIORITY_MAP.get(btype, 4),
|
||||
)
|
||||
|
||||
# 5. Merge attachment data
|
||||
for td in tables:
|
||||
blocks.append(
|
||||
ContentBlock(
|
||||
type=ContentBlockType.table,
|
||||
raw_text=f"Table: {td.title or td.sheet_name or 'Data'}\n"
|
||||
f"Headers: {', '.join(td.headers)}\n"
|
||||
f"Rows: {len(td.rows)}",
|
||||
extracted_data={"headers": td.headers, "row_count": len(td.rows)},
|
||||
source_section=td.title or td.sheet_name,
|
||||
priority=_PRIORITY_MAP[ContentBlockType.table],
|
||||
)
|
||||
)
|
||||
|
||||
for img in images:
|
||||
blocks.append(
|
||||
ContentBlock(
|
||||
type=ContentBlockType.image_reference,
|
||||
raw_text=f"Image: {img.filename}",
|
||||
extracted_data={
|
||||
"file_path": img.file_path,
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
},
|
||||
source_section=None,
|
||||
priority=_PRIORITY_MAP[ContentBlockType.image_reference],
|
||||
)
|
||||
)
|
||||
|
||||
# 6. Sort by priority (descending), preserving order for same priority
|
||||
blocks.sort(key=lambda b: -b.priority)
|
||||
|
||||
# 7. Generate summary
|
||||
summary = await self._generate_summary(markdown, blocks)
|
||||
|
||||
return ClassifiedContent(
|
||||
title=title,
|
||||
blocks=blocks,
|
||||
tables=tables,
|
||||
images=images,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
async def ask_followup_questions(
|
||||
self, content: ClassifiedContent
|
||||
) -> Optional[List[str]]:
|
||||
"""Ask follow-up questions if content is too thin."""
|
||||
total_words = sum(len(b.raw_text.split()) for b in content.blocks)
|
||||
if total_words >= 200 and len(content.blocks) >= 3:
|
||||
return None
|
||||
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
block_summary = "\n".join(
|
||||
f"- [{b.type.value}] {b.raw_text[:100]}..." for b in content.blocks[:10]
|
||||
)
|
||||
|
||||
messages = [
|
||||
LLMSystemMessage(
|
||||
content="You help identify missing information for a presentation brief. "
|
||||
"Return a JSON array of 2-4 short questions that would help create a more complete presentation."
|
||||
),
|
||||
LLMUserMessage(
|
||||
content=f"The user provided a brief with {total_words} words and {len(content.blocks)} content blocks:\n\n"
|
||||
f"{block_summary}\n\n"
|
||||
"What additional information would be helpful?"
|
||||
),
|
||||
]
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 2,
|
||||
"maxItems": 4,
|
||||
}
|
||||
},
|
||||
"required": ["questions"],
|
||||
}
|
||||
|
||||
try:
|
||||
result = await client.generate_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=schema,
|
||||
)
|
||||
return result.get("questions", [])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# --- Internal methods ---
|
||||
|
||||
def _extract_title(self, markdown: str) -> Optional[str]:
|
||||
for line in markdown.split("\n"):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("# ") and not stripped.startswith("## "):
|
||||
return stripped.lstrip("# ").strip()
|
||||
return None
|
||||
|
||||
async def _chunk_content(self, markdown: str):
|
||||
"""Chunk using ScoreBasedChunker. Fall back to paragraph splitting."""
|
||||
try:
|
||||
headings = self._chunker.extract_headings(markdown)
|
||||
if len(headings) >= 2:
|
||||
scores = self._chunker.score_headings(headings)
|
||||
chunks = self._chunker.get_chunks_from_headings(
|
||||
markdown, headings, scores, top_k=30
|
||||
)
|
||||
if chunks:
|
||||
return chunks
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: split by double newlines (paragraph-based)
|
||||
from models.document_chunk import DocumentChunk
|
||||
|
||||
paragraphs = [p.strip() for p in re.split(r"\n{2,}", markdown) if p.strip()]
|
||||
return [
|
||||
DocumentChunk(
|
||||
heading=f"Section {i + 1}",
|
||||
content=p,
|
||||
heading_index=i,
|
||||
score=5.0,
|
||||
)
|
||||
for i, p in enumerate(paragraphs)
|
||||
if len(p) > 20
|
||||
]
|
||||
|
||||
def _classify_by_rules(self, text: str) -> Optional[ContentBlockType]:
|
||||
"""Apply rule-based classification. Returns None if ambiguous."""
|
||||
# Check from most specific to least
|
||||
if _QUOTE_RE.search(text):
|
||||
return ContentBlockType.quote
|
||||
|
||||
if _TABLE_RE.search(text):
|
||||
return ContentBlockType.table
|
||||
|
||||
if _IMAGE_REF_RE.search(text):
|
||||
return ContentBlockType.image_reference
|
||||
|
||||
if _CTA_RE.search(text):
|
||||
return ContentBlockType.call_to_action
|
||||
|
||||
metric_matches = _METRIC_RE.findall(text)
|
||||
if len(metric_matches) >= 2:
|
||||
return ContentBlockType.metric
|
||||
|
||||
timeline_matches = _TIMELINE_RE.findall(text)
|
||||
if len(timeline_matches) >= 2:
|
||||
return ContentBlockType.timeline
|
||||
|
||||
if _COMPARISON_RE.search(text):
|
||||
return ContentBlockType.comparison
|
||||
|
||||
list_matches = _LIST_RE.findall(text)
|
||||
numbered_matches = _NUMBERED_LIST_RE.findall(text)
|
||||
if len(list_matches) >= 3 or len(numbered_matches) >= 3:
|
||||
return ContentBlockType.list_items
|
||||
|
||||
# Single metric mention
|
||||
if metric_matches:
|
||||
return ContentBlockType.metric
|
||||
|
||||
return None # Ambiguous — defer to LLM
|
||||
|
||||
def _extract_data(
|
||||
self, block_type: ContentBlockType, text: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Extract structured data from a content block based on its type."""
|
||||
if block_type == ContentBlockType.metric:
|
||||
return self._extract_metric_data(text)
|
||||
if block_type == ContentBlockType.quote:
|
||||
return self._extract_quote_data(text)
|
||||
return None
|
||||
|
||||
def _extract_metric_data(self, text: str) -> Dict[str, Any]:
|
||||
"""Extract numeric values and labels from metric text."""
|
||||
metrics = []
|
||||
# Pattern: label ... value
|
||||
for match in re.finditer(
|
||||
r"([\w\s]+?)\s*(?::|is|was|reached|hit|grew\s+to|of)\s*"
|
||||
r"([\$€£¥]?\s?\d[\d,.]*\s?[KMBTkmbt%]*)",
|
||||
text,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
label = match.group(1).strip()
|
||||
value = match.group(2).strip()
|
||||
if len(label) < 50:
|
||||
metrics.append({"label": label, "value": value})
|
||||
|
||||
# Fallback: just extract all numbers with context
|
||||
if not metrics:
|
||||
for match in _METRIC_RE.finditer(text):
|
||||
metrics.append({"value": match.group().strip()})
|
||||
|
||||
return {"metrics": metrics[:10]}
|
||||
|
||||
def _extract_quote_data(self, text: str) -> Dict[str, Any]:
|
||||
"""Extract quote text and attribution."""
|
||||
match = _QUOTE_RE.search(text)
|
||||
if match:
|
||||
full = match.group()
|
||||
# Try to split attribution
|
||||
parts = re.split(r"\s*[-\u2014\u2013]\s*", full, maxsplit=1)
|
||||
quote_text = parts[0].strip().strip('"\u201c\u201d')
|
||||
attribution = parts[1].strip() if len(parts) > 1 else None
|
||||
return {"quote": quote_text, "attribution": attribution}
|
||||
return {}
|
||||
|
||||
async def _llm_classify_batch(
|
||||
self, texts: List[str]
|
||||
) -> List[ContentBlockType]:
|
||||
"""Use LLM to classify a batch of ambiguous text chunks."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
types_list = ", ".join(t.value for t in ContentBlockType)
|
||||
numbered = "\n".join(f"{i + 1}. {t[:300]}" for i, t in enumerate(texts))
|
||||
|
||||
messages = [
|
||||
LLMSystemMessage(
|
||||
content=f"Classify each numbered text chunk into one of these content types: {types_list}.\n"
|
||||
"Return a JSON object with a 'classifications' array of strings, one per chunk, in order."
|
||||
),
|
||||
LLMUserMessage(content=numbered),
|
||||
]
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"classifications": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": [t.value for t in ContentBlockType]},
|
||||
}
|
||||
},
|
||||
"required": ["classifications"],
|
||||
}
|
||||
|
||||
try:
|
||||
result = await client.generate_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=schema,
|
||||
)
|
||||
classifications = result.get("classifications", [])
|
||||
output = []
|
||||
for i, text in enumerate(texts):
|
||||
if i < len(classifications):
|
||||
try:
|
||||
output.append(ContentBlockType(classifications[i]))
|
||||
except ValueError:
|
||||
output.append(ContentBlockType.narrative)
|
||||
else:
|
||||
output.append(ContentBlockType.narrative)
|
||||
return output
|
||||
except Exception:
|
||||
return [ContentBlockType.narrative] * len(texts)
|
||||
|
||||
async def _generate_summary(
|
||||
self, markdown: str, blocks: List[ContentBlock]
|
||||
) -> str:
|
||||
"""Generate a brief summary of the content."""
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
messages = [
|
||||
LLMSystemMessage(
|
||||
content="Summarize the following content in 1-2 sentences for use as a presentation overview."
|
||||
),
|
||||
LLMUserMessage(content=markdown[:3000]),
|
||||
]
|
||||
|
||||
try:
|
||||
result = await client.generate(model=model, messages=messages)
|
||||
return result.strip()[:500]
|
||||
except Exception:
|
||||
# Fallback: first 200 chars
|
||||
return markdown[:200].strip() + "..."
|
||||
327
backend/services/native_chart_service.py
Normal file
327
backend/services/native_chart_service.py
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
"""Native Chart Service: render ChartData as native python-pptx charts on slides."""
|
||||
from typing import List, Optional
|
||||
|
||||
from pptx.chart.data import CategoryChartData
|
||||
from pptx.dml.color import RGBColor
|
||||
from pptx.enum.chart import XL_CHART_TYPE, XL_LEGEND_POSITION, XL_LABEL_POSITION
|
||||
from pptx.oxml.xmlchemy import OxmlElement
|
||||
from pptx.slide import Slide
|
||||
from pptx.util import Emu, Pt
|
||||
|
||||
from services.chart_data_extractor import ChartData
|
||||
|
||||
|
||||
# Map our chart_type strings to python-pptx chart type enums
|
||||
_CHART_TYPE_MAP = {
|
||||
"bar": XL_CHART_TYPE.BAR_CLUSTERED,
|
||||
"column": XL_CHART_TYPE.COLUMN_CLUSTERED,
|
||||
"line": XL_CHART_TYPE.LINE_MARKERS,
|
||||
"pie": XL_CHART_TYPE.PIE,
|
||||
"doughnut": XL_CHART_TYPE.DOUGHNUT,
|
||||
"area": XL_CHART_TYPE.AREA,
|
||||
"scatter": XL_CHART_TYPE.XY_SCATTER,
|
||||
}
|
||||
|
||||
# Default brand-neutral colors for chart series
|
||||
_DEFAULT_COLORS = [
|
||||
"4472C4", "ED7D31", "A5A5A5", "FFC000", "5B9BD5",
|
||||
"70AD47", "264478", "9B57A0", "636363", "EB6E1F",
|
||||
]
|
||||
|
||||
|
||||
class NativeChartService:
|
||||
|
||||
def add_chart(
|
||||
self,
|
||||
slide: Slide,
|
||||
chart_data: ChartData,
|
||||
left: int,
|
||||
top: int,
|
||||
width: int,
|
||||
height: int,
|
||||
brand_colors: Optional[List[str]] = None,
|
||||
font_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Add a native chart to a slide.
|
||||
|
||||
For standard chart types (bar, column, line, pie, doughnut, area, scatter)
|
||||
uses python-pptx's add_chart API. For gantt and waterfall, falls back
|
||||
to shape-based rendering.
|
||||
|
||||
Args:
|
||||
slide: The pptx Slide object
|
||||
chart_data: ChartData with type, categories, series
|
||||
left, top, width, height: Position/size in Pt units (raw int, will be wrapped)
|
||||
brand_colors: List of hex color strings (e.g. ["4472C4", "ED7D31"])
|
||||
font_name: Font family name for labels
|
||||
"""
|
||||
if chart_data.chart_type == "gantt":
|
||||
self._add_gantt_chart(slide, chart_data, left, top, width, height, brand_colors, font_name)
|
||||
return
|
||||
|
||||
if chart_data.chart_type == "waterfall":
|
||||
self._add_waterfall_chart(slide, chart_data, left, top, width, height, brand_colors, font_name)
|
||||
return
|
||||
|
||||
xl_chart_type = _CHART_TYPE_MAP.get(chart_data.chart_type, XL_CHART_TYPE.COLUMN_CLUSTERED)
|
||||
colors = brand_colors or _DEFAULT_COLORS
|
||||
|
||||
# Build chart data object
|
||||
pptx_data = CategoryChartData()
|
||||
pptx_data.categories = chart_data.categories
|
||||
|
||||
for series in chart_data.series:
|
||||
pptx_data.add_series(series.name, series.values)
|
||||
|
||||
# Add chart to slide
|
||||
chart_frame = slide.shapes.add_chart(
|
||||
xl_chart_type,
|
||||
Pt(left), Pt(top), Pt(width), Pt(height),
|
||||
pptx_data,
|
||||
)
|
||||
chart = chart_frame.chart
|
||||
|
||||
# Style the chart
|
||||
chart.has_legend = len(chart_data.series) > 1
|
||||
if chart.has_legend:
|
||||
chart.legend.position = XL_LEGEND_POSITION.BOTTOM
|
||||
chart.legend.include_in_layout = False
|
||||
if font_name:
|
||||
chart.legend.font.name = font_name
|
||||
chart.legend.font.size = Pt(9)
|
||||
|
||||
# Apply title
|
||||
chart.has_title = True
|
||||
chart.chart_title.text_frame.text = chart_data.title
|
||||
if font_name:
|
||||
chart.chart_title.text_frame.paragraphs[0].font.name = font_name
|
||||
chart.chart_title.text_frame.paragraphs[0].font.size = Pt(12)
|
||||
chart.chart_title.text_frame.paragraphs[0].font.bold = True
|
||||
|
||||
# Apply brand colors to series
|
||||
self._apply_series_colors(chart, colors)
|
||||
|
||||
# Style axes
|
||||
if chart_data.chart_type not in ("pie", "doughnut"):
|
||||
self._style_axes(chart, font_name)
|
||||
|
||||
# Add data labels for pie/doughnut
|
||||
if chart_data.chart_type in ("pie", "doughnut"):
|
||||
self._add_pie_labels(chart, font_name)
|
||||
|
||||
def _apply_series_colors(self, chart, colors: List[str]) -> None:
|
||||
"""Apply brand colors to each series in the chart."""
|
||||
plot = chart.plots[0]
|
||||
for i, series in enumerate(plot.series):
|
||||
color_hex = colors[i % len(colors)]
|
||||
series.format.fill.solid()
|
||||
series.format.fill.fore_color.rgb = RGBColor.from_string(color_hex)
|
||||
|
||||
# For line charts, also color the line
|
||||
if hasattr(series, 'smooth'):
|
||||
series.format.line.color.rgb = RGBColor.from_string(color_hex)
|
||||
|
||||
def _style_axes(self, chart, font_name: Optional[str]) -> None:
|
||||
"""Style category and value axes."""
|
||||
try:
|
||||
category_axis = chart.category_axis
|
||||
category_axis.has_minor_gridlines = False
|
||||
if font_name:
|
||||
category_axis.tick_labels.font.name = font_name
|
||||
category_axis.tick_labels.font.size = Pt(8)
|
||||
|
||||
value_axis = chart.value_axis
|
||||
value_axis.has_minor_gridlines = False
|
||||
if font_name:
|
||||
value_axis.tick_labels.font.name = font_name
|
||||
value_axis.tick_labels.font.size = Pt(8)
|
||||
except Exception:
|
||||
pass # Some chart types don't have axes
|
||||
|
||||
def _add_pie_labels(self, chart, font_name: Optional[str]) -> None:
|
||||
"""Add percentage labels to pie/doughnut charts."""
|
||||
plot = chart.plots[0]
|
||||
plot.has_data_labels = True
|
||||
data_labels = plot.data_labels
|
||||
data_labels.show_percentage = True
|
||||
data_labels.show_category_name = False
|
||||
data_labels.show_value = False
|
||||
data_labels.number_format = '0%'
|
||||
if font_name:
|
||||
data_labels.font.name = font_name
|
||||
data_labels.font.size = Pt(9)
|
||||
|
||||
# --- Shape-based charts (Gantt, Waterfall) ---
|
||||
|
||||
def _add_gantt_chart(
|
||||
self,
|
||||
slide: Slide,
|
||||
chart_data: ChartData,
|
||||
left: int, top: int, width: int, height: int,
|
||||
brand_colors: Optional[List[str]] = None,
|
||||
font_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Render a Gantt chart using rectangles.
|
||||
|
||||
Expects categories = task names, series[0] = start values, series[1] = duration values.
|
||||
Values are numeric (e.g. week numbers or day offsets).
|
||||
"""
|
||||
colors = brand_colors or _DEFAULT_COLORS
|
||||
if len(chart_data.series) < 2:
|
||||
return
|
||||
|
||||
starts = chart_data.series[0].values
|
||||
durations = chart_data.series[1].values
|
||||
tasks = chart_data.categories
|
||||
|
||||
n_tasks = len(tasks)
|
||||
if n_tasks == 0:
|
||||
return
|
||||
|
||||
# Calculate bounds
|
||||
max_end = max(s + d for s, d in zip(starts, durations)) if starts else 1
|
||||
min_start = min(starts) if starts else 0
|
||||
|
||||
chart_left = left + 120 # leave room for labels
|
||||
chart_width = width - 130
|
||||
bar_height_total = height - 40 # leave room for title
|
||||
bar_h = max(bar_height_total // n_tasks - 4, 10)
|
||||
|
||||
# Title
|
||||
title_box = slide.shapes.add_textbox(Pt(left), Pt(top), Pt(width), Pt(24))
|
||||
tf = title_box.text_frame
|
||||
tf.text = chart_data.title
|
||||
if font_name:
|
||||
tf.paragraphs[0].font.name = font_name
|
||||
tf.paragraphs[0].font.size = Pt(12)
|
||||
tf.paragraphs[0].font.bold = True
|
||||
|
||||
# Draw task bars
|
||||
range_span = max_end - min_start or 1
|
||||
for i, (task, start, dur) in enumerate(zip(tasks, starts, durations)):
|
||||
y = top + 30 + i * (bar_h + 4)
|
||||
|
||||
# Task label
|
||||
label = slide.shapes.add_textbox(Pt(left), Pt(y), Pt(115), Pt(bar_h))
|
||||
label.text_frame.word_wrap = True
|
||||
label.text_frame.text = task
|
||||
if font_name:
|
||||
label.text_frame.paragraphs[0].font.name = font_name
|
||||
label.text_frame.paragraphs[0].font.size = Pt(8)
|
||||
|
||||
# Bar
|
||||
bar_x = chart_left + int((start - min_start) / range_span * chart_width)
|
||||
bar_w = max(int(dur / range_span * chart_width), 6)
|
||||
|
||||
from pptx.enum.shapes import MSO_SHAPE
|
||||
|
||||
bar = slide.shapes.add_shape(
|
||||
MSO_SHAPE.ROUNDED_RECTANGLE,
|
||||
Pt(bar_x), Pt(y), Pt(bar_w), Pt(bar_h),
|
||||
)
|
||||
bar.fill.solid()
|
||||
color = colors[i % len(colors)]
|
||||
bar.fill.fore_color.rgb = RGBColor.from_string(color)
|
||||
bar.line.fill.background() # no border
|
||||
|
||||
def _add_waterfall_chart(
|
||||
self,
|
||||
slide: Slide,
|
||||
chart_data: ChartData,
|
||||
left: int, top: int, width: int, height: int,
|
||||
brand_colors: Optional[List[str]] = None,
|
||||
font_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Render a waterfall chart using stacked shapes.
|
||||
|
||||
series[0].values = incremental changes (positive or negative).
|
||||
The last category is treated as the total.
|
||||
"""
|
||||
colors = brand_colors or _DEFAULT_COLORS
|
||||
if not chart_data.series:
|
||||
return
|
||||
|
||||
values = chart_data.series[0].values
|
||||
cats = chart_data.categories
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
# Compute cumulative
|
||||
cumulative = []
|
||||
running = 0
|
||||
for v in values:
|
||||
cumulative.append(running)
|
||||
running += v
|
||||
|
||||
all_levels = cumulative + [running]
|
||||
max_val = max(max(all_levels), max(abs(v) for v in values), 1)
|
||||
min_val = min(min(all_levels), 0)
|
||||
val_range = max_val - min_val or 1
|
||||
|
||||
chart_area_top = top + 30
|
||||
chart_area_height = height - 60
|
||||
chart_area_left = left + 10
|
||||
chart_area_width = width - 20
|
||||
bar_width = max(chart_area_width // n - 8, 12)
|
||||
|
||||
# Title
|
||||
title_box = slide.shapes.add_textbox(Pt(left), Pt(top), Pt(width), Pt(24))
|
||||
tf = title_box.text_frame
|
||||
tf.text = chart_data.title
|
||||
if font_name:
|
||||
tf.paragraphs[0].font.name = font_name
|
||||
tf.paragraphs[0].font.size = Pt(12)
|
||||
tf.paragraphs[0].font.bold = True
|
||||
|
||||
from pptx.enum.shapes import MSO_SHAPE
|
||||
|
||||
positive_color = colors[0] if colors else "4472C4"
|
||||
negative_color = colors[1] if len(colors) > 1 else "ED7D31"
|
||||
total_color = colors[2] if len(colors) > 2 else "A5A5A5"
|
||||
|
||||
for i in range(n):
|
||||
x = chart_area_left + i * (bar_width + 8)
|
||||
val = values[i]
|
||||
base = cumulative[i]
|
||||
|
||||
is_last = i == n - 1
|
||||
|
||||
if is_last:
|
||||
# Total bar from 0 to cumulative total
|
||||
total = base + val
|
||||
bar_bottom = 0
|
||||
bar_val = total
|
||||
color = total_color
|
||||
else:
|
||||
if val >= 0:
|
||||
bar_bottom = base
|
||||
bar_val = val
|
||||
color = positive_color
|
||||
else:
|
||||
bar_bottom = base + val
|
||||
bar_val = abs(val)
|
||||
color = negative_color
|
||||
|
||||
# Convert to pixel positions
|
||||
bar_top_y = chart_area_top + int((max_val - bar_bottom - bar_val) / val_range * chart_area_height)
|
||||
bar_h = max(int(bar_val / val_range * chart_area_height), 4)
|
||||
|
||||
bar = slide.shapes.add_shape(
|
||||
MSO_SHAPE.RECTANGLE,
|
||||
Pt(x), Pt(bar_top_y), Pt(bar_width), Pt(bar_h),
|
||||
)
|
||||
bar.fill.solid()
|
||||
bar.fill.fore_color.rgb = RGBColor.from_string(color)
|
||||
bar.line.fill.background()
|
||||
|
||||
# Category label below
|
||||
lbl = slide.shapes.add_textbox(
|
||||
Pt(x - 4), Pt(chart_area_top + chart_area_height + 2),
|
||||
Pt(bar_width + 8), Pt(20),
|
||||
)
|
||||
lbl.text_frame.word_wrap = True
|
||||
lbl.text_frame.text = cats[i] if i < len(cats) else ""
|
||||
if font_name:
|
||||
lbl.text_frame.paragraphs[0].font.name = font_name
|
||||
lbl.text_frame.paragraphs[0].font.size = Pt(7)
|
||||
|
|
@ -20,6 +20,7 @@ from pptx.dml.color import RGBColor
|
|||
from models.pptx_models import (
|
||||
PptxAutoShapeBoxModel,
|
||||
PptxBoxShapeEnum,
|
||||
PptxChartBoxModel,
|
||||
PptxConnectorModel,
|
||||
PptxFillModel,
|
||||
PptxFontModel,
|
||||
|
|
@ -34,6 +35,8 @@ from models.pptx_models import (
|
|||
PptxTextBoxModel,
|
||||
PptxTextRunModel,
|
||||
)
|
||||
from services.native_chart_service import NativeChartService
|
||||
from services.chart_data_extractor import ChartData, ChartSeries
|
||||
from utils.download_helpers import download_files
|
||||
from utils.image_utils import (
|
||||
clip_image,
|
||||
|
|
@ -59,6 +62,8 @@ class PptxPresentationCreator:
|
|||
self._ppt.slide_width = Pt(1280)
|
||||
self._ppt.slide_height = Pt(720)
|
||||
|
||||
self._chart_service = NativeChartService()
|
||||
|
||||
def get_sub_element(self, parent, tagname, **kwargs):
|
||||
"""Helper method to create XML elements"""
|
||||
element = OxmlElement(tagname)
|
||||
|
|
@ -161,9 +166,33 @@ class PptxPresentationCreator:
|
|||
elif model_type is PptxTextBoxModel:
|
||||
self.add_textbox(slide, shape_model)
|
||||
|
||||
elif model_type is PptxChartBoxModel:
|
||||
self.add_chart(slide, shape_model)
|
||||
|
||||
elif model_type is PptxConnectorModel:
|
||||
self.add_connector(slide, shape_model)
|
||||
|
||||
def add_chart(self, slide: Slide, chart_model: PptxChartBoxModel):
|
||||
cd = chart_model.chart_data
|
||||
chart_data = ChartData(
|
||||
chart_type=cd.chart_type,
|
||||
title=cd.title,
|
||||
categories=cd.categories,
|
||||
series=[ChartSeries(name=s["name"], values=s["values"]) for s in cd.series],
|
||||
unit=cd.unit,
|
||||
)
|
||||
pos = chart_model.position
|
||||
self._chart_service.add_chart(
|
||||
slide=slide,
|
||||
chart_data=chart_data,
|
||||
left=pos.left,
|
||||
top=pos.top,
|
||||
width=pos.width,
|
||||
height=pos.height,
|
||||
brand_colors=chart_model.brand_colors,
|
||||
font_name=chart_model.font_name,
|
||||
)
|
||||
|
||||
def add_connector(self, slide: Slide, connector_model: PptxConnectorModel):
|
||||
if connector_model.thickness == 0:
|
||||
return
|
||||
|
|
|
|||
305
backend/services/slide_mapping_engine.py
Normal file
305
backend/services/slide_mapping_engine.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Slide Mapping Engine: map classified content blocks to master deck layouts."""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.content_models import ClassifiedContent, ContentBlock, ContentBlockType
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
||||
class SlideMapping(BaseModel):
|
||||
content_block_indices: List[int] # which content blocks go on this slide
|
||||
layout_id: str
|
||||
layout_name: str
|
||||
slide_type: str
|
||||
content_summary: str
|
||||
attachment_ids: List[str] = []
|
||||
|
||||
|
||||
# Map content block types to preferred layout types (as stored in MasterDeckModel.layouts[].layout_type)
|
||||
_BLOCK_TO_LAYOUT_TYPE: Dict[ContentBlockType, List[str]] = {
|
||||
ContentBlockType.metric: ["metrics", "kpi", "data", "chart", "content"],
|
||||
ContentBlockType.quote: ["quote", "testimonial", "content"],
|
||||
ContentBlockType.table: ["table", "chart", "data", "content"],
|
||||
ContentBlockType.timeline: ["timeline", "process", "content"],
|
||||
ContentBlockType.comparison: ["comparison", "two_column", "content"],
|
||||
ContentBlockType.list_items: ["content", "bullet", "list"],
|
||||
ContentBlockType.narrative: ["content", "text", "description"],
|
||||
ContentBlockType.image_reference: ["picture", "image", "content"],
|
||||
ContentBlockType.call_to_action: ["content", "title_slide"],
|
||||
}
|
||||
|
||||
|
||||
class SlideMappingEngine:
|
||||
|
||||
async def map(
|
||||
self,
|
||||
classified_content: ClassifiedContent,
|
||||
layouts: List[dict],
|
||||
n_slides: int,
|
||||
instructions: Optional[str] = None,
|
||||
) -> List[SlideMapping]:
|
||||
"""Map classified content blocks to master deck layouts.
|
||||
|
||||
Args:
|
||||
classified_content: Output from ContentIntelligenceService.classify()
|
||||
layouts: MasterDeckModel.layouts list — each dict has layout_name, layout_type, index, etc.
|
||||
n_slides: Target number of slides
|
||||
instructions: Optional user instructions
|
||||
|
||||
Returns:
|
||||
Ordered list of SlideMapping
|
||||
"""
|
||||
if not layouts:
|
||||
return self._fallback_mapping(classified_content, n_slides)
|
||||
|
||||
# Build layout index by type for fast lookup
|
||||
layout_by_type: Dict[str, List[dict]] = {}
|
||||
for layout in layouts:
|
||||
lt = (layout.get("layout_type") or "custom").lower()
|
||||
layout_by_type.setdefault(lt, []).append(layout)
|
||||
|
||||
blocks = classified_content.blocks
|
||||
|
||||
# 1. Always start with a title slide
|
||||
mappings: List[SlideMapping] = []
|
||||
title_layout = self._find_layout(layout_by_type, ["title_slide", "title"], layouts)
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[],
|
||||
layout_id=str(title_layout.get("index", 0)),
|
||||
layout_name=title_layout.get("layout_name", "Title"),
|
||||
slide_type="title_slide",
|
||||
content_summary=classified_content.title or "Presentation Title",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. If many blocks, add agenda/section header
|
||||
if len(blocks) > 5:
|
||||
section_layout = self._find_layout(
|
||||
layout_by_type, ["section_header", "section", "content"], layouts
|
||||
)
|
||||
sections = list(
|
||||
{b.source_section for b in blocks if b.source_section}
|
||||
)
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[],
|
||||
layout_id=str(section_layout.get("index", 0)),
|
||||
layout_name=section_layout.get("layout_name", "Agenda"),
|
||||
slide_type="section_header",
|
||||
content_summary="Agenda: " + ", ".join(sections[:6]),
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Map each content block to a layout
|
||||
remaining_slots = n_slides - len(mappings)
|
||||
block_mappings = self._assign_blocks_to_slides(
|
||||
blocks, layout_by_type, layouts, remaining_slots
|
||||
)
|
||||
mappings.extend(block_mappings)
|
||||
|
||||
# 4. If we have more slides than content, add transitional slides
|
||||
while len(mappings) < n_slides:
|
||||
content_layout = self._find_layout(
|
||||
layout_by_type, ["content", "blank"], layouts
|
||||
)
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[],
|
||||
layout_id=str(content_layout.get("index", 0)),
|
||||
layout_name=content_layout.get("layout_name", "Content"),
|
||||
slide_type="content",
|
||||
content_summary="Additional content",
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Trim if over target
|
||||
if len(mappings) > n_slides:
|
||||
# Keep title + agenda, trim lowest-priority from the rest
|
||||
fixed = mappings[:2] if len(mappings) > 2 else mappings[:1]
|
||||
rest = mappings[len(fixed):]
|
||||
rest = rest[: n_slides - len(fixed)]
|
||||
mappings = fixed + rest
|
||||
|
||||
# 6. Optional LLM refinement for ambiguous mappings
|
||||
if instructions:
|
||||
mappings = await self._llm_refine(
|
||||
mappings, classified_content, layouts, instructions, n_slides
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
def _find_layout(
|
||||
self,
|
||||
layout_by_type: Dict[str, List[dict]],
|
||||
preferred_types: List[str],
|
||||
all_layouts: List[dict],
|
||||
) -> dict:
|
||||
"""Find best matching layout by type preference, fallback to first layout."""
|
||||
for lt in preferred_types:
|
||||
if lt in layout_by_type and layout_by_type[lt]:
|
||||
return layout_by_type[lt][0]
|
||||
return all_layouts[0] if all_layouts else {"index": 0, "layout_name": "Default", "layout_type": "content"}
|
||||
|
||||
def _assign_blocks_to_slides(
|
||||
self,
|
||||
blocks: List[ContentBlock],
|
||||
layout_by_type: Dict[str, List[dict]],
|
||||
all_layouts: List[dict],
|
||||
max_slides: int,
|
||||
) -> List[SlideMapping]:
|
||||
"""Assign content blocks to slides, respecting max_slides constraint."""
|
||||
if max_slides <= 0:
|
||||
return []
|
||||
|
||||
mappings: List[SlideMapping] = []
|
||||
|
||||
if len(blocks) <= max_slides:
|
||||
# One block per slide
|
||||
for i, block in enumerate(blocks):
|
||||
preferred = _BLOCK_TO_LAYOUT_TYPE.get(block.type, ["content"])
|
||||
layout = self._find_layout(layout_by_type, preferred, all_layouts)
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[i],
|
||||
layout_id=str(layout.get("index", 0)),
|
||||
layout_name=layout.get("layout_name", "Content"),
|
||||
slide_type=block.type.value,
|
||||
content_summary=block.raw_text[:120],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# More blocks than slides — merge low-priority blocks
|
||||
# Sort by priority descending, take top max_slides groups
|
||||
sorted_blocks = sorted(
|
||||
enumerate(blocks), key=lambda x: -x[1].priority
|
||||
)
|
||||
|
||||
# High-priority blocks get their own slide
|
||||
high_priority = sorted_blocks[:max_slides]
|
||||
overflow = sorted_blocks[max_slides:]
|
||||
|
||||
# Group overflow with nearest high-priority block
|
||||
for idx, block in high_priority:
|
||||
preferred = _BLOCK_TO_LAYOUT_TYPE.get(block.type, ["content"])
|
||||
layout = self._find_layout(layout_by_type, preferred, all_layouts)
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[idx],
|
||||
layout_id=str(layout.get("index", 0)),
|
||||
layout_name=layout.get("layout_name", "Content"),
|
||||
slide_type=block.type.value,
|
||||
content_summary=block.raw_text[:120],
|
||||
)
|
||||
)
|
||||
|
||||
# Distribute overflow blocks across existing slides
|
||||
for i, (idx, block) in enumerate(overflow):
|
||||
target = i % len(mappings)
|
||||
mappings[target].content_block_indices.append(idx)
|
||||
|
||||
# Re-sort mappings by original block order
|
||||
mappings.sort(
|
||||
key=lambda m: min(m.content_block_indices) if m.content_block_indices else 999
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
async def _llm_refine(
|
||||
self,
|
||||
mappings: List[SlideMapping],
|
||||
content: ClassifiedContent,
|
||||
layouts: List[dict],
|
||||
instructions: str,
|
||||
n_slides: int,
|
||||
) -> List[SlideMapping]:
|
||||
"""Use LLM to refine layout assignments based on user instructions."""
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
layout_info = "\n".join(
|
||||
f"- Index {l.get('index')}: {l.get('layout_name')} (type: {l.get('layout_type')})"
|
||||
for l in layouts
|
||||
)
|
||||
|
||||
current_mapping = "\n".join(
|
||||
f"Slide {i + 1}: [{m.slide_type}] {m.content_summary[:80]} → layout '{m.layout_name}'"
|
||||
for i, m in enumerate(mappings)
|
||||
)
|
||||
|
||||
messages = [
|
||||
LLMSystemMessage(
|
||||
content="You refine slide-to-layout mappings for presentations. "
|
||||
"Given the current mapping and user instructions, suggest layout changes. "
|
||||
"Return a JSON with 'changes' array of {slide_index: int, new_layout_index: int} objects. "
|
||||
"Only include slides that need changing. Return empty array if no changes needed."
|
||||
),
|
||||
LLMUserMessage(
|
||||
content=f"Available layouts:\n{layout_info}\n\n"
|
||||
f"Current mapping:\n{current_mapping}\n\n"
|
||||
f"User instructions: {instructions}"
|
||||
),
|
||||
]
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"changes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {"type": "integer"},
|
||||
"new_layout_index": {"type": "integer"},
|
||||
},
|
||||
"required": ["slide_index", "new_layout_index"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["changes"],
|
||||
}
|
||||
|
||||
try:
|
||||
result = await client.generate_structured(
|
||||
model=model, messages=messages, response_format=schema
|
||||
)
|
||||
for change in result.get("changes", []):
|
||||
si = change.get("slide_index", -1)
|
||||
li = change.get("new_layout_index", -1)
|
||||
if 0 <= si < len(mappings) and 0 <= li < len(layouts):
|
||||
mappings[si].layout_id = str(li)
|
||||
mappings[si].layout_name = layouts[li].get("layout_name", "Content")
|
||||
except Exception:
|
||||
pass # Keep original mapping on LLM failure
|
||||
|
||||
return mappings
|
||||
|
||||
def _fallback_mapping(
|
||||
self, content: ClassifiedContent, n_slides: int
|
||||
) -> List[SlideMapping]:
|
||||
"""Fallback when no master deck layouts are available."""
|
||||
mappings = [
|
||||
SlideMapping(
|
||||
content_block_indices=[],
|
||||
layout_id="0",
|
||||
layout_name="Title",
|
||||
slide_type="title_slide",
|
||||
content_summary=content.title or "Presentation",
|
||||
)
|
||||
]
|
||||
|
||||
for i, block in enumerate(content.blocks[: n_slides - 1]):
|
||||
mappings.append(
|
||||
SlideMapping(
|
||||
content_block_indices=[i],
|
||||
layout_id="0",
|
||||
layout_name="Content",
|
||||
slide_type=block.type.value,
|
||||
content_summary=block.raw_text[:120],
|
||||
)
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
"use client";
|
||||
|
||||
import React, { useCallback, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Plus, Trash2 } from "lucide-react";
|
||||
|
||||
const CHART_TYPES = [
|
||||
{ value: "bar", label: "Bar" },
|
||||
{ value: "column", label: "Column" },
|
||||
{ value: "line", label: "Line" },
|
||||
{ value: "pie", label: "Pie" },
|
||||
{ value: "doughnut", label: "Doughnut" },
|
||||
{ value: "area", label: "Area" },
|
||||
{ value: "scatter", label: "Scatter" },
|
||||
{ value: "gantt", label: "Gantt" },
|
||||
{ value: "waterfall", label: "Waterfall" },
|
||||
];
|
||||
|
||||
export interface ChartSeries {
|
||||
name: string;
|
||||
values: number[];
|
||||
}
|
||||
|
||||
export interface ChartDataPayload {
|
||||
chart_type: string;
|
||||
title: string;
|
||||
categories: string[];
|
||||
series: ChartSeries[];
|
||||
unit?: string;
|
||||
}
|
||||
|
||||
interface ChartDataEditorProps {
|
||||
initialData?: ChartDataPayload;
|
||||
onApply: (data: ChartDataPayload) => void;
|
||||
onCancel?: () => void;
|
||||
}
|
||||
|
||||
const DEFAULT_DATA: ChartDataPayload = {
|
||||
chart_type: "column",
|
||||
title: "Chart",
|
||||
categories: ["Category 1", "Category 2", "Category 3"],
|
||||
series: [{ name: "Series 1", values: [10, 20, 30] }],
|
||||
};
|
||||
|
||||
export default function ChartDataEditor({
|
||||
initialData,
|
||||
onApply,
|
||||
onCancel,
|
||||
}: ChartDataEditorProps) {
|
||||
const [data, setData] = useState<ChartDataPayload>(
|
||||
initialData ?? DEFAULT_DATA
|
||||
);
|
||||
|
||||
const updateCategory = useCallback(
|
||||
(index: number, value: string) => {
|
||||
setData((prev) => {
|
||||
const cats = [...prev.categories];
|
||||
cats[index] = value;
|
||||
return { ...prev, categories: cats };
|
||||
});
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const updateSeriesName = useCallback(
|
||||
(seriesIdx: number, name: string) => {
|
||||
setData((prev) => {
|
||||
const series = prev.series.map((s, i) =>
|
||||
i === seriesIdx ? { ...s, name } : s
|
||||
);
|
||||
return { ...prev, series };
|
||||
});
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const updateCellValue = useCallback(
|
||||
(seriesIdx: number, catIdx: number, value: string) => {
|
||||
setData((prev) => {
|
||||
const series = prev.series.map((s, i) => {
|
||||
if (i !== seriesIdx) return s;
|
||||
const values = [...s.values];
|
||||
values[catIdx] = parseFloat(value) || 0;
|
||||
return { ...s, values };
|
||||
});
|
||||
return { ...prev, series };
|
||||
});
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const addCategory = useCallback(() => {
|
||||
setData((prev) => ({
|
||||
...prev,
|
||||
categories: [...prev.categories, `Category ${prev.categories.length + 1}`],
|
||||
series: prev.series.map((s) => ({
|
||||
...s,
|
||||
values: [...s.values, 0],
|
||||
})),
|
||||
}));
|
||||
}, []);
|
||||
|
||||
const removeCategory = useCallback(
|
||||
(index: number) => {
|
||||
setData((prev) => ({
|
||||
...prev,
|
||||
categories: prev.categories.filter((_, i) => i !== index),
|
||||
series: prev.series.map((s) => ({
|
||||
...s,
|
||||
values: s.values.filter((_, i) => i !== index),
|
||||
})),
|
||||
}));
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const addSeries = useCallback(() => {
|
||||
setData((prev) => ({
|
||||
...prev,
|
||||
series: [
|
||||
...prev.series,
|
||||
{
|
||||
name: `Series ${prev.series.length + 1}`,
|
||||
values: new Array(prev.categories.length).fill(0),
|
||||
},
|
||||
],
|
||||
}));
|
||||
}, []);
|
||||
|
||||
const removeSeries = useCallback(
|
||||
(index: number) => {
|
||||
setData((prev) => ({
|
||||
...prev,
|
||||
series: prev.series.filter((_, i) => i !== index),
|
||||
}));
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 p-4 border rounded-lg bg-background">
|
||||
{/* Header controls */}
|
||||
<div className="flex items-center gap-3">
|
||||
<Input
|
||||
value={data.title}
|
||||
onChange={(e) => setData((prev) => ({ ...prev, title: e.target.value }))}
|
||||
placeholder="Chart title"
|
||||
className="max-w-[240px]"
|
||||
/>
|
||||
<Select
|
||||
value={data.chart_type}
|
||||
onValueChange={(val) => setData((prev) => ({ ...prev, chart_type: val }))}
|
||||
>
|
||||
<SelectTrigger className="w-[140px]">
|
||||
<SelectValue placeholder="Chart type" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{CHART_TYPES.map((ct) => (
|
||||
<SelectItem key={ct.value} value={ct.value}>
|
||||
{ct.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Input
|
||||
value={data.unit ?? ""}
|
||||
onChange={(e) =>
|
||||
setData((prev) => ({
|
||||
...prev,
|
||||
unit: e.target.value || undefined,
|
||||
}))
|
||||
}
|
||||
placeholder="Unit (e.g. %, $)"
|
||||
className="max-w-[100px]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Spreadsheet grid */}
|
||||
<div className="overflow-x-auto">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead className="w-[140px]">Category</TableHead>
|
||||
{data.series.map((s, si) => (
|
||||
<TableHead key={si} className="min-w-[120px]">
|
||||
<div className="flex items-center gap-1">
|
||||
<Input
|
||||
value={s.name}
|
||||
onChange={(e) => updateSeriesName(si, e.target.value)}
|
||||
className="h-7 text-xs"
|
||||
/>
|
||||
{data.series.length > 1 && (
|
||||
<button
|
||||
onClick={() => removeSeries(si)}
|
||||
className="text-muted-foreground hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</TableHead>
|
||||
))}
|
||||
<TableHead className="w-[40px]" />
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{data.categories.map((cat, ci) => (
|
||||
<TableRow key={ci}>
|
||||
<TableCell>
|
||||
<Input
|
||||
value={cat}
|
||||
onChange={(e) => updateCategory(ci, e.target.value)}
|
||||
className="h-7 text-xs"
|
||||
/>
|
||||
</TableCell>
|
||||
{data.series.map((s, si) => (
|
||||
<TableCell key={si}>
|
||||
<Input
|
||||
type="number"
|
||||
value={s.values[ci] ?? 0}
|
||||
onChange={(e) => updateCellValue(si, ci, e.target.value)}
|
||||
className="h-7 text-xs"
|
||||
/>
|
||||
</TableCell>
|
||||
))}
|
||||
<TableCell>
|
||||
{data.categories.length > 1 && (
|
||||
<button
|
||||
onClick={() => removeCategory(ci)}
|
||||
className="text-muted-foreground hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</button>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Button variant="outline" size="sm" onClick={addCategory}>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
Row
|
||||
</Button>
|
||||
<Button variant="outline" size="sm" onClick={addSeries}>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
Series
|
||||
</Button>
|
||||
<div className="ml-auto flex gap-2">
|
||||
{onCancel && (
|
||||
<Button variant="ghost" size="sm" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
<Button size="sm" onClick={() => onApply(data)}>
|
||||
Apply
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue