""" Admin API routes — all endpoints require jwt_required + admin_required. Users: GET/POST /api/admin/users GET/PUT /api/admin/users/ POST /api/admin/users//disable|enable Usage: GET /api/admin/usage/summary GET /api/admin/usage/events Pricing: GET /api/admin/pricing """ import logging from datetime import datetime, timezone, timedelta from quart import Blueprint, jsonify, request from bson import ObjectId from app.auth.quart_jwt import jwt_required, get_jwt_identity from app.utils import admin_required, make_serializable from app.models.user import User from app.models.usage_event import UsageEvent from app.models.model_pricing import ModelPricing from app.db import get_db logger = logging.getLogger(__name__) admin_bp = Blueprint('admin', __name__) # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _safe_user(doc: dict) -> dict: """Return a user document safe for API response — strip password_hash.""" if not doc: return {} out = {k: v for k, v in doc.items() if k != 'password_hash'} return make_serializable(out) def _month_start() -> datetime: now = datetime.now(timezone.utc) return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) def _parse_iso(s: str) -> datetime: """Parse ISO-8601 datetime string. Handles Z suffix (Python < 3.11 compat).""" return datetime.fromisoformat(s.replace('Z', '+00:00')) def _period_match(from_str: str | None, to_str: str | None) -> dict: """Build a MongoDB ts-range filter. Returns {} when both are absent (All time).""" if not from_str and not to_str: return {} now = datetime.now(timezone.utc) from_dt = _parse_iso(from_str) if from_str else _month_start() to_dt = _parse_iso(to_str) if to_str else now return {'ts': {'$gte': from_dt, '$lte': to_dt}} async def _user_period_cost(user_id: str, from_str: str | None, to_str: str | None) -> float: """Cost for a single user within the given period (or all time).""" match = {'user_id': user_id, **_period_match(from_str, to_str)} return await UsageEvent.sum_cost(match) # ───────────────────────────────────────────────────────────────────────────── # Users # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/users', methods=['GET']) @jwt_required() @admin_required async def list_users(): """GET /api/admin/users?q=&role=&skip=&limit=&from=ISO&to=ISO""" q = request.args.get('q', '').strip() role_filter = request.args.get('role', '').strip() skip = max(0, int(request.args.get('skip', 0))) limit = min(100, max(1, int(request.args.get('limit', 50)))) from_str = request.args.get('from') to_str = request.args.get('to') query = {} if q: query['$or'] = [ {'username': {'$regex': q, '$options': 'i'}}, {'email': {'$regex': q, '$options': 'i'}}, ] if role_filter: query['role'] = role_filter users = await User.find_all(query, skip=skip, limit=limit) total = await User.count(query) result = [] for u in users: user_id = str(u.get('_id', '')) safe = _safe_user(u) safe['cost_mtd'] = await _user_period_cost(user_id, from_str, to_str) result.append(safe) return jsonify({'users': result, 'total': total, 'skip': skip, 'limit': limit}), 200 @admin_bp.route('/users/', methods=['GET']) @jwt_required() @admin_required async def get_user(user_id): """GET /api/admin/users/""" try: user = await User.find_by_id(user_id) except Exception: return jsonify({'error': 'Invalid user ID'}), 400 if not user: return jsonify({'error': 'User not found'}), 404 safe = _safe_user(user) safe['cost_mtd'] = await _user_period_cost(user_id, None, None) return jsonify(safe), 200 @admin_bp.route('/users/', methods=['PUT']) @jwt_required() @admin_required async def update_user(user_id): """PUT /api/admin/users/ — update role, is_active, quota, override_quota.""" data = await request.get_json(silent=True) or {} allowed = {'role', 'is_active', 'quota', 'override_quota'} fields = {k: v for k, v in data.items() if k in allowed} if not fields: return jsonify({'error': 'No valid fields to update'}), 400 # Guard: cannot demote if this is the last admin if fields.get('role') == 'user': requesting_id = get_jwt_identity() if requesting_id == user_id: admin_count = await User.count({'role': 'admin'}) if admin_count <= 1: return jsonify({'error': 'Cannot demote the last admin'}), 409 # Validate role value if 'role' in fields and fields['role'] not in ('user', 'admin'): return jsonify({'error': 'Invalid role. Must be user or admin'}), 400 try: updated = await User.update(user_id, fields) except Exception: return jsonify({'error': 'Invalid user ID'}), 400 if not updated: return jsonify({'error': 'User not found'}), 404 # Bump token_version so existing JWTs are immediately invalidated if fields.get('is_active') is False or 'role' in fields: await User.bump_token_version(user_id) logger.info(f"Admin updated user {user_id}: {list(fields.keys())}") user = await User.find_by_id(user_id) return jsonify(_safe_user(user)), 200 @admin_bp.route('/users//disable', methods=['POST']) @jwt_required() @admin_required async def disable_user(user_id): """POST /api/admin/users//disable""" requesting_id = get_jwt_identity() if requesting_id == user_id: return jsonify({'error': 'Cannot disable your own account'}), 400 try: updated = await User.update(user_id, {'is_active': False}) except Exception: return jsonify({'error': 'Invalid user ID'}), 400 if not updated: return jsonify({'error': 'User not found'}), 404 logger.info(f"Admin disabled user {user_id}") return jsonify({'message': 'User disabled'}), 200 @admin_bp.route('/users//enable', methods=['POST']) @jwt_required() @admin_required async def enable_user(user_id): """POST /api/admin/users//enable""" try: updated = await User.update(user_id, {'is_active': True}) except Exception: return jsonify({'error': 'Invalid user ID'}), 400 if not updated: return jsonify({'error': 'User not found'}), 404 logger.info(f"Admin enabled user {user_id}") return jsonify({'message': 'User enabled'}), 200 # ───────────────────────────────────────────────────────────────────────────── # Usage # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/usage/summary', methods=['GET']) @jwt_required() @admin_required async def usage_summary(): """ GET /api/admin/usage/summary?from=ISO&to=ISO&group_by=user|model|feature|day&user_id=&focus_group_id= Returns aggregated cost + token totals. """ try: from_str = request.args.get('from') to_str = request.args.get('to') group_by = request.args.get('group_by', 'user') filter_user = request.args.get('user_id') filter_fg = request.args.get('focus_group_id') match: dict = _period_match(from_str, to_str) if filter_user: match['user_id'] = filter_user if filter_fg: match['focus_group_id'] = filter_fg # Group-by key group_keys = { 'user': '$user_id', 'model': '$model', 'feature': '$feature', 'day': {'$dateToString': {'format': '%Y-%m-%d', 'date': '$ts'}}, 'focus_group': '$focus_group_id', } group_key = group_keys.get(group_by, '$user_id') db = await get_db() pipeline = [ {'$match': match}, {'$group': { '_id': group_key, 'total_cost': {'$sum': '$cost_usd.total'}, 'prompt_tokens': {'$sum': '$prompt_tokens'}, 'completion_tokens': {'$sum': '$completion_tokens'}, 'calls': {'$sum': 1}, }}, {'$sort': {'total_cost': -1}}, ] rows = await db.usage_events.aggregate(pipeline).to_list(500) # Totals totals_pipeline = [ {'$match': match}, {'$group': { '_id': None, 'total_cost': {'$sum': '$cost_usd.total'}, 'prompt_tokens': {'$sum': '$prompt_tokens'}, 'completion_tokens': {'$sum': '$completion_tokens'}, 'calls': {'$sum': 1}, }}, ] totals_raw = await db.usage_events.aggregate(totals_pipeline).to_list(1) totals = totals_raw[0] if totals_raw else { 'total_cost': 0, 'prompt_tokens': 0, 'completion_tokens': 0, 'calls': 0 } totals.pop('_id', None) return jsonify({ 'rows': make_serializable(rows), 'totals': make_serializable(totals), 'from': from_str, 'to': to_str, 'group_by': group_by, }), 200 except Exception as e: logger.error(f"Usage summary error: {e}", exc_info=True) return jsonify({'error': str(e)}), 500 @admin_bp.route('/usage/events', methods=['GET']) @jwt_required() @admin_required async def usage_events(): """GET /api/admin/usage/events?user_id=&focus_group_id=&feature=&skip=&limit=""" skip = max(0, int(request.args.get('skip', 0))) limit = min(500, max(1, int(request.args.get('limit', 50)))) filter_user = request.args.get('user_id') filter_fg = request.args.get('focus_group_id') filter_feature = request.args.get('feature') match: dict = {} if filter_user: match['user_id'] = filter_user if filter_fg: match['focus_group_id'] = filter_fg if filter_feature: match['feature'] = filter_feature db = await get_db() cursor = db.usage_events.find(match).sort('ts', -1).skip(skip).limit(limit) events = await cursor.to_list(length=limit) total = await db.usage_events.count_documents(match) return jsonify({ 'events': make_serializable(events), 'total': total, 'skip': skip, 'limit': limit, }), 200 # ───────────────────────────────────────────────────────────────────────────── # Pricing # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/pricing', methods=['GET']) @jwt_required() @admin_required async def list_pricing(): """GET /api/admin/pricing — active pricing rows for all models.""" db = await get_db() now = datetime.now(timezone.utc) cursor = db.model_pricing.find({ 'effective_from': {'$lte': now}, '$or': [{'effective_until': None}, {'effective_until': {'$gt': now}}], }).sort([('model', 1), ('effective_from', -1)]) rows = await cursor.to_list(length=100) return jsonify({'pricing': make_serializable(rows)}), 200 # ───────────────────────────────────────────────────────────────────────────── # Users — extended # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/users', methods=['POST']) @jwt_required() @admin_required async def create_user(): """POST /api/admin/users — create a local (non-SSO) user.""" import bcrypt as _bcrypt data = await request.get_json(silent=True) or {} username = (data.get('username') or '').strip() email = (data.get('email') or '').strip() password = (data.get('password') or '').strip() role = data.get('role', 'user') if not username or not email or not password: return jsonify({'error': 'username, email, password required'}), 400 if role not in ('user', 'admin'): return jsonify({'error': 'Invalid role. Must be user or admin'}), 400 db = await get_db() if await db.users.find_one({'$or': [{'username': username}, {'email': email}]}): return jsonify({'error': 'Username or email already exists'}), 409 pw_hash = _bcrypt.hashpw(password.encode(), _bcrypt.gensalt()).decode() now = datetime.now(timezone.utc) doc = { 'username': username, 'email': email, 'password_hash': pw_hash, 'role': role, 'is_active': True, 'override_quota': False, 'token_version': 0, 'created_at': now, 'updated_at': now, } result = await db.users.insert_one(doc) doc['_id'] = result.inserted_id logger.info(f"Admin created user {username} ({email})") return jsonify(_safe_user(make_serializable(doc))), 201 @admin_bp.route('/users//reset-password', methods=['POST']) @jwt_required() @admin_required async def reset_password(user_id): """POST /api/admin/users//reset-password""" import bcrypt as _bcrypt data = await request.get_json(silent=True) or {} new_password = (data.get('password') or '').strip() if not new_password or len(new_password) < 8: return jsonify({'error': 'Password must be at least 8 characters'}), 400 pw_hash = _bcrypt.hashpw(new_password.encode(), _bcrypt.gensalt()).decode() db = await get_db() try: result = await db.users.update_one( {'_id': ObjectId(user_id)}, {'$set': {'password_hash': pw_hash}} ) except Exception: return jsonify({'error': 'Invalid user ID'}), 400 if result.matched_count == 0: return jsonify({'error': 'User not found'}), 404 await User.bump_token_version(user_id) logger.info(f"Admin reset password for user {user_id}") return jsonify({'ok': True}), 200 # ───────────────────────────────────────────────────────────────────────────── # Pricing — extended # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/pricing', methods=['POST']) @jwt_required() @admin_required async def create_pricing(): """POST /api/admin/pricing — insert a new pricing row.""" data = await request.get_json(silent=True) or {} model = (data.get('model') or '').strip() provider = (data.get('provider') or '').strip() tiers = data.get('tiers', []) if not model or not provider or not tiers: return jsonify({'error': 'model, provider, tiers required'}), 400 now = datetime.now(timezone.utc) expire_current = bool(data.get('expire_current', False)) db = await get_db() if expire_current: await db.model_pricing.update_many( {'model': model, 'effective_until': None}, {'$set': {'effective_until': now}}, ) doc = { 'model': model, 'provider': provider, 'currency': 'USD', 'tiers': tiers, 'effective_from': now, 'effective_until': None, 'notes': data.get('notes', ''), } result = await db.model_pricing.insert_one(doc) doc['_id'] = result.inserted_id logger.info(f"Admin created pricing row for model {model}") return jsonify(make_serializable(doc)), 201 # ───────────────────────────────────────────────────────────────────────────── # Focus Groups (admin view) # ───────────────────────────────────────────────────────────────────────────── @admin_bp.route('/focus-groups', methods=['GET']) @jwt_required() @admin_required async def list_focus_groups(): """GET /api/admin/focus-groups?skip=&limit=&from=ISO&to=ISO — list with cost totals.""" skip = max(0, int(request.args.get('skip', 0))) limit = min(200, max(1, int(request.args.get('limit', 50)))) from_str = request.args.get('from') to_str = request.args.get('to') period_filter = _period_match(from_str, to_str) db = await get_db() cursor = db.focus_groups.find( {}, {'name': 1, 'date': 1, 'status': 1, 'llm_model': 1, 'quota': 1}, ).sort('date', -1).skip(skip).limit(limit) fgs = await cursor.to_list(length=limit) result = [] for fg in fgs: fg_id = str(fg['_id']) cost_match: dict = {'focus_group_id': fg_id, **period_filter} pipeline = [ {'$match': cost_match}, {'$group': {'_id': None, 'total': {'$sum': '$cost_usd.total'}, 'calls': {'$sum': 1}}}, ] agg = await db.usage_events.aggregate(pipeline).to_list(1) fg['cost_total'] = agg[0]['total'] if agg else 0 fg['call_count'] = agg[0]['calls'] if agg else 0 result.append(fg) result.sort(key=lambda x: x['cost_total'], reverse=True) total = await db.focus_groups.count_documents({}) return jsonify({'focus_groups': make_serializable(result), 'total': total}), 200