Neepurna
update
33611d6
import os
from pathlib import Path
import logging
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
import json
# Create cache directory in /tmp which is usually writable
cache_dir = "/tmp/model_cache"
os.makedirs(cache_dir, exist_ok=True)
# Set environment variables for model caching
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_HOME"] = cache_dir
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from typing import Dict, Any
# Updated HTML template with escaped curly braces
HTML_TEMPLATE = """<!DOCTYPE html><html><head><title>SeeaFile ChatBot</title><style>
body{{font-family:Arial,sans-serif;max-width:800px;margin:40px auto;padding:20px;background:#f8f9fa}}
.chat-container{{background:white;padding:20px;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1)}}
.chat-input{{display:flex;gap:10px;margin-top:20px}}
#message{{flex:1;padding:10px;border:1px solid #ddd;border-radius:4px}}
.btn{{background:#007bff;color:white;border:none;padding:10px 20px;border-radius:4px;cursor:pointer}}
.btn:hover{{background:#0056b3}}
#chat-history{{margin-top:20px}}
.message{{padding:10px;margin:5px 0;border-radius:4px}}
.user-msg{{background:#e9ecef}}
.bot-msg{{background:#f8f9fa;border-left:3px solid #007bff}}
</style></head><body>
<div class="chat-container">
<h1>SeeaFile ChatBot</h1>
<div id="chat-history"></div>
<div class="chat-input">
<input type="text" id="message" placeholder="Type your message..." />
<button class="btn" id="sendBtn">Send</button>
</div>
</div>
<script>
async function sendMessage() {{
const msgInput = document.getElementById('message');
const msg = msgInput.value.trim();
if (!msg) return;
// Add user message
addMessage(msg, 'user');
msgInput.value = '';
try {{
console.log('Sending message:', msg); // Debug log
const response = await fetch('/chat', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json',
'Accept': 'application/json'
}},
body: JSON.stringify({{message: msg}})
}});
if (!response.ok) {{
throw new Error(`HTTP error! status: ${{response.status}}`);
}}
const data = await response.json();
console.log('Received response:', data); // Debug log
if (data.response) {{
addMessage(data.response, 'bot');
}} else if (data.detail) {{
addMessage(`Error: ${{data.detail}}`, 'bot');
}} else {{
addMessage('Received invalid response format', 'bot');
}}
}} catch (error) {{
console.error('Chat error:', error); // Debug log
addMessage(`Error: ${{error.message}}`, 'bot');
}}
}}
function addMessage(text, sender) {{
const history = document.getElementById('chat-history');
const msg = document.createElement('div');
msg.className = `message ${{sender}}-msg`;
msg.textContent = text;
history.appendChild(msg);
history.scrollTop = history.scrollHeight;
}}
// Add event listeners after DOM is loaded
document.addEventListener('DOMContentLoaded', () => {{
const msgInput = document.getElementById('message');
const sendBtn = document.getElementById('sendBtn');
msgInput.addEventListener('keypress', (e) => {{
if (e.key === 'Enter') {{
e.preventDefault();
sendMessage();
}}
}});
sendBtn.addEventListener('click', () => {{
sendMessage();
}});
}});
</script></body></html>"""
app = FastAPI(
title="SeeaFile ChatBot API",
description="API for question answering using document context",
version="1.0.0",
docs_url="/docs", # Enable Swagger UI
redoc_url="/redoc" # Enable ReDoc
)
# Add CORS middleware for mobile app access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add error handling middleware
@app.middleware("http")
async def catch_exceptions_middleware(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
logger.error(f"Unhandled error: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"detail": "Internal server error",
"error": str(e)
}
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request, exc):
logger.error(f"HTTP error: {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
logger.error(f"Validation error: {str(exc)}")
return JSONResponse(
status_code=422,
content={"detail": str(exc)}
)
class HealthResponse(BaseModel):
status: str
message: str
class ChatResponse(BaseModel):
answer: str
confidence: float = 1.0
@app.get("/", response_class=HTMLResponse)
async def root():
"""
Health check endpoint with HTML response
"""
status_data = {
"status": "ok",
"message": "SeeaFile ChatBot API is running",
"endpoints": {
"docs": "/docs",
"test": "/test",
"generate": "/generate",
"health": "/health"
}
}
json_str = json.dumps(status_data, indent=2).replace("'", "\\'").replace('"', '\\"')
html_content = HTML_TEMPLATE.format(status_json=json_str)
return HTMLResponse(content=html_content)
@app.get("/test")
async def test_endpoint():
"""Test endpoint to verify RAG model functionality"""
sample_query = Query(
question="What is this API about?",
documents=["This is a chatbot API that uses RAG (Retrieval Augmented Generation) to answer questions based on provided documents."]
)
try:
response = await generate_answer(sample_query)
return {
"status": "success",
"test_result": response,
"message": "RAG model is working correctly"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Test failed: {str(e)}")
# Wrap model loading in try-except
try:
retriever = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=cache_dir)
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir)
generator = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir)
except Exception as e:
logger.error(f"Failed to load models: {str(e)}", exc_info=True)
raise RuntimeError(f"Model initialization failed: {str(e)}")
class Query(BaseModel):
question: str
documents: list[str]
class ChatMessage(BaseModel):
message: str
@app.post("/chat")
async def chat(message: ChatMessage):
"""
Simple chat endpoint that processes a single message
"""
try:
if not message.message.strip():
raise HTTPException(status_code=400, detail="Empty message")
# Use a default context for simple chat
context = "I am an AI assistant that helps answer questions about documents and general topics."
# Prepare input for the generator
input_text = f"question: {message.message} context: {context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate the answer
output_ids = generator.generate(
inputs.input_ids,
max_length=150,
num_beams=4,
early_stopping=True,
temperature=0.7
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
logger.info(f"Generated response for message: {message.message[:50]}...")
return {"response": response}
except Exception as e:
logger.error(f"Chat error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Chat failed: {str(e)}"
)
@app.post("/generate", response_model=ChatResponse)
async def generate_answer(query: Query):
"""
Generate an answer based on the provided question and documents
Args:
query (Query): Question and list of documents
Returns:
ChatResponse: Generated answer with confidence score
"""
try:
if not query.documents:
raise HTTPException(status_code=400, detail="No documents provided.")
# Encode the documents and the query
doc_embeddings = retriever.encode(query.documents, convert_to_tensor=True)
query_embedding = retriever.encode(query.question, convert_to_tensor=True)
# Compute cosine similarities
similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
top_doc_index = torch.argmax(similarities).item()
top_doc = query.documents[top_doc_index]
# Prepare input for the generator
input_text = f"question: {query.question} context: {top_doc}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate the answer
output_ids = generator.generate(inputs.input_ids, max_length=150, num_beams=5, early_stopping=True)
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return ChatResponse(
answer=answer,
confidence=similarities[top_doc_index].item()
)
except Exception as e:
logger.error(f"Error generating answer: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to generate answer: {str(e)}"
)
# Add health check with model status
@app.get("/health")
async def health_check():
try:
# Test models are loaded
return {
"status": "healthy",
"models": {
"retriever": retriever is not None,
"tokenizer": tokenizer is not None,
"generator": generator is not None
}
}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}