import os from pathlib import Path import logging from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException import json # Create cache directory in /tmp which is usually writable cache_dir = "/tmp/model_cache" os.makedirs(cache_dir, exist_ok=True) # Set environment variables for model caching os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HF_HOME"] = cache_dir os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer, util from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch from fastapi.responses import JSONResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from typing import Dict, Any # Updated HTML template with escaped curly braces HTML_TEMPLATE = """SeeaFile ChatBot

SeeaFile ChatBot

""" app = FastAPI( title="SeeaFile ChatBot API", description="API for question answering using document context", version="1.0.0", docs_url="/docs", # Enable Swagger UI redoc_url="/redoc" # Enable ReDoc ) # Add CORS middleware for mobile app access app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Add error handling middleware @app.middleware("http") async def catch_exceptions_middleware(request: Request, call_next): try: return await call_next(request) except Exception as e: logger.error(f"Unhandled error: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={ "detail": "Internal server error", "error": str(e) } ) @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request, exc): logger.error(f"HTTP error: {exc.detail}") return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail} ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): logger.error(f"Validation error: {str(exc)}") return JSONResponse( status_code=422, content={"detail": str(exc)} ) class HealthResponse(BaseModel): status: str message: str class ChatResponse(BaseModel): answer: str confidence: float = 1.0 @app.get("/", response_class=HTMLResponse) async def root(): """ Health check endpoint with HTML response """ status_data = { "status": "ok", "message": "SeeaFile ChatBot API is running", "endpoints": { "docs": "/docs", "test": "/test", "generate": "/generate", "health": "/health" } } json_str = json.dumps(status_data, indent=2).replace("'", "\\'").replace('"', '\\"') html_content = HTML_TEMPLATE.format(status_json=json_str) return HTMLResponse(content=html_content) @app.get("/test") async def test_endpoint(): """Test endpoint to verify RAG model functionality""" sample_query = Query( question="What is this API about?", documents=["This is a chatbot API that uses RAG (Retrieval Augmented Generation) to answer questions based on provided documents."] ) try: response = await generate_answer(sample_query) return { "status": "success", "test_result": response, "message": "RAG model is working correctly" } except Exception as e: raise HTTPException(status_code=500, detail=f"Test failed: {str(e)}") # Wrap model loading in try-except try: retriever = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=cache_dir) tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir) generator = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir) except Exception as e: logger.error(f"Failed to load models: {str(e)}", exc_info=True) raise RuntimeError(f"Model initialization failed: {str(e)}") class Query(BaseModel): question: str documents: list[str] class ChatMessage(BaseModel): message: str @app.post("/chat") async def chat(message: ChatMessage): """ Simple chat endpoint that processes a single message """ try: if not message.message.strip(): raise HTTPException(status_code=400, detail="Empty message") # Use a default context for simple chat context = "I am an AI assistant that helps answer questions about documents and general topics." # Prepare input for the generator input_text = f"question: {message.message} context: {context}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate the answer output_ids = generator.generate( inputs.input_ids, max_length=150, num_beams=4, early_stopping=True, temperature=0.7 ) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) logger.info(f"Generated response for message: {message.message[:50]}...") return {"response": response} except Exception as e: logger.error(f"Chat error: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Chat failed: {str(e)}" ) @app.post("/generate", response_model=ChatResponse) async def generate_answer(query: Query): """ Generate an answer based on the provided question and documents Args: query (Query): Question and list of documents Returns: ChatResponse: Generated answer with confidence score """ try: if not query.documents: raise HTTPException(status_code=400, detail="No documents provided.") # Encode the documents and the query doc_embeddings = retriever.encode(query.documents, convert_to_tensor=True) query_embedding = retriever.encode(query.question, convert_to_tensor=True) # Compute cosine similarities similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0] top_doc_index = torch.argmax(similarities).item() top_doc = query.documents[top_doc_index] # Prepare input for the generator input_text = f"question: {query.question} context: {top_doc}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Generate the answer output_ids = generator.generate(inputs.input_ids, max_length=150, num_beams=5, early_stopping=True) answer = tokenizer.decode(output_ids[0], skip_special_tokens=True) return ChatResponse( answer=answer, confidence=similarities[top_doc_index].item() ) except Exception as e: logger.error(f"Error generating answer: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to generate answer: {str(e)}" ) # Add health check with model status @app.get("/health") async def health_check(): try: # Test models are loaded return { "status": "healthy", "models": { "retriever": retriever is not None, "tokenizer": tokenizer is not None, "generator": generator is not None } } except Exception as e: logger.error(f"Health check failed: {str(e)}") return { "status": "unhealthy", "error": str(e) }