| from fastapi import FastAPI, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from pydantic import BaseModel
|
| from typing import Optional, List, Dict
|
| import os
|
| from datetime import datetime
|
| import logging
|
| import threading
|
| import requests
|
| from chatbot import RAGChatbot
|
|
|
|
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| app = FastAPI(
|
| title="RAG Chatbot API - Multi-User",
|
| description="HR Assistant Chatbot with Per-User Session Management",
|
| version="2.0.0",
|
| docs_url="/docs",
|
| redoc_url="/redoc"
|
| )
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| base_chatbot = None
|
|
|
|
|
| user_sessions = {}
|
| session_lock = threading.Lock()
|
|
|
|
|
| MAX_SESSIONS = 100
|
| SESSION_TIMEOUT = 3600
|
|
|
|
|
| class UserSession:
|
| """Isolated session for each user"""
|
|
|
| def __init__(self, user_id: str):
|
| self.user_id = user_id
|
| self.chat_history = []
|
| self.conversation_context = {
|
| 'current_employee': None,
|
| 'last_mentioned_entities': []
|
| }
|
| self.last_activity = datetime.now()
|
|
|
| def update_activity(self):
|
| self.last_activity = datetime.now()
|
|
|
|
|
| def cleanup_old_sessions():
|
| """Remove inactive sessions"""
|
| with session_lock:
|
| current_time = datetime.now()
|
| to_remove = []
|
|
|
| for user_id, session in user_sessions.items():
|
| time_diff = (current_time - session.last_activity).total_seconds()
|
| if time_diff > SESSION_TIMEOUT:
|
| to_remove.append(user_id)
|
|
|
| for user_id in to_remove:
|
| del user_sessions[user_id]
|
| logger.info(f"Cleaned up session for user: {user_id}")
|
|
|
|
|
| def get_or_create_session(user_id: str) -> UserSession:
|
| """Get existing session or create new one"""
|
| with session_lock:
|
| if len(user_sessions) > MAX_SESSIONS:
|
| cleanup_old_sessions()
|
|
|
| if user_id not in user_sessions:
|
| user_sessions[user_id] = UserSession(user_id)
|
| logger.info(f"Created new session for user: {user_id}")
|
|
|
| session = user_sessions[user_id]
|
| session.update_activity()
|
| return session
|
|
|
|
|
|
|
| class ChatRequest(BaseModel):
|
| question: str
|
| user_id: str
|
|
|
|
|
| class ChatResponse(BaseModel):
|
| question: str
|
| answer: str
|
| timestamp: str
|
| user_id: str
|
| session_info: Dict
|
|
|
|
|
| @app.on_event("startup")
|
| async def startup_event():
|
| global base_chatbot
|
|
|
| logger.info("=== Starting RAG Chatbot Initialization ===")
|
|
|
| try:
|
| PDF_PATH = os.getenv("PDF_PATH", "./data/policies.pdf")
|
| HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
|
|
| logger.info(f"PDF Path: {PDF_PATH}")
|
| logger.info(f"File exists: {os.path.exists(PDF_PATH)}")
|
|
|
| if not os.path.exists(PDF_PATH):
|
| raise ValueError(f"PDF file not found at {PDF_PATH}")
|
|
|
| base_chatbot = RAGChatbot(PDF_PATH, HF_TOKEN)
|
| logger.info("=== Base chatbot initialized successfully! ===")
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to initialize chatbot: {e}")
|
| raise
|
|
|
|
|
| @app.get("/")
|
| async def root():
|
| return {
|
| "service": "RAG Chatbot API",
|
| "version": "2.0.0",
|
| "status": "healthy",
|
| "active_sessions": len(user_sessions),
|
| "chatbot_loaded": base_chatbot is not None,
|
| "endpoints": {
|
| "docs": "/docs",
|
| "chat": "POST /api/chat",
|
| "history": "GET /api/history/{user_id}",
|
| "reset": "POST /api/reset?user_id=xxx",
|
| "sessions": "GET /api/sessions"
|
| }
|
| }
|
|
|
|
|
| @app.get("/api/health")
|
| async def health_check():
|
| if base_chatbot is None:
|
| raise HTTPException(status_code=503, detail="Chatbot not initialized")
|
|
|
| return {
|
| "status": "healthy",
|
| "timestamp": datetime.now().isoformat(),
|
| "chatbot_ready": True,
|
| "active_sessions": len(user_sessions)
|
| }
|
|
|
|
|
| @app.post("/api/chat", response_model=ChatResponse)
|
| async def chat(request: ChatRequest):
|
| """Send a question to the chatbot with user session isolation"""
|
| if base_chatbot is None:
|
| raise HTTPException(status_code=503, detail="Chatbot not initialized")
|
|
|
| if not request.question.strip():
|
| raise HTTPException(status_code=400, detail="Question cannot be empty")
|
|
|
| if not request.user_id:
|
| raise HTTPException(status_code=400, detail="user_id is required")
|
|
|
| try:
|
| logger.info(f"User {request.user_id}: {request.question[:50]}...")
|
|
|
|
|
| session = get_or_create_session(request.user_id)
|
|
|
|
|
| resolved_question = base_chatbot._resolve_pronouns_for_session(
|
| request.question,
|
| session.conversation_context
|
| )
|
|
|
|
|
| retrieved_data = base_chatbot._retrieve(resolved_question, k=20)
|
|
|
|
|
| relevant_past_chats = base_chatbot._search_session_history(
|
| resolved_question,
|
| session.chat_history,
|
| k=5
|
| )
|
|
|
|
|
| prompt = base_chatbot._build_prompt_for_session(
|
| resolved_question,
|
| retrieved_data,
|
| relevant_past_chats,
|
| session.chat_history,
|
| session.conversation_context
|
| )
|
|
|
|
|
| payload = {
|
| "model": base_chatbot.model_name,
|
| "messages": [
|
| {
|
| "role": "user",
|
| "content": prompt
|
| }
|
| ],
|
| "max_tokens": 512,
|
| "temperature": 0.3
|
| }
|
|
|
| response = requests.post(
|
| base_chatbot.api_url,
|
| headers=base_chatbot.headers,
|
| json=payload,
|
| timeout=60
|
| )
|
| response.raise_for_status()
|
| result = response.json()
|
|
|
|
|
| answer = result["choices"][0]["message"]["content"]
|
|
|
|
|
| base_chatbot._update_conversation_context_for_session(
|
| request.question,
|
| answer,
|
| session.conversation_context
|
| )
|
|
|
|
|
| chat_entry = {
|
| 'timestamp': datetime.now().isoformat(),
|
| 'question': request.question,
|
| 'answer': answer,
|
| 'used_past_context': len(relevant_past_chats) > 0
|
| }
|
| session.chat_history.append(chat_entry)
|
|
|
| response_data = ChatResponse(
|
| question=request.question,
|
| answer=answer,
|
| timestamp=datetime.now().isoformat(),
|
| user_id=request.user_id,
|
| session_info={
|
| 'total_messages': len(session.chat_history),
|
| 'current_context': session.conversation_context.get('current_employee')
|
| }
|
| )
|
|
|
| logger.info(f"User {request.user_id}: Question processed successfully")
|
| return response_data
|
|
|
| except Exception as e:
|
| logger.error(f"Error for user {request.user_id}: {e}")
|
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
|
|
| @app.post("/api/reset")
|
| async def reset_chat(user_id: str):
|
| """Reset chat history for specific user"""
|
| if not user_id:
|
| raise HTTPException(status_code=400, detail="user_id is required")
|
|
|
| with session_lock:
|
| if user_id in user_sessions:
|
| del user_sessions[user_id]
|
| logger.info(f"Reset session for user: {user_id}")
|
| return {"message": f"Chat history reset for user {user_id}", "status": "success"}
|
| else:
|
| return {"message": f"No session found for user {user_id}", "status": "success"}
|
|
|
|
|
| @app.get("/api/history/{user_id}")
|
| async def get_history(user_id: str):
|
| """Get chat history for specific user"""
|
| session = get_or_create_session(user_id)
|
|
|
| return {
|
| "user_id": user_id,
|
| "total_conversations": len(session.chat_history),
|
| "current_context": session.conversation_context.get('current_employee'),
|
| "history": session.chat_history
|
| }
|
|
|
|
|
| @app.get("/api/sessions")
|
| async def get_active_sessions():
|
| """Get list of active sessions"""
|
| with session_lock:
|
| return {
|
| "total_sessions": len(user_sessions),
|
| "max_sessions": MAX_SESSIONS,
|
| "session_timeout_seconds": SESSION_TIMEOUT,
|
| "sessions": [
|
| {
|
| "user_id": user_id,
|
| "messages": len(session.chat_history),
|
| "last_activity": session.last_activity.isoformat(),
|
| "current_context": session.conversation_context.get('current_employee')
|
| }
|
| for user_id, session in user_sessions.items()
|
| ]
|
| }
|
|
|
|
|
| @app.post("/api/cleanup")
|
| async def manual_cleanup():
|
| """Manually trigger session cleanup"""
|
| cleanup_old_sessions()
|
| return {
|
| "message": "Cleanup completed",
|
| "active_sessions": len(user_sessions)
|
| }
|
|
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |