Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
import logging | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from fastapi.exceptions import RequestValidationError | |
from starlette.exceptions import HTTPException as StarletteHTTPException | |
import json | |
# Create cache directory in /tmp which is usually writable | |
cache_dir = "/tmp/model_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Set environment variables for model caching | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
os.environ["HF_HOME"] = cache_dir | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from typing import Dict, Any | |
# Updated HTML template with escaped curly braces | |
HTML_TEMPLATE = """<!DOCTYPE html><html><head><title>SeeaFile ChatBot</title><style> | |
body{{font-family:Arial,sans-serif;max-width:800px;margin:40px auto;padding:20px;background:#f8f9fa}} | |
.chat-container{{background:white;padding:20px;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1)}} | |
.chat-input{{display:flex;gap:10px;margin-top:20px}} | |
#message{{flex:1;padding:10px;border:1px solid #ddd;border-radius:4px}} | |
.btn{{background:#007bff;color:white;border:none;padding:10px 20px;border-radius:4px;cursor:pointer}} | |
.btn:hover{{background:#0056b3}} | |
#chat-history{{margin-top:20px}} | |
.message{{padding:10px;margin:5px 0;border-radius:4px}} | |
.user-msg{{background:#e9ecef}} | |
.bot-msg{{background:#f8f9fa;border-left:3px solid #007bff}} | |
</style></head><body> | |
<div class="chat-container"> | |
<h1>SeeaFile ChatBot</h1> | |
<div id="chat-history"></div> | |
<div class="chat-input"> | |
<input type="text" id="message" placeholder="Type your message..." /> | |
<button class="btn" id="sendBtn">Send</button> | |
</div> | |
</div> | |
<script> | |
async function sendMessage() {{ | |
const msgInput = document.getElementById('message'); | |
const msg = msgInput.value.trim(); | |
if (!msg) return; | |
// Add user message | |
addMessage(msg, 'user'); | |
msgInput.value = ''; | |
try {{ | |
console.log('Sending message:', msg); // Debug log | |
const response = await fetch('/chat', {{ | |
method: 'POST', | |
headers: {{ | |
'Content-Type': 'application/json', | |
'Accept': 'application/json' | |
}}, | |
body: JSON.stringify({{message: msg}}) | |
}}); | |
if (!response.ok) {{ | |
throw new Error(`HTTP error! status: ${{response.status}}`); | |
}} | |
const data = await response.json(); | |
console.log('Received response:', data); // Debug log | |
if (data.response) {{ | |
addMessage(data.response, 'bot'); | |
}} else if (data.detail) {{ | |
addMessage(`Error: ${{data.detail}}`, 'bot'); | |
}} else {{ | |
addMessage('Received invalid response format', 'bot'); | |
}} | |
}} catch (error) {{ | |
console.error('Chat error:', error); // Debug log | |
addMessage(`Error: ${{error.message}}`, 'bot'); | |
}} | |
}} | |
function addMessage(text, sender) {{ | |
const history = document.getElementById('chat-history'); | |
const msg = document.createElement('div'); | |
msg.className = `message ${{sender}}-msg`; | |
msg.textContent = text; | |
history.appendChild(msg); | |
history.scrollTop = history.scrollHeight; | |
}} | |
// Add event listeners after DOM is loaded | |
document.addEventListener('DOMContentLoaded', () => {{ | |
const msgInput = document.getElementById('message'); | |
const sendBtn = document.getElementById('sendBtn'); | |
msgInput.addEventListener('keypress', (e) => {{ | |
if (e.key === 'Enter') {{ | |
e.preventDefault(); | |
sendMessage(); | |
}} | |
}}); | |
sendBtn.addEventListener('click', () => {{ | |
sendMessage(); | |
}}); | |
}}); | |
</script></body></html>""" | |
app = FastAPI( | |
title="SeeaFile ChatBot API", | |
description="API for question answering using document context", | |
version="1.0.0", | |
docs_url="/docs", # Enable Swagger UI | |
redoc_url="/redoc" # Enable ReDoc | |
) | |
# Add CORS middleware for mobile app access | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Adjust in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Add error handling middleware | |
async def catch_exceptions_middleware(request: Request, call_next): | |
try: | |
return await call_next(request) | |
except Exception as e: | |
logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={ | |
"detail": "Internal server error", | |
"error": str(e) | |
} | |
) | |
async def http_exception_handler(request, exc): | |
logger.error(f"HTTP error: {exc.detail}") | |
return JSONResponse( | |
status_code=exc.status_code, | |
content={"detail": exc.detail} | |
) | |
async def validation_exception_handler(request, exc): | |
logger.error(f"Validation error: {str(exc)}") | |
return JSONResponse( | |
status_code=422, | |
content={"detail": str(exc)} | |
) | |
class HealthResponse(BaseModel): | |
status: str | |
message: str | |
class ChatResponse(BaseModel): | |
answer: str | |
confidence: float = 1.0 | |
async def root(): | |
""" | |
Health check endpoint with HTML response | |
""" | |
status_data = { | |
"status": "ok", | |
"message": "SeeaFile ChatBot API is running", | |
"endpoints": { | |
"docs": "/docs", | |
"test": "/test", | |
"generate": "/generate", | |
"health": "/health" | |
} | |
} | |
json_str = json.dumps(status_data, indent=2).replace("'", "\\'").replace('"', '\\"') | |
html_content = HTML_TEMPLATE.format(status_json=json_str) | |
return HTMLResponse(content=html_content) | |
async def test_endpoint(): | |
"""Test endpoint to verify RAG model functionality""" | |
sample_query = Query( | |
question="What is this API about?", | |
documents=["This is a chatbot API that uses RAG (Retrieval Augmented Generation) to answer questions based on provided documents."] | |
) | |
try: | |
response = await generate_answer(sample_query) | |
return { | |
"status": "success", | |
"test_result": response, | |
"message": "RAG model is working correctly" | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Test failed: {str(e)}") | |
# Wrap model loading in try-except | |
try: | |
retriever = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=cache_dir) | |
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir) | |
generator = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn', cache_dir=cache_dir) | |
except Exception as e: | |
logger.error(f"Failed to load models: {str(e)}", exc_info=True) | |
raise RuntimeError(f"Model initialization failed: {str(e)}") | |
class Query(BaseModel): | |
question: str | |
documents: list[str] | |
class ChatMessage(BaseModel): | |
message: str | |
async def chat(message: ChatMessage): | |
""" | |
Simple chat endpoint that processes a single message | |
""" | |
try: | |
if not message.message.strip(): | |
raise HTTPException(status_code=400, detail="Empty message") | |
# Use a default context for simple chat | |
context = "I am an AI assistant that helps answer questions about documents and general topics." | |
# Prepare input for the generator | |
input_text = f"question: {message.message} context: {context}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Generate the answer | |
output_ids = generator.generate( | |
inputs.input_ids, | |
max_length=150, | |
num_beams=4, | |
early_stopping=True, | |
temperature=0.7 | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
logger.info(f"Generated response for message: {message.message[:50]}...") | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Chat error: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=500, | |
detail=f"Chat failed: {str(e)}" | |
) | |
async def generate_answer(query: Query): | |
""" | |
Generate an answer based on the provided question and documents | |
Args: | |
query (Query): Question and list of documents | |
Returns: | |
ChatResponse: Generated answer with confidence score | |
""" | |
try: | |
if not query.documents: | |
raise HTTPException(status_code=400, detail="No documents provided.") | |
# Encode the documents and the query | |
doc_embeddings = retriever.encode(query.documents, convert_to_tensor=True) | |
query_embedding = retriever.encode(query.question, convert_to_tensor=True) | |
# Compute cosine similarities | |
similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0] | |
top_doc_index = torch.argmax(similarities).item() | |
top_doc = query.documents[top_doc_index] | |
# Prepare input for the generator | |
input_text = f"question: {query.question} context: {top_doc}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Generate the answer | |
output_ids = generator.generate(inputs.input_ids, max_length=150, num_beams=5, early_stopping=True) | |
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return ChatResponse( | |
answer=answer, | |
confidence=similarities[top_doc_index].item() | |
) | |
except Exception as e: | |
logger.error(f"Error generating answer: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to generate answer: {str(e)}" | |
) | |
# Add health check with model status | |
async def health_check(): | |
try: | |
# Test models are loaded | |
return { | |
"status": "healthy", | |
"models": { | |
"retriever": retriever is not None, | |
"tokenizer": tokenizer is not None, | |
"generator": generator is not None | |
} | |
} | |
except Exception as e: | |
logger.error(f"Health check failed: {str(e)}") | |
return { | |
"status": "unhealthy", | |
"error": str(e) | |
} | |