Implement repository pattern and chat service orchestration
Repositories (Data Access Layer): - BaseRepository: Generic CRUD operations with async support - UserRepository: User management, Azure AD integration - ConversationRepository: Conversation CRUD, archiving, search - MessageRepository: Message creation, retrieval, search - TokenUsageRepository: Usage tracking, cost calculation, analytics Chat Service (Business Logic): - Complete conversation lifecycle management - Message sending with OpenAI integration - Multi-turn conversation support via previous_response_id - Automatic token usage tracking - Cost calculation per message - Permission checks for user access - Conversation archiving and deletion - Token usage analytics and reporting Key Features: - Repository pattern for clean data access - Async/await throughout for performance - Proper error handling and logging - Permission verification for user actions - Citation validation from OpenAI responses - Automatic cost tracking per message - File search results stored in message metadata Integration Points: - OpenAIService for AI responses - All SQLAlchemy models - Token cost calculation from settings - Multi-turn conversations via last_response_id Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
86da0b2330
commit
29e6c2e442
6 changed files with 1142 additions and 0 deletions
126
backend/app/repositories/base.py
Normal file
126
backend/app/repositories/base.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Base Repository Pattern
|
||||
|
||||
Provides generic CRUD operations for all repositories
|
||||
"""
|
||||
|
||||
from typing import TypeVar, Generic, Type, Optional, List
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeMeta)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""
|
||||
Base repository with generic CRUD operations
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
session: Async database session
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[ModelType], session: AsyncSession):
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
async def create(self, **kwargs) -> ModelType:
|
||||
"""
|
||||
Create a new record
|
||||
|
||||
Args:
|
||||
**kwargs: Model field values
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
instance = self.model(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(instance)
|
||||
return instance
|
||||
|
||||
async def get_by_id(self, id: UUID) -> Optional[ModelType]:
|
||||
"""
|
||||
Get record by ID
|
||||
|
||||
Args:
|
||||
id: Record UUID
|
||||
|
||||
Returns:
|
||||
Model instance or None
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(self.model.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""
|
||||
Get all records with pagination
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update(self, id: UUID, **kwargs) -> Optional[ModelType]:
|
||||
"""
|
||||
Update record by ID
|
||||
|
||||
Args:
|
||||
id: Record UUID
|
||||
**kwargs: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated model instance or None
|
||||
"""
|
||||
await self.session.execute(
|
||||
update(self.model).where(self.model.id == id).values(**kwargs)
|
||||
)
|
||||
await self.session.commit()
|
||||
return await self.get_by_id(id)
|
||||
|
||||
async def delete(self, id: UUID) -> bool:
|
||||
"""
|
||||
Delete record by ID
|
||||
|
||||
Args:
|
||||
id: Record UUID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
delete(self.model).where(self.model.id == id)
|
||||
)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def exists(self, id: UUID) -> bool:
|
||||
"""
|
||||
Check if record exists
|
||||
|
||||
Args:
|
||||
id: Record UUID
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(self.model.id).where(self.model.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
145
backend/app/repositories/conversation_repository.py
Normal file
145
backend/app/repositories/conversation_repository.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Conversation Repository
|
||||
|
||||
Data access layer for Conversation model
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ConversationRepository(BaseRepository[Conversation]):
|
||||
"""Repository for Conversation model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(Conversation, session)
|
||||
|
||||
async def get_by_user(
|
||||
self,
|
||||
user_id: UUID,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Conversation]:
|
||||
"""
|
||||
Get conversations for a specific user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
include_archived: Include archived conversations
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of conversations
|
||||
"""
|
||||
query = select(Conversation).where(Conversation.user_id == user_id)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(Conversation.is_archived == False)
|
||||
|
||||
query = query.order_by(Conversation.last_message_at.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_with_messages(self, conversation_id: UUID) -> Optional[Conversation]:
|
||||
"""
|
||||
Get conversation with all messages loaded
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
Conversation with messages or None
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == conversation_id)
|
||||
.options(selectinload(Conversation.messages))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def archive(self, conversation_id: UUID) -> bool:
|
||||
"""
|
||||
Archive a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
True if archived successfully
|
||||
"""
|
||||
await self.update(conversation_id, is_archived=True)
|
||||
return True
|
||||
|
||||
async def unarchive(self, conversation_id: UUID) -> bool:
|
||||
"""
|
||||
Unarchive a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
True if unarchived successfully
|
||||
"""
|
||||
await self.update(conversation_id, is_archived=False)
|
||||
return True
|
||||
|
||||
async def update_last_response_id(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
response_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Update the last OpenAI response ID for multi-turn conversations
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
response_id: OpenAI Responses API response ID
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
await self.update(
|
||||
conversation_id,
|
||||
last_response_id=response_id
|
||||
)
|
||||
return True
|
||||
|
||||
async def search_by_title(
|
||||
self,
|
||||
user_id: UUID,
|
||||
search_term: str,
|
||||
limit: int = 20
|
||||
) -> List[Conversation]:
|
||||
"""
|
||||
Search conversations by title
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
search_term: Search term for title
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching conversations
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Conversation)
|
||||
.where(
|
||||
and_(
|
||||
Conversation.user_id == user_id,
|
||||
Conversation.title.ilike(f"%{search_term}%")
|
||||
)
|
||||
)
|
||||
.order_by(Conversation.last_message_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
166
backend/app/repositories/message_repository.py
Normal file
166
backend/app/repositories/message_repository.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
Message Repository
|
||||
|
||||
Data access layer for Message model
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
from app.models.message import Message
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class MessageRepository(BaseRepository[Message]):
|
||||
"""Repository for Message model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(Message, session)
|
||||
|
||||
async def get_by_conversation(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Get messages for a specific conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of messages ordered by creation time
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_latest_by_conversation(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
count: int = 10
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Get latest N messages from a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
count: Number of messages to retrieve
|
||||
|
||||
Returns:
|
||||
List of latest messages
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(count)
|
||||
)
|
||||
return list(reversed(result.scalars().all()))
|
||||
|
||||
async def create_user_message(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
content: str,
|
||||
token_count: int = 0
|
||||
) -> Message:
|
||||
"""
|
||||
Create a user message
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
content: Message content
|
||||
token_count: Estimated token count
|
||||
|
||||
Returns:
|
||||
Created message instance
|
||||
"""
|
||||
return await self.create(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=content,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
async def create_assistant_message(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
content: str,
|
||||
openai_response_id: str,
|
||||
token_count: int = 0,
|
||||
metadata: Optional[dict] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Create an assistant message
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
content: Message content
|
||||
openai_response_id: OpenAI Responses API response ID
|
||||
token_count: Token count from API
|
||||
metadata: Additional metadata (e.g., file_search_results)
|
||||
|
||||
Returns:
|
||||
Created message instance
|
||||
"""
|
||||
return await self.create(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=content,
|
||||
openai_response_id=openai_response_id,
|
||||
token_count=token_count,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
async def get_conversation_token_count(self, conversation_id: UUID) -> int:
|
||||
"""
|
||||
Get total token count for a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
messages = await self.get_by_conversation(conversation_id)
|
||||
return sum(msg.token_count for msg in messages if msg.token_count)
|
||||
|
||||
async def search_content(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
search_term: str,
|
||||
limit: int = 20
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Search messages by content
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
search_term: Search term
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching messages
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Message)
|
||||
.where(
|
||||
and_(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.content.ilike(f"%{search_term}%")
|
||||
)
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
180
backend/app/repositories/token_usage_repository.py
Normal file
180
backend/app/repositories/token_usage_repository.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""
|
||||
Token Usage Repository
|
||||
|
||||
Data access layer for TokenUsage model
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from decimal import Decimal
|
||||
|
||||
from app.models.token_usage import TokenUsage
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class TokenUsageRepository(BaseRepository[TokenUsage]):
|
||||
"""Repository for TokenUsage model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(TokenUsage, session)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
user_id: UUID,
|
||||
conversation_id: UUID,
|
||||
message_id: UUID,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: str,
|
||||
cost_usd: Decimal,
|
||||
operation_type: str = "chat",
|
||||
metadata: Optional[dict] = None
|
||||
) -> TokenUsage:
|
||||
"""
|
||||
Record token usage for a message
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
conversation_id: Conversation UUID
|
||||
message_id: Message UUID
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
total_tokens: Total tokens used
|
||||
model: Model name
|
||||
cost_usd: Cost in USD
|
||||
operation_type: Type of operation
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Created TokenUsage instance
|
||||
"""
|
||||
return await self.create(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
cost_usd=cost_usd,
|
||||
operation_type=operation_type,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
async def get_user_total_tokens(
|
||||
self,
|
||||
user_id: UUID,
|
||||
days: Optional[int] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get total tokens used by user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
days: Optional filter for last N days
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
query = select(func.sum(TokenUsage.total_tokens)).where(
|
||||
TokenUsage.user_id == user_id
|
||||
)
|
||||
|
||||
if days:
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
query = query.where(TokenUsage.created_at >= since)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_user_total_cost(
|
||||
self,
|
||||
user_id: UUID,
|
||||
days: Optional[int] = None
|
||||
) -> Decimal:
|
||||
"""
|
||||
Get total cost for user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
days: Optional filter for last N days
|
||||
|
||||
Returns:
|
||||
Total cost in USD
|
||||
"""
|
||||
query = select(func.sum(TokenUsage.cost_usd)).where(
|
||||
TokenUsage.user_id == user_id
|
||||
)
|
||||
|
||||
if days:
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
query = query.where(TokenUsage.created_at >= since)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar() or Decimal("0.0")
|
||||
|
||||
async def get_usage_by_conversation(
|
||||
self,
|
||||
conversation_id: UUID
|
||||
) -> List[TokenUsage]:
|
||||
"""
|
||||
Get token usage for a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
List of TokenUsage records
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(TokenUsage)
|
||||
.where(TokenUsage.conversation_id == conversation_id)
|
||||
.order_by(TokenUsage.created_at.asc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_daily_usage(
|
||||
self,
|
||||
user_id: UUID,
|
||||
days: int = 30
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get daily token usage for user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
days: Number of days to retrieve
|
||||
|
||||
Returns:
|
||||
List of dicts with date and token counts
|
||||
"""
|
||||
since = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(
|
||||
func.date(TokenUsage.created_at).label("date"),
|
||||
func.sum(TokenUsage.total_tokens).label("tokens"),
|
||||
func.sum(TokenUsage.cost_usd).label("cost")
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
TokenUsage.user_id == user_id,
|
||||
TokenUsage.created_at >= since
|
||||
)
|
||||
)
|
||||
.group_by(func.date(TokenUsage.created_at))
|
||||
.order_by(func.date(TokenUsage.created_at).asc())
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"date": row.date.isoformat(),
|
||||
"tokens": int(row.tokens),
|
||||
"cost": float(row.cost)
|
||||
}
|
||||
for row in result
|
||||
]
|
||||
117
backend/app/repositories/user_repository.py
Normal file
117
backend/app/repositories/user_repository.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
User Repository
|
||||
|
||||
Data access layer for User model
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
"""Repository for User model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(User, session)
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[User]:
|
||||
"""
|
||||
Get user by email
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User instance or None
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_azure_ad_id(self, azure_ad_id: str) -> Optional[User]:
|
||||
"""
|
||||
Get user by Azure AD ID
|
||||
|
||||
Args:
|
||||
azure_ad_id: Azure AD identifier
|
||||
|
||||
Returns:
|
||||
User instance or None
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User).where(User.azure_ad_id == azure_ad_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create_from_azure(
|
||||
self,
|
||||
azure_ad_id: str,
|
||||
email: str,
|
||||
display_name: str,
|
||||
given_name: Optional[str] = None,
|
||||
surname: Optional[str] = None,
|
||||
) -> User:
|
||||
"""
|
||||
Get existing user or create new from Azure AD data
|
||||
|
||||
Args:
|
||||
azure_ad_id: Azure AD identifier
|
||||
email: User email
|
||||
display_name: User display name
|
||||
given_name: User first name
|
||||
surname: User last name
|
||||
|
||||
Returns:
|
||||
User instance (existing or newly created)
|
||||
"""
|
||||
# Try to find existing user
|
||||
user = await self.get_by_azure_ad_id(azure_ad_id)
|
||||
|
||||
if user:
|
||||
# Update last login
|
||||
user.last_login_at = self.session.bind.dialect.get_dbapi_connection().now()
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user)
|
||||
return user
|
||||
|
||||
# Create new user
|
||||
return await self.create(
|
||||
azure_ad_id=azure_ad_id,
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
given_name=given_name,
|
||||
surname=surname,
|
||||
role="user",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
async def deactivate(self, user_id) -> bool:
|
||||
"""
|
||||
Deactivate user account
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
True if deactivated successfully
|
||||
"""
|
||||
await self.update(user_id, is_active=False)
|
||||
return True
|
||||
|
||||
async def activate(self, user_id) -> bool:
|
||||
"""
|
||||
Activate user account
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
True if activated successfully
|
||||
"""
|
||||
await self.update(user_id, is_active=True)
|
||||
return True
|
||||
408
backend/app/services/chat_service.py
Normal file
408
backend/app/services/chat_service.py
Normal file
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Chat Service
|
||||
|
||||
Orchestrates chat functionality between API, repositories, and OpenAI service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, List
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.openai_service import OpenAIService
|
||||
from app.repositories.conversation_repository import ConversationRepository
|
||||
from app.repositories.message_repository import MessageRepository
|
||||
from app.repositories.token_usage_repository import TokenUsageRepository
|
||||
from app.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
Main chat service for handling conversations and messages.
|
||||
|
||||
Orchestrates:
|
||||
- Conversation management
|
||||
- Message creation and retrieval
|
||||
- OpenAI API integration
|
||||
- Token usage tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
openai_service: Optional[OpenAIService] = None
|
||||
):
|
||||
"""
|
||||
Initialize chat service
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
openai_service: Optional OpenAI service instance
|
||||
"""
|
||||
self.session = session
|
||||
self.openai_service = openai_service or OpenAIService()
|
||||
|
||||
# Initialize repositories
|
||||
self.conversation_repo = ConversationRepository(session)
|
||||
self.message_repo = MessageRepository(session)
|
||||
self.token_repo = TokenUsageRepository(session)
|
||||
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id: UUID,
|
||||
title: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Create a new conversation
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
title: Optional conversation title
|
||||
|
||||
Returns:
|
||||
Dict with conversation data
|
||||
"""
|
||||
conversation = await self.conversation_repo.create(
|
||||
user_id=user_id,
|
||||
title=title or "New Conversation"
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation.id} for user {user_id}")
|
||||
|
||||
return {
|
||||
"id": str(conversation.id),
|
||||
"user_id": str(conversation.user_id),
|
||||
"title": conversation.title,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"last_message_at": conversation.last_message_at.isoformat() if conversation.last_message_at else None,
|
||||
"is_archived": conversation.is_archived,
|
||||
}
|
||||
|
||||
async def get_conversation(self, conversation_id: UUID) -> Optional[Dict]:
|
||||
"""
|
||||
Get conversation by ID
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
Dict with conversation data or None
|
||||
"""
|
||||
conversation = await self.conversation_repo.get_by_id(conversation_id)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": str(conversation.id),
|
||||
"user_id": str(conversation.user_id),
|
||||
"title": conversation.title,
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"last_message_at": conversation.last_message_at.isoformat() if conversation.last_message_at else None,
|
||||
"is_archived": conversation.is_archived,
|
||||
"last_response_id": conversation.last_response_id,
|
||||
}
|
||||
|
||||
async def list_conversations(
|
||||
self,
|
||||
user_id: UUID,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
List conversations for a user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
include_archived: Include archived conversations
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of conversation dicts
|
||||
"""
|
||||
conversations = await self.conversation_repo.get_by_user(
|
||||
user_id=user_id,
|
||||
include_archived=include_archived,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(conv.id),
|
||||
"title": conv.title,
|
||||
"created_at": conv.created_at.isoformat(),
|
||||
"last_message_at": conv.last_message_at.isoformat() if conv.last_message_at else None,
|
||||
"is_archived": conv.is_archived,
|
||||
}
|
||||
for conv in conversations
|
||||
]
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
user_id: UUID,
|
||||
conversation_id: UUID,
|
||||
message_content: str
|
||||
) -> Dict:
|
||||
"""
|
||||
Send a message and get AI response
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
conversation_id: Conversation UUID
|
||||
message_content: User's message text
|
||||
|
||||
Returns:
|
||||
Dict with user message and assistant response
|
||||
"""
|
||||
# Verify conversation exists and belongs to user
|
||||
conversation = await self.conversation_repo.get_by_id(conversation_id)
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
if conversation.user_id != user_id:
|
||||
raise PermissionError(f"User {user_id} does not have access to conversation {conversation_id}")
|
||||
|
||||
logger.info(f"Processing message from user {user_id} in conversation {conversation_id}")
|
||||
|
||||
# 1. Save user message
|
||||
user_message = await self.message_repo.create_user_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message_content,
|
||||
token_count=self._estimate_tokens(message_content)
|
||||
)
|
||||
|
||||
# 2. Get previous response ID for multi-turn conversation
|
||||
previous_response_id = conversation.last_response_id
|
||||
|
||||
# 3. Generate AI response
|
||||
try:
|
||||
openai_response = await self.openai_service.generate_response(
|
||||
user_message=message_content,
|
||||
previous_response_id=previous_response_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated OpenAI response {openai_response['response_id']} "
|
||||
f"with {openai_response['usage']['total_tokens']} tokens"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 4. Save assistant message
|
||||
assistant_message = await self.message_repo.create_assistant_message(
|
||||
conversation_id=conversation_id,
|
||||
content=openai_response["content"],
|
||||
openai_response_id=openai_response["response_id"],
|
||||
token_count=openai_response["usage"]["completion_tokens"],
|
||||
metadata={
|
||||
"file_search_results": openai_response["file_search_results"],
|
||||
"has_citations": openai_response["has_citations"],
|
||||
"needs_review": openai_response.get("needs_review", False)
|
||||
}
|
||||
)
|
||||
|
||||
# 5. Update conversation's last_response_id for next turn
|
||||
await self.conversation_repo.update_last_response_id(
|
||||
conversation_id=conversation_id,
|
||||
response_id=openai_response["response_id"]
|
||||
)
|
||||
|
||||
# 6. Record token usage
|
||||
cost_usd = self._calculate_cost(
|
||||
prompt_tokens=openai_response["usage"]["prompt_tokens"],
|
||||
completion_tokens=openai_response["usage"]["completion_tokens"]
|
||||
)
|
||||
|
||||
await self.token_repo.record_usage(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=assistant_message.id,
|
||||
prompt_tokens=openai_response["usage"]["prompt_tokens"],
|
||||
completion_tokens=openai_response["usage"]["completion_tokens"],
|
||||
total_tokens=openai_response["usage"]["total_tokens"],
|
||||
model=settings.OPENAI_MODEL,
|
||||
cost_usd=cost_usd,
|
||||
operation_type="chat"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Completed message exchange in conversation {conversation_id}. "
|
||||
f"Cost: ${cost_usd:.6f}"
|
||||
)
|
||||
|
||||
# 7. Return response
|
||||
return {
|
||||
"user_message": {
|
||||
"id": str(user_message.id),
|
||||
"content": user_message.content,
|
||||
"created_at": user_message.created_at.isoformat()
|
||||
},
|
||||
"assistant_message": {
|
||||
"id": str(assistant_message.id),
|
||||
"content": assistant_message.content,
|
||||
"created_at": assistant_message.created_at.isoformat(),
|
||||
"file_search_results": openai_response["file_search_results"],
|
||||
"needs_review": openai_response.get("needs_review", False)
|
||||
},
|
||||
"usage": openai_response["usage"],
|
||||
"cost_usd": float(cost_usd)
|
||||
}
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get messages for a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
skip: Number of messages to skip
|
||||
limit: Maximum number of messages
|
||||
|
||||
Returns:
|
||||
List of message dicts
|
||||
"""
|
||||
messages = await self.message_repo.get_by_conversation(
|
||||
conversation_id=conversation_id,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(msg.id),
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
"metadata": msg.metadata
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
async def update_conversation_title(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
title: str
|
||||
) -> bool:
|
||||
"""
|
||||
Update conversation title
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
title: New title
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
await self.conversation_repo.update(conversation_id, title=title)
|
||||
logger.info(f"Updated conversation {conversation_id} title to: {title}")
|
||||
return True
|
||||
|
||||
async def archive_conversation(self, conversation_id: UUID) -> bool:
|
||||
"""
|
||||
Archive a conversation
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
True if archived successfully
|
||||
"""
|
||||
await self.conversation_repo.archive(conversation_id)
|
||||
logger.info(f"Archived conversation {conversation_id}")
|
||||
return True
|
||||
|
||||
async def delete_conversation(
|
||||
self,
|
||||
user_id: UUID,
|
||||
conversation_id: UUID
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a conversation (with permission check)
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
conversation_id: Conversation UUID
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
"""
|
||||
# Verify ownership
|
||||
conversation = await self.conversation_repo.get_by_id(conversation_id)
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
if conversation.user_id != user_id:
|
||||
raise PermissionError(f"User {user_id} does not have access to conversation {conversation_id}")
|
||||
|
||||
# Delete (will cascade to messages)
|
||||
await self.conversation_repo.delete(conversation_id)
|
||||
logger.info(f"Deleted conversation {conversation_id} for user {user_id}")
|
||||
return True
|
||||
|
||||
async def get_token_usage_summary(
|
||||
self,
|
||||
user_id: UUID,
|
||||
days: Optional[int] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Get token usage summary for user
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
days: Optional filter for last N days
|
||||
|
||||
Returns:
|
||||
Dict with usage statistics
|
||||
"""
|
||||
total_tokens = await self.token_repo.get_user_total_tokens(user_id, days)
|
||||
total_cost = await self.token_repo.get_user_total_cost(user_id, days)
|
||||
daily_usage = await self.token_repo.get_daily_usage(user_id, days or 30)
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost_usd": float(total_cost),
|
||||
"daily_breakdown": daily_usage,
|
||||
"period_days": days
|
||||
}
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Estimate token count for text (rough approximation)
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough estimate: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> Decimal:
|
||||
"""
|
||||
Calculate cost in USD
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
Total cost in USD
|
||||
"""
|
||||
prompt_cost = Decimal(str(prompt_tokens)) * Decimal(str(settings.PROMPT_TOKEN_COST)) / Decimal("1000")
|
||||
completion_cost = Decimal(str(completion_tokens)) * Decimal(str(settings.COMPLETION_TOKEN_COST)) / Decimal("1000")
|
||||
|
||||
return prompt_cost + completion_cost
|
||||
Loading…
Add table
Reference in a new issue