from typing import Any from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.auth.providers.azure_ad import validate_azure_token from app.auth.providers.jwt_provider import JWTAuthProvider from app.models.user import User, UserRole, UserStatus class AuthService: """Authentication service wrapping the JWT provider.""" def __init__(self) -> None: self.provider = JWTAuthProvider() async def login( self, email: str, password: str, db: AsyncSession ) -> dict[str, str] | None: """Authenticate a user and return tokens, or None if invalid.""" result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if user is None: return None if user.password_hash is None: return None # SSO-only user, cannot authenticate with password if not self.provider.verify_password(password, user.password_hash): return None if user.status.value != "active": return None token_data = { "sub": str(user.id), "email": user.email, "role": user.role.value, "name": user.name, } return { "access_token": self.provider.create_access_token(token_data), "refresh_token": self.provider.create_refresh_token(token_data), "token_type": "bearer", } async def sso_login( self, azure_token: str, db: AsyncSession ) -> dict[str, str] | None: """Validate an Azure AD token, auto-provision the user, and return app tokens.""" claims = await validate_azure_token(azure_token) # Extract email — Azure AD may use either 'email' or 'preferred_username' email: str = claims.get("email") or claims.get("preferred_username") or "" email = email.strip().lower() if not email: return None name: str = claims.get("name") or email.split("@")[0] # Look up existing user result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if user is None: # Auto-provision SSO user with reviewer role user = User( email=email, name=name, password_hash=None, role=UserRole.viewer, status=UserStatus.active, auth_provider="azure_ad", ) db.add(user) try: await db.flush() except IntegrityError: # Race condition: another request created the user first — re-query await db.rollback() result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if user is None: return None else: if user.status.value != "active": return None token_data = { "sub": str(user.id), "email": user.email, "role": user.role.value, "name": user.name, } return { "access_token": self.provider.create_access_token(token_data), "refresh_token": self.provider.create_refresh_token(token_data), "token_type": "bearer", } def refresh_tokens(self, refresh_token: str) -> dict[str, str] | None: """Validate a refresh token and issue new token pair.""" claims = self.provider.validate_token(refresh_token) if claims is None: return None if claims.get("type") != "refresh": return None token_data = { "sub": claims["sub"], "email": claims.get("email", ""), "role": claims.get("role", ""), "name": claims.get("name", ""), } return { "access_token": self.provider.create_access_token(token_data), "refresh_token": self.provider.create_refresh_token(token_data), "token_type": "bearer", } def validate_token(self, token: str) -> dict[str, Any] | None: """Validate a token and return claims.""" return self.provider.validate_token(token) def hash_password(self, password: str) -> str: """Hash a password.""" return self.provider.hash_password(password)