from datetime import datetime, timezone import logging logger = logging.getLogger(__name__) class QuotaExceededError(Exception): def __init__(self, scope: str, limit_usd: float, used_usd: float, period_start=None): self.scope = scope # "user" | "focus_group" self.limit_usd = limit_usd self.used_usd = used_usd self.period_start = period_start super().__init__( f"Quota exceeded ({scope}): used ${used_usd:.4f} of ${limit_usd:.2f} limit" ) class QuotaWarning: def __init__(self, scope: str, limit_usd: float, used_usd: float, pct: float): self.scope = scope self.limit_usd = limit_usd self.used_usd = used_usd self.pct = pct async def check_quota(user_id: str | None, focus_group_id: str | None) -> QuotaWarning | None: """Check quotas and raise QuotaExceededError if either is exceeded. Returns a QuotaWarning (not raised) when usage is between 80 % and 100 %. Returns None if all quotas are fine. Admins and users with override_quota=True bypass user-level quota. Focus-group quotas apply to everyone (including admins) — they are project budgets. """ from app.models.user import User from app.models.usage_event import UsageEvent warning = None if user_id: try: user = await User.find_by_id(user_id) if user: is_admin = user.get("role") == "admin" override = user.get("override_quota", False) if not is_admin and not override: quota = user.get("quota") or {} limit = quota.get("monthly_usd") if limit: now = datetime.now(timezone.utc) period_start = now.replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) spent = await UsageEvent.sum_cost( {"user_id": user_id, "ts": {"$gte": period_start}} ) pct = spent / limit if limit else 0 if spent >= limit: raise QuotaExceededError("user", limit, spent, period_start) elif pct >= 0.8: warning = QuotaWarning("user", limit, spent, pct) except QuotaExceededError: raise except Exception: logger.warning("Quota check failed (non-fatal, allowing call)", exc_info=True) if focus_group_id: try: from app.models.focus_group import FocusGroup fg = await FocusGroup.find_by_id(focus_group_id) if fg: fg_quota = fg.get("quota") or {} fg_limit = fg_quota.get("total_usd") if fg_limit: spent = await UsageEvent.sum_cost({"focus_group_id": focus_group_id}) pct = spent / fg_limit if fg_limit else 0 if spent >= fg_limit: raise QuotaExceededError("focus_group", fg_limit, spent, None) elif pct >= 0.8 and not warning: warning = QuotaWarning("focus_group", fg_limit, spent, pct) except QuotaExceededError: raise except Exception: logger.warning("Focus-group quota check failed (non-fatal)", exc_info=True) return warning