111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Annotated
|
|
|
|
import bcrypt
|
|
from fastapi import Depends, HTTPException, Security, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from src.config import settings
|
|
from src.database import get_db
|
|
from src.models import ApiKey, User
|
|
|
|
bearer_scheme = HTTPBearer(auto_error=False)
|
|
|
|
ALGORITHM = "HS256"
|
|
|
|
|
|
# ── Password ──────────────────────────────────────────────────────────────────
|
|
|
|
def hash_password(password: str) -> str:
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
return bcrypt.checkpw(plain.encode(), hashed.encode())
|
|
|
|
|
|
# ── JWT ───────────────────────────────────────────────────────────────────────
|
|
|
|
def create_access_token(user_id: str, role: str) -> str:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
return jwt.encode(
|
|
{"sub": user_id, "role": role, "exp": expire, "type": "access"},
|
|
settings.SECRET_KEY, algorithm=ALGORITHM,
|
|
)
|
|
|
|
|
|
def create_refresh_token(user_id: str) -> str:
|
|
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
return jwt.encode(
|
|
{"sub": user_id, "exp": expire, "type": "refresh"},
|
|
settings.SECRET_KEY, algorithm=ALGORITHM,
|
|
)
|
|
|
|
|
|
def decode_token(token: str) -> dict:
|
|
try:
|
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
|
except JWTError:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
|
|
|
|
# ── API Key ───────────────────────────────────────────────────────────────────
|
|
|
|
def generate_api_key() -> tuple[str, str, str]:
|
|
"""Returns (raw_key, prefix, hash). raw_key shown once to user."""
|
|
raw = "cc_" + secrets.token_urlsafe(32)
|
|
prefix = raw[:11] # "cc_" + 8 chars
|
|
return raw, prefix, hash_password(raw)
|
|
|
|
|
|
async def verify_api_key(raw_key: str, db: AsyncSession) -> User | None:
|
|
"""Find user by API key. Updates last_used_at."""
|
|
if not raw_key or not raw_key.startswith("cc_"):
|
|
return None
|
|
prefix = raw_key[:11]
|
|
result = await db.execute(
|
|
select(ApiKey)
|
|
.options(selectinload(ApiKey.user))
|
|
.where(ApiKey.key_prefix == prefix, ApiKey.is_active == True)
|
|
.join(ApiKey.user)
|
|
.where(User.is_active == True)
|
|
)
|
|
keys = result.scalars().all()
|
|
for key in keys:
|
|
if verify_password(raw_key, key.key_hash):
|
|
key.last_used_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
return key.user
|
|
return None
|
|
|
|
|
|
# ── FastAPI dependencies ──────────────────────────────────────────────────────
|
|
|
|
async def get_current_user(
|
|
credentials: Annotated[HTTPAuthorizationCredentials | None, Security(bearer_scheme)],
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> User:
|
|
if not credentials:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
payload = decode_token(credentials.credentials)
|
|
if payload.get("type") != "access":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
|
user = await db.get(User, payload["sub"])
|
|
if not user or not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
return user
|
|
|
|
|
|
async def get_admin_user(user: User = Depends(get_current_user)) -> User:
|
|
if user.role != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin required")
|
|
return user
|
|
|
|
|
|
CurrentUser = Annotated[User, Depends(get_current_user)]
|
|
AdminUser = Annotated[User, Depends(get_admin_user)]
|