87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
from typing import Optional
|
|
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
|
|
from ..models.user import User, UserRole
|
|
from .database import get_database
|
|
from .security import decode_token
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
db: AsyncIOMotorDatabase = Depends(get_database),
|
|
) -> User:
|
|
token = credentials.credentials
|
|
payload = decode_token(token)
|
|
user_id: str = payload.get("sub")
|
|
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
)
|
|
|
|
user_doc = await db.users.find_one({"_id": user_id})
|
|
if user_doc is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found",
|
|
)
|
|
|
|
return User(**user_doc)
|
|
|
|
|
|
def require_role(required_role: UserRole):
|
|
async def role_checker(current_user: User = Depends(get_current_user)) -> User:
|
|
if current_user.role != required_role and current_user.role != UserRole.ADMIN:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions",
|
|
)
|
|
return current_user
|
|
|
|
return role_checker
|
|
|
|
|
|
def require_roles(*required_roles: UserRole):
|
|
async def roles_checker(current_user: User = Depends(get_current_user)) -> User:
|
|
if current_user.role not in required_roles and current_user.role != UserRole.ADMIN:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions",
|
|
)
|
|
return current_user
|
|
|
|
return roles_checker
|
|
|
|
|
|
async def get_current_user_optional(
|
|
request: Request,
|
|
db: AsyncIOMotorDatabase = Depends(get_database),
|
|
) -> Optional[User]:
|
|
authorization: str = request.headers.get("Authorization")
|
|
if not authorization:
|
|
return None
|
|
|
|
try:
|
|
scheme, token = authorization.split()
|
|
if scheme.lower() != "bearer":
|
|
return None
|
|
|
|
payload = decode_token(token)
|
|
user_id: str = payload.get("sub")
|
|
|
|
if user_id is None:
|
|
return None
|
|
|
|
user_doc = await db.users.find_one({"_id": user_id})
|
|
if user_doc is None:
|
|
return None
|
|
|
|
return User(**user_doc)
|
|
except Exception:
|
|
return None
|