Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import pipeline | |
from duckduckgo_search import DDGS | |
from typing import Optional, List, Dict, Any | |
import logging | |
import re | |
import spacy | |
from rake_nltk import Rake | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('punkt_tab') | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Enhanced Chat API with Dynamic Model Selection", | |
description="API with query classification and dynamic model loading for QA/Summarization", | |
version="2.0.0" | |
) | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["POST", "GET"], | |
allow_headers=["*"], | |
) | |
# Request/Response models | |
class ChatRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 500 | |
use_search: bool = False | |
temperature: float = 0.7 | |
class ChatResponse(BaseModel): | |
response: str | |
task_type: str | |
search_results: Optional[List[Dict[str, Any]]] = None | |
class SearchRequest(BaseModel): | |
query: str | |
max_results: int = 5 | |
# Global pipelines with lazy loading | |
classifier_pipeline = None | |
qa_pipeline = None | |
summarization_pipeline = None | |
nlp = spacy.load("en_core_web_sm") | |
rake = Rake() | |
def extract_search_terms(text: str) -> List[str]: | |
"""Extract enhanced search terms using combined NER, syntax, and keywords""" | |
doc = nlp(text) | |
# 1. Extract named entities | |
entities = [ent.text for ent in doc.ents] | |
# 2. Extract noun phrases through syntactic analysis | |
noun_phrases = list(doc.noun_chunks) | |
# 3. Extract question focus using dependency parsing | |
focus_phrase = extract_focus_phrase(doc) | |
# 4. Get keywords using RAKE | |
rake.extract_keywords_from_text(text) | |
keywords = rake.get_ranked_phrases()[:3] # Top 3 keywords | |
# Combine and filter terms | |
terms = entities + [np.text for np in noun_phrases] + keywords | |
if focus_phrase: | |
terms.append(focus_phrase) | |
# Clean and deduplicate | |
return clean_terms(terms) | |
def extract_focus_phrase(doc) -> str: | |
"""Extract main question focus using dependency parse""" | |
for token in doc: | |
if token.dep_ == "ROOT": | |
for child in token.children: | |
if child.dep_ in ("attr", "nsubj", "dobj"): | |
return " ".join([t.text for t in child.subtree]) | |
return "" | |
def clean_terms(terms: List[str]) -> List[str]: | |
"""Remove duplicates and irrelevant terms""" | |
# Remove stopwords and single characters | |
cleaned = [ | |
t for t in terms | |
if len(t) > 1 and not all(token.is_stop for token in nlp(t)) | |
] | |
# Remove redundant subphrases | |
final_terms = [] | |
for term in sorted(cleaned, key=len, reverse=True): | |
if not any(term in other for other in final_terms): | |
final_terms.append(term) | |
return final_terms | |
def load_classifier(): | |
"""Load zero-shot classification model""" | |
global classifier_pipeline | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
classifier_pipeline = pipeline( | |
"zero-shot-classification", | |
model="valhalla/distilbart-mnli-12-3", | |
device=0 if device == "cuda" else -1 | |
) | |
logger.info("Zero-shot classifier loaded") | |
except Exception as e: | |
logger.error(f"Classifier load error: {e}") | |
def load_qa_model(): | |
"""Load question-answering model on demand""" | |
global qa_pipeline | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="distilbert-base-uncased-distilled-squad", | |
device=0 if device == "cuda" else -1 | |
) | |
logger.info("QA model loaded") | |
except Exception as e: | |
logger.error(f"QA model load error: {e}") | |
def load_summarization_model(): | |
"""Load summarization model on demand""" | |
global summarization_pipeline | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
summarization_pipeline = pipeline( | |
"summarization", | |
model="sshleifer/distilbart-cnn-6-6", | |
device=0 if device == "cuda" else -1 | |
) | |
logger.info("Summarization model loaded") | |
except Exception as e: | |
logger.error(f"Summarization model load error: {e}") | |
def classify_query(prompt: str) -> str: | |
"""Classify query using zero-shot learning""" | |
candidate_labels = ['question answering', 'summarization'] | |
try: | |
result = classifier_pipeline( | |
prompt, | |
candidate_labels, | |
multi_label=False | |
) | |
return result['labels'][0] | |
except Exception as e: | |
logger.error(f"Classification failed: {e}") | |
return 'question answering' # Fallback to QA | |
def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]: | |
"""Enhanced web search with error handling""" | |
try: | |
with DDGS() as ddgs: | |
return [{ | |
"title": r.get("title", ""), | |
"body": re.sub(r'\s+', ' ', r.get("body", "")).strip(), | |
"href": r.get("href", "") | |
} for r in ddgs.text(query, safesearch='off', max_results=max_results)] | |
except Exception as e: | |
logger.error(f"Search error: {e}") | |
return [] | |
def format_search_context(results: List[Dict[str, Any]]) -> str: | |
"""Create condensed context from search results""" | |
return "\n".join( | |
f"{i+1}. {res['title']}: {res['body'][:200]}" | |
for i, res in enumerate(results[:5]) | |
) | |
def preprocess_text(text: str) -> str: | |
"""Use spaCy for fast text cleaning/normalization""" | |
doc = nlp(text) | |
# Lemmatize and remove stopwords | |
return " ".join([ | |
token.lemma_ for token in doc | |
if not token.is_stop and not token.is_punct | |
])[:1024] | |
async def startup_event(): | |
"""Initialize core models on startup""" | |
load_classifier() | |
load_qa_model() | |
load_summarization_model() | |
async def chat_endpoint(request: ChatRequest): | |
"""Enhanced chat endpoint with dynamic model selection""" | |
logger.info(f"Request: {request.prompt}") | |
try: | |
search_results = [] | |
search_context = "" | |
# Web search processing | |
if request.use_search: | |
search_terms = extract_search_terms(request.prompt.lower()) | |
search_query = " ".join(search_terms) or request.prompt | |
search_results = search_web(search_query) | |
search_context = format_search_context(search_results) | |
logger.info(f"Search Context: {search_context}") | |
# Query classification | |
task_type = classify_query(request.prompt) | |
logger.info(f"Classified task: {task_type}") | |
if task_type == 'question answering': | |
response = qa_pipeline( | |
question=request.prompt, | |
context=search_context or request.prompt, | |
max_answer_len=100 | |
)['answer'] | |
else: | |
response = summarization_pipeline( | |
search_context.lower() or request.prompt.lower(), | |
max_length=150, | |
min_length=30, | |
do_sample=False | |
)[0]['summary_text'] | |
return ChatResponse( | |
response=response, | |
task_type=task_type, | |
search_results=search_results if request.use_search else None | |
) | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def search_endpoint(request: SearchRequest): | |
"""Search endpoint with improved error handling""" | |
try: | |
return {"results": search_web(request.query, request.max_results)} | |
except Exception as e: | |
logger.error(f"Search endpoint error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
"""Enhanced health check with model status""" | |
return { | |
"message": "Open Source Chat API is running!", | |
"models": { | |
"classifier": bool(classifier_pipeline), | |
"qa": bool(qa_pipeline), | |
"summarization": bool(summarization_pipeline) | |
}, | |
"endpoints": { | |
"chat": "/chat", | |
"search": "/search", | |
"docs": "/docs" | |
} | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |