- Remove contact references from system prompt, add language matching rule - Add copy-to-clipboard button on assistant messages with iframe fallback - Increase token lifetime to 24h/30d, add refresh queue, remove hard redirect - Fix adaptive layout for iframe/standalone, pin input at bottom - Fix CSS specificity conflict (8px→2px spacing), add markdown post-processing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
385 lines
11 KiB
Python
385 lines
11 KiB
Python
"""
|
|
Authentication Service
|
|
|
|
Handles MSAL authentication and JWT session management
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Dict
|
|
from datetime import datetime, timedelta
|
|
from uuid import UUID
|
|
import httpx
|
|
import bcrypt
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.security import create_access_token, create_refresh_token, hash_token
|
|
from app.repositories.user_repository import UserRepository
|
|
from app.models.session import Session
|
|
from app.models.user import User
|
|
from app.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
|
|
class AuthService:
|
|
"""
|
|
Authentication service for MSAL and JWT management
|
|
|
|
Handles:
|
|
- Azure AD token validation
|
|
- User creation/update from Azure AD
|
|
- JWT session token generation
|
|
- Session management
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
"""
|
|
Initialize auth service
|
|
|
|
Args:
|
|
session: Database session
|
|
"""
|
|
self.session = session
|
|
self.user_repo = UserRepository(session)
|
|
|
|
async def validate_azure_token(self, id_token: str) -> Optional[Dict]:
|
|
"""
|
|
Validate Azure AD ID token
|
|
|
|
Args:
|
|
id_token: Azure AD ID token from MSAL
|
|
|
|
Returns:
|
|
Dict with user info or None if invalid
|
|
"""
|
|
try:
|
|
# Call Microsoft Graph API to get user info
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://graph.microsoft.com/v1.0/me",
|
|
headers={"Authorization": f"Bearer {id_token}"},
|
|
timeout=10.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"Azure AD validation failed: {response.status_code}")
|
|
return None
|
|
|
|
user_info = response.json()
|
|
|
|
logger.info(f"Successfully validated Azure AD token for: {user_info.get('mail')}")
|
|
|
|
return {
|
|
"azure_ad_id": user_info.get("id"),
|
|
"email": user_info.get("mail") or user_info.get("userPrincipalName"),
|
|
"display_name": user_info.get("displayName"),
|
|
"given_name": user_info.get("givenName"),
|
|
"surname": user_info.get("surname"),
|
|
}
|
|
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Failed to validate Azure AD token: {e}", exc_info=True)
|
|
return None
|
|
|
|
async def login_simple(
|
|
self,
|
|
email: str,
|
|
password: str,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None
|
|
) -> Optional[Dict]:
|
|
"""
|
|
Login user with email and password (for test users)
|
|
|
|
Args:
|
|
email: User email
|
|
password: User password
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Dict with access_token, refresh_token, and user info
|
|
"""
|
|
from sqlalchemy import select
|
|
|
|
# Find user by email
|
|
result = await self.session.execute(
|
|
select(User).where(User.email == email)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
logger.warning(f"Login attempt for non-existent user: {email}")
|
|
return None
|
|
|
|
# Check if this is a test user with password in meta_data
|
|
if not user.meta_data or not user.meta_data.get("is_test_user"):
|
|
logger.warning(f"Simple login attempted for non-test user: {email}")
|
|
return None
|
|
|
|
# Verify password
|
|
password_hash = user.meta_data.get("password_hash")
|
|
if not password_hash:
|
|
logger.warning(f"No password hash found for user: {email}")
|
|
return None
|
|
|
|
try:
|
|
if not bcrypt.checkpw(password.encode('utf-8'), password_hash.encode('utf-8')):
|
|
logger.warning(f"Invalid password for user: {email}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Password verification error for {email}: {e}")
|
|
return None
|
|
|
|
if not user.is_active:
|
|
logger.warning(f"User {user.id} is deactivated")
|
|
return None
|
|
|
|
# Generate JWT tokens
|
|
token_data = {
|
|
"sub": str(user.id),
|
|
"email": user.email,
|
|
"role": user.role,
|
|
}
|
|
|
|
access_token = create_access_token(token_data)
|
|
refresh_token = create_refresh_token(token_data)
|
|
|
|
# Create session record
|
|
session_record = Session(
|
|
user_id=user.id,
|
|
access_token_hash=hash_token(access_token),
|
|
refresh_token_hash=hash_token(refresh_token),
|
|
expires_at=datetime.utcnow() + timedelta(hours=24),
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
is_active=True,
|
|
)
|
|
|
|
self.session.add(session_record)
|
|
|
|
# Update last login
|
|
user.last_login_at = datetime.utcnow()
|
|
|
|
await self.session.commit()
|
|
|
|
logger.info(f"User {user.id} logged in successfully (simple auth)")
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
"token_type": "bearer",
|
|
"expires_in": 86400, # 24 hours
|
|
"user": {
|
|
"id": str(user.id),
|
|
"email": user.email,
|
|
"display_name": user.display_name,
|
|
"role": user.role,
|
|
}
|
|
}
|
|
|
|
async def login(
|
|
self,
|
|
id_token: str,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None
|
|
) -> Optional[Dict]:
|
|
"""
|
|
Login user with Azure AD token
|
|
|
|
Args:
|
|
id_token: Azure AD ID token
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Dict with access_token, refresh_token, and user info
|
|
"""
|
|
# Validate Azure AD token
|
|
azure_user = await self.validate_azure_token(id_token)
|
|
|
|
if not azure_user:
|
|
logger.warning("Failed to validate Azure AD token")
|
|
return None
|
|
|
|
# Get or create user
|
|
user = await self.user_repo.get_or_create_from_azure(
|
|
azure_ad_id=azure_user["azure_ad_id"],
|
|
email=azure_user["email"],
|
|
display_name=azure_user["display_name"],
|
|
given_name=azure_user.get("given_name"),
|
|
surname=azure_user.get("surname"),
|
|
)
|
|
|
|
if not user.is_active:
|
|
logger.warning(f"User {user.id} is deactivated")
|
|
return None
|
|
|
|
# Generate JWT tokens
|
|
token_data = {
|
|
"sub": str(user.id),
|
|
"email": user.email,
|
|
"role": user.role,
|
|
}
|
|
|
|
access_token = create_access_token(token_data)
|
|
refresh_token = create_refresh_token(token_data)
|
|
|
|
# Create session record
|
|
session_record = Session(
|
|
user_id=user.id,
|
|
access_token_hash=hash_token(access_token),
|
|
refresh_token_hash=hash_token(refresh_token),
|
|
expires_at=datetime.utcnow() + timedelta(hours=24),
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
is_active=True,
|
|
)
|
|
|
|
self.session.add(session_record)
|
|
await self.session.commit()
|
|
|
|
logger.info(f"User {user.id} logged in successfully")
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
"token_type": "bearer",
|
|
"expires_in": 86400, # 24 hours
|
|
"user": {
|
|
"id": str(user.id),
|
|
"email": user.email,
|
|
"display_name": user.display_name,
|
|
"role": user.role,
|
|
}
|
|
}
|
|
|
|
async def logout(self, user_id: UUID, access_token: str) -> bool:
|
|
"""
|
|
Logout user (invalidate session)
|
|
|
|
Args:
|
|
user_id: User UUID
|
|
access_token: Access token to invalidate
|
|
|
|
Returns:
|
|
True if logged out successfully
|
|
"""
|
|
# Find and deactivate session
|
|
from sqlalchemy import select, update
|
|
from app.models.session import Session
|
|
|
|
token_hash = hash_token(access_token)
|
|
|
|
await self.session.execute(
|
|
update(Session)
|
|
.where(
|
|
Session.user_id == user_id,
|
|
Session.access_token_hash == token_hash,
|
|
Session.is_active == True
|
|
)
|
|
.values(is_active=False)
|
|
)
|
|
|
|
await self.session.commit()
|
|
|
|
logger.info(f"User {user_id} logged out")
|
|
return True
|
|
|
|
async def refresh_access_token(self, refresh_token: str) -> Optional[Dict]:
|
|
"""
|
|
Refresh access token using refresh token
|
|
|
|
Args:
|
|
refresh_token: Refresh token
|
|
|
|
Returns:
|
|
Dict with new access_token or None if invalid
|
|
"""
|
|
from app.core.security import decode_token
|
|
|
|
# Decode refresh token
|
|
payload = decode_token(refresh_token)
|
|
|
|
if not payload:
|
|
logger.warning("Invalid refresh token")
|
|
return None
|
|
|
|
user_id = UUID(payload.get("sub"))
|
|
|
|
# Verify user still exists and is active
|
|
user = await self.user_repo.get_by_id(user_id)
|
|
|
|
if not user or not user.is_active:
|
|
logger.warning(f"User {user_id} not found or inactive")
|
|
return None
|
|
|
|
# Generate new access token
|
|
token_data = {
|
|
"sub": str(user.id),
|
|
"email": user.email,
|
|
"role": user.role,
|
|
}
|
|
|
|
new_access_token = create_access_token(token_data)
|
|
|
|
logger.info(f"Refreshed access token for user {user_id}")
|
|
|
|
return {
|
|
"access_token": new_access_token,
|
|
"token_type": "bearer",
|
|
"expires_in": 3600,
|
|
}
|
|
|
|
async def get_current_user(self, access_token: str) -> Optional[User]:
|
|
"""
|
|
Get current user from access token
|
|
|
|
Args:
|
|
access_token: JWT access token
|
|
|
|
Returns:
|
|
User instance or None
|
|
"""
|
|
from app.core.security import get_token_subject
|
|
|
|
user_id_str = get_token_subject(access_token)
|
|
|
|
if not user_id_str:
|
|
return None
|
|
|
|
user_id = UUID(user_id_str)
|
|
user = await self.user_repo.get_by_id(user_id)
|
|
|
|
if user and user.is_active:
|
|
return user
|
|
|
|
return None
|
|
|
|
async def verify_session(self, access_token: str) -> bool:
|
|
"""
|
|
Verify session is valid and active
|
|
|
|
Args:
|
|
access_token: Access token to verify
|
|
|
|
Returns:
|
|
True if session is valid
|
|
"""
|
|
from sqlalchemy import select
|
|
from app.models.session import Session
|
|
|
|
token_hash = hash_token(access_token)
|
|
|
|
result = await self.session.execute(
|
|
select(Session)
|
|
.where(
|
|
Session.access_token_hash == token_hash,
|
|
Session.is_active == True,
|
|
Session.expires_at > datetime.utcnow()
|
|
)
|
|
)
|
|
|
|
session_record = result.scalar_one_or_none()
|
|
return session_record is not None
|