"""Chart Data Extractor: extract chart-ready data from content blocks and tables.""" import re from typing import List, Optional from pydantic import BaseModel from models.content_models import ContentBlock, ContentBlockType from services.attachment_parser_service import TableData class ChartSeries(BaseModel): name: str values: List[float] class ChartData(BaseModel): chart_type: str # bar, column, line, pie, doughnut, area, scatter, gantt, waterfall title: str categories: List[str] series: List[ChartSeries] unit: Optional[str] = None # --- Public API --- def extract( content_block: ContentBlock, table_data: Optional[TableData] = None, ) -> Optional[ChartData]: """Extract chart data from a content block and/or associated table. Returns ChartData if chartable data is found, else None. """ if table_data and table_data.rows and table_data.headers: return _chart_from_table(table_data) if content_block.type == ContentBlockType.metric: return _chart_from_metrics(content_block) return None # --- Table → ChartData --- def _chart_from_table(td: TableData) -> Optional[ChartData]: """Convert a TableData into ChartData. Heuristic: first column = categories, remaining numeric columns = series. """ if not td.rows or len(td.headers) < 2: return None # Determine which columns are numeric (by checking majority of rows) numeric_cols = [] for col_idx in range(1, len(td.headers)): numeric_count = 0 for row in td.rows: if col_idx < len(row): val = row[col_idx] if _to_float(val) is not None: numeric_count += 1 if numeric_count >= len(td.rows) * 0.5: numeric_cols.append(col_idx) if not numeric_cols: return None categories = [] for row in td.rows: categories.append(str(row[0]) if row else "") series_list: List[ChartSeries] = [] for col_idx in numeric_cols: values = [] for row in td.rows: val = row[col_idx] if col_idx < len(row) else 0 values.append(_to_float(val) or 0.0) series_list.append( ChartSeries(name=td.headers[col_idx], values=values) ) chart_type = _recommend_chart_type(categories, series_list, td) title = td.title or td.sheet_name or "Chart" return ChartData( chart_type=chart_type, title=title, categories=categories, series=series_list, ) # --- Metric block → ChartData --- _NUMBER_RE = re.compile( r"[\$€£¥]?\s?(\d[\d,.]*)\s?([KMBTkmbt%]?)", ) def _chart_from_metrics(block: ContentBlock) -> Optional[ChartData]: """Build ChartData from a metric content block's extracted_data.""" metrics = (block.extracted_data or {}).get("metrics", []) if not metrics: return None categories = [] values = [] unit = None for m in metrics: label = m.get("label", "").strip() raw_value = m.get("value", "") parsed = _parse_metric_value(raw_value) if parsed is None: continue numeric_val, val_unit = parsed if val_unit and not unit: unit = val_unit categories.append(label or f"Metric {len(categories) + 1}") values.append(numeric_val) if len(values) < 2: return None chart_type = "bar" # If all values are percentages and sum near 100, use pie if unit == "%" and 90 <= sum(values) <= 110: chart_type = "pie" return ChartData( chart_type=chart_type, title=block.source_section or "Key Metrics", categories=categories, series=[ChartSeries(name="Value", values=values)], unit=unit, ) # --- Chart type recommendation --- _TIME_PATTERN = re.compile( r"(?:19|20)\d{2}|Q[1-4]|(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)", re.IGNORECASE, ) def _recommend_chart_type( categories: List[str], series: List[ChartSeries], td: Optional[TableData] = None, ) -> str: """Auto-recommend a chart type based on data characteristics.""" n_cats = len(categories) n_series = len(series) # Check if categories look like time periods time_count = sum(1 for c in categories if _TIME_PATTERN.search(c)) is_time_series = time_count >= n_cats * 0.6 if is_time_series: return "line" # Single series if n_series == 1: vals = series[0].values # Parts of a whole total = sum(vals) if 2 <= n_cats <= 8 and 90 <= total <= 110: return "pie" if n_cats <= 6: return "bar" return "column" # Multiple series if n_series == 2: return "bar" # grouped bar return "column" # --- Helpers --- def _to_float(val) -> Optional[float]: """Convert a cell value to float, handling common formats.""" if val is None: return None if isinstance(val, (int, float)): return float(val) if isinstance(val, str): cleaned = val.strip().replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "").rstrip("%") try: return float(cleaned) except ValueError: return None return None def _parse_metric_value(raw: str) -> Optional[tuple]: """Parse a metric value string like '$2.3M' or '45%' into (float, unit).""" if not raw: return None raw = raw.strip() unit = None if raw.endswith("%"): unit = "%" raw = raw.rstrip("%").strip() elif raw[-1:].upper() in ("K", "M", "B", "T"): suffix = raw[-1].upper() multipliers = {"K": 1_000, "M": 1_000_000, "B": 1_000_000_000, "T": 1_000_000_000_000} raw_num = raw[:-1].strip() cleaned = raw_num.replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "") try: return float(cleaned) * multipliers[suffix], suffix except ValueError: return None cleaned = raw.replace(",", "").replace("$", "").replace("€", "").replace("£", "").replace("¥", "") try: return float(cleaned), unit except ValueError: return None