FireBird-Tech's picture
Update src/routes/chat_routes.py
e252fda verified
from datetime import datetime
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from src.db.init_db import session_factory
from src.db.schemas.models import ModelUsage
from src.managers.ai_manager import AI_Manager
from src.managers.chat_manager import ChatManager
from src.managers.user_manager import get_current_user, User
from src.schemas.chat_schemas import *
from src.utils.logger import Logger
import os
from dotenv import load_dotenv
load_dotenv()
# Initialize logger with console logging disabled
logger = Logger("chat_routes", see_time=True, console_log=False)
# Initialize router
router = APIRouter(prefix="/chats", tags=["chats"])
# Initialize chat manager
chat_manager = ChatManager(db_url=os.getenv("DATABASE_URL"))
# Initialize AI manager
ai_manager = AI_Manager()
# Routes
@router.post("/", response_model=ChatResponse)
async def create_chat(chat_create: ChatCreate):
"""Create a new chat session"""
try:
chat = chat_manager.create_chat(chat_create.user_id)
return chat
except Exception as e:
logger.log_message(f"Error creating chat: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to create chat: {str(e)}")
@router.post("/{chat_id}/messages", response_model=MessageResponse)
async def add_message(chat_id: int, message: MessageCreate, user_id: Optional[int] = None):
"""Add a message to a chat"""
try:
result = chat_manager.add_message(chat_id, message.content, message.sender, user_id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.log_message(f"Error adding message: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}")
@router.get("/{chat_id}", response_model=ChatDetailResponse)
async def get_chat(chat_id: int, user_id: Optional[int] = None):
"""Get a chat by ID with all messages"""
try:
chat = chat_manager.get_chat(chat_id, user_id)
return chat
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.log_message(f"Error retrieving chat: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to retrieve chat: {str(e)}")
@router.get("/", response_model=List[ChatResponse])
async def get_chats(
user_id: Optional[int] = None,
limit: int = Query(10, ge=1, le=100),
offset: int = Query(0, ge=0)
):
"""Get recent chats, optionally filtered by user_id"""
try:
chats = chat_manager.get_user_chats(user_id, limit, offset)
return chats
except Exception as e:
logger.log_message(f"Error retrieving chats: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to retrieve chats: {str(e)}")
@router.delete("/{chat_id}")
async def delete_chat(chat_id: int, user_id: Optional[int] = None):
"""Delete a chat and all its messages while preserving model usage data"""
try:
# Delete the chat using the updated chat_manager method
# which now preserves ModelUsage records
success = chat_manager.delete_chat(chat_id, user_id)
if not success:
raise HTTPException(status_code=404, detail=f"Chat with ID {chat_id} not found or access denied")
return {"message": f"Chat {chat_id} deleted successfully", "preserved_model_usage": True}
except HTTPException:
raise
except Exception as e:
logger.log_message(f"Error deleting chat: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to delete chat: {str(e)}")
@router.post("/users", response_model=dict)
async def create_or_get_user(user_info: UserInfo):
"""Create a new user or get an existing one by email"""
try:
user = chat_manager.get_or_create_user(
username=user_info.username,
email=user_info.email
)
return user
except Exception as e:
logger.log_message(f"Error creating/getting user: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to process user: {str(e)}")
@router.put("/{chat_id}", response_model=ChatResponse)
async def update_chat(chat_id: int, chat_update: ChatUpdate):
"""Update a chat's title or user_id"""
try:
chat = chat_manager.update_chat(
chat_id=chat_id,
title=chat_update.title,
user_id=chat_update.user_id
)
return chat
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.log_message(f"Error updating chat: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to update chat: {str(e)}")
@router.post("/cleanup-empty", response_model=dict)
async def cleanup_empty_chats(request: ChatCreate):
"""Delete empty chats for a user"""
try:
deleted_count = chat_manager.delete_empty_chats(request.user_id, request.is_admin)
return {"message": f"Deleted {deleted_count} empty chats"}
except Exception as e:
logger.log_message(f"Error cleaning up empty chats: {str(e)}", level=logging.ERROR)
raise HTTPException(status_code=500, detail=f"Failed to clean up empty chats: {str(e)}")
@router.post("/debug/test-model-usage")
async def test_model_usage(
model_name: str = "gpt-3.5-turbo",
user_id: Optional[int] = None
):
"""Debug endpoint to manually test model usage tracking"""
try:
# Generate a test prompt
test_prompt = "This is a test message to verify model usage tracking."
# Call the AI manager directly
response = await ai_manager.generate_response(
prompt=test_prompt,
model_name=model_name,
user_id=user_id,
chat_id=999 # Test chat ID
)
# Get the latest model usage entry
session = session_factory()
try:
latest_usage = session.query(ModelUsage).order_by(ModelUsage.usage_id.desc()).first()
return {
"success": True,
"message": "Model usage tracking test completed",
"response": response,
"usage_recorded": {
"usage_id": latest_usage.usage_id if latest_usage else None,
"model_name": latest_usage.model_name if latest_usage else None,
"tokens": latest_usage.total_tokens if latest_usage else None,
"cost": latest_usage.cost if latest_usage else None,
"timestamp": latest_usage.timestamp.isoformat() if latest_usage else None
}
}
finally:
session.close()
except Exception as e:
logger.log_message(f"Error in test-model-usage: {str(e)}", level=logging.ERROR)
return {
"success": False,
"error": str(e)
}