Atlas / app.py
findEthics
Remove text preprocessing for summarisation model
9245669
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import pipeline
from duckduckgo_search import DDGS
from typing import Optional, List, Dict, Any
import logging
import re
import spacy
from rake_nltk import Rake
import nltk
nltk.download('stopwords')
nltk.download('punkt_tab')
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Enhanced Chat API with Dynamic Model Selection",
description="API with query classification and dynamic model loading for QA/Summarization",
version="2.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
# Request/Response models
class ChatRequest(BaseModel):
prompt: str
max_new_tokens: int = 500
use_search: bool = False
temperature: float = 0.7
class ChatResponse(BaseModel):
response: str
task_type: str
search_results: Optional[List[Dict[str, Any]]] = None
class SearchRequest(BaseModel):
query: str
max_results: int = 5
# Global pipelines with lazy loading
classifier_pipeline = None
qa_pipeline = None
summarization_pipeline = None
nlp = spacy.load("en_core_web_sm")
rake = Rake()
def extract_search_terms(text: str) -> List[str]:
"""Extract enhanced search terms using combined NER, syntax, and keywords"""
doc = nlp(text)
# 1. Extract named entities
entities = [ent.text for ent in doc.ents]
# 2. Extract noun phrases through syntactic analysis
noun_phrases = list(doc.noun_chunks)
# 3. Extract question focus using dependency parsing
focus_phrase = extract_focus_phrase(doc)
# 4. Get keywords using RAKE
rake.extract_keywords_from_text(text)
keywords = rake.get_ranked_phrases()[:3] # Top 3 keywords
# Combine and filter terms
terms = entities + [np.text for np in noun_phrases] + keywords
if focus_phrase:
terms.append(focus_phrase)
# Clean and deduplicate
return clean_terms(terms)
def extract_focus_phrase(doc) -> str:
"""Extract main question focus using dependency parse"""
for token in doc:
if token.dep_ == "ROOT":
for child in token.children:
if child.dep_ in ("attr", "nsubj", "dobj"):
return " ".join([t.text for t in child.subtree])
return ""
def clean_terms(terms: List[str]) -> List[str]:
"""Remove duplicates and irrelevant terms"""
# Remove stopwords and single characters
cleaned = [
t for t in terms
if len(t) > 1 and not all(token.is_stop for token in nlp(t))
]
# Remove redundant subphrases
final_terms = []
for term in sorted(cleaned, key=len, reverse=True):
if not any(term in other for other in final_terms):
final_terms.append(term)
return final_terms
def load_classifier():
"""Load zero-shot classification model"""
global classifier_pipeline
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
classifier_pipeline = pipeline(
"zero-shot-classification",
model="valhalla/distilbart-mnli-12-3",
device=0 if device == "cuda" else -1
)
logger.info("Zero-shot classifier loaded")
except Exception as e:
logger.error(f"Classifier load error: {e}")
def load_qa_model():
"""Load question-answering model on demand"""
global qa_pipeline
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-uncased-distilled-squad",
device=0 if device == "cuda" else -1
)
logger.info("QA model loaded")
except Exception as e:
logger.error(f"QA model load error: {e}")
def load_summarization_model():
"""Load summarization model on demand"""
global summarization_pipeline
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
summarization_pipeline = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-6-6",
device=0 if device == "cuda" else -1
)
logger.info("Summarization model loaded")
except Exception as e:
logger.error(f"Summarization model load error: {e}")
def classify_query(prompt: str) -> str:
"""Classify query using zero-shot learning"""
candidate_labels = ['question answering', 'summarization']
try:
result = classifier_pipeline(
prompt,
candidate_labels,
multi_label=False
)
return result['labels'][0]
except Exception as e:
logger.error(f"Classification failed: {e}")
return 'question answering' # Fallback to QA
def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""Enhanced web search with error handling"""
try:
with DDGS() as ddgs:
return [{
"title": r.get("title", ""),
"body": re.sub(r'\s+', ' ', r.get("body", "")).strip(),
"href": r.get("href", "")
} for r in ddgs.text(query, safesearch='off', max_results=max_results)]
except Exception as e:
logger.error(f"Search error: {e}")
return []
def format_search_context(results: List[Dict[str, Any]]) -> str:
"""Create condensed context from search results"""
return "\n".join(
f"{i+1}. {res['title']}: {res['body'][:200]}"
for i, res in enumerate(results[:5])
)
def preprocess_text(text: str) -> str:
"""Use spaCy for fast text cleaning/normalization"""
doc = nlp(text)
# Lemmatize and remove stopwords
return " ".join([
token.lemma_ for token in doc
if not token.is_stop and not token.is_punct
])[:1024]
@app.on_event("startup")
async def startup_event():
"""Initialize core models on startup"""
load_classifier()
load_qa_model()
load_summarization_model()
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""Enhanced chat endpoint with dynamic model selection"""
logger.info(f"Request: {request.prompt}")
try:
search_results = []
search_context = ""
# Web search processing
if request.use_search:
search_terms = extract_search_terms(request.prompt.lower())
search_query = " ".join(search_terms) or request.prompt
search_results = search_web(search_query)
search_context = format_search_context(search_results)
logger.info(f"Search Context: {search_context}")
# Query classification
task_type = classify_query(request.prompt)
logger.info(f"Classified task: {task_type}")
if task_type == 'question answering':
response = qa_pipeline(
question=request.prompt,
context=search_context or request.prompt,
max_answer_len=100
)['answer']
else:
response = summarization_pipeline(
search_context.lower() or request.prompt.lower(),
max_length=150,
min_length=30,
do_sample=False
)[0]['summary_text']
return ChatResponse(
response=response,
task_type=task_type,
search_results=search_results if request.use_search else None
)
except Exception as e:
logger.error(f"Chat error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/search")
async def search_endpoint(request: SearchRequest):
"""Search endpoint with improved error handling"""
try:
return {"results": search_web(request.query, request.max_results)}
except Exception as e:
logger.error(f"Search endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Enhanced health check with model status"""
return {
"message": "Open Source Chat API is running!",
"models": {
"classifier": bool(classifier_pipeline),
"qa": bool(qa_pipeline),
"summarization": bool(summarization_pipeline)
},
"endpoints": {
"chat": "/chat",
"search": "/search",
"docs": "/docs"
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)