from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import pipeline from duckduckgo_search import DDGS from typing import Optional, List, Dict, Any import logging import re import spacy from rake_nltk import Rake import nltk nltk.download('stopwords') nltk.download('punkt_tab') # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Enhanced Chat API with Dynamic Model Selection", description="API with query classification and dynamic model loading for QA/Summarization", version="2.0.0" ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST", "GET"], allow_headers=["*"], ) # Request/Response models class ChatRequest(BaseModel): prompt: str max_new_tokens: int = 500 use_search: bool = False temperature: float = 0.7 class ChatResponse(BaseModel): response: str task_type: str search_results: Optional[List[Dict[str, Any]]] = None class SearchRequest(BaseModel): query: str max_results: int = 5 # Global pipelines with lazy loading classifier_pipeline = None qa_pipeline = None summarization_pipeline = None nlp = spacy.load("en_core_web_sm") rake = Rake() def extract_search_terms(text: str) -> List[str]: """Extract enhanced search terms using combined NER, syntax, and keywords""" doc = nlp(text) # 1. Extract named entities entities = [ent.text for ent in doc.ents] # 2. Extract noun phrases through syntactic analysis noun_phrases = list(doc.noun_chunks) # 3. Extract question focus using dependency parsing focus_phrase = extract_focus_phrase(doc) # 4. Get keywords using RAKE rake.extract_keywords_from_text(text) keywords = rake.get_ranked_phrases()[:3] # Top 3 keywords # Combine and filter terms terms = entities + [np.text for np in noun_phrases] + keywords if focus_phrase: terms.append(focus_phrase) # Clean and deduplicate return clean_terms(terms) def extract_focus_phrase(doc) -> str: """Extract main question focus using dependency parse""" for token in doc: if token.dep_ == "ROOT": for child in token.children: if child.dep_ in ("attr", "nsubj", "dobj"): return " ".join([t.text for t in child.subtree]) return "" def clean_terms(terms: List[str]) -> List[str]: """Remove duplicates and irrelevant terms""" # Remove stopwords and single characters cleaned = [ t for t in terms if len(t) > 1 and not all(token.is_stop for token in nlp(t)) ] # Remove redundant subphrases final_terms = [] for term in sorted(cleaned, key=len, reverse=True): if not any(term in other for other in final_terms): final_terms.append(term) return final_terms def load_classifier(): """Load zero-shot classification model""" global classifier_pipeline try: device = "cuda" if torch.cuda.is_available() else "cpu" classifier_pipeline = pipeline( "zero-shot-classification", model="valhalla/distilbart-mnli-12-3", device=0 if device == "cuda" else -1 ) logger.info("Zero-shot classifier loaded") except Exception as e: logger.error(f"Classifier load error: {e}") def load_qa_model(): """Load question-answering model on demand""" global qa_pipeline try: device = "cuda" if torch.cuda.is_available() else "cpu" qa_pipeline = pipeline( "question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if device == "cuda" else -1 ) logger.info("QA model loaded") except Exception as e: logger.error(f"QA model load error: {e}") def load_summarization_model(): """Load summarization model on demand""" global summarization_pipeline try: device = "cuda" if torch.cuda.is_available() else "cpu" summarization_pipeline = pipeline( "summarization", model="sshleifer/distilbart-cnn-6-6", device=0 if device == "cuda" else -1 ) logger.info("Summarization model loaded") except Exception as e: logger.error(f"Summarization model load error: {e}") def classify_query(prompt: str) -> str: """Classify query using zero-shot learning""" candidate_labels = ['question answering', 'summarization'] try: result = classifier_pipeline( prompt, candidate_labels, multi_label=False ) return result['labels'][0] except Exception as e: logger.error(f"Classification failed: {e}") return 'question answering' # Fallback to QA def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]: """Enhanced web search with error handling""" try: with DDGS() as ddgs: return [{ "title": r.get("title", ""), "body": re.sub(r'\s+', ' ', r.get("body", "")).strip(), "href": r.get("href", "") } for r in ddgs.text(query, safesearch='off', max_results=max_results)] except Exception as e: logger.error(f"Search error: {e}") return [] def format_search_context(results: List[Dict[str, Any]]) -> str: """Create condensed context from search results""" return "\n".join( f"{i+1}. {res['title']}: {res['body'][:200]}" for i, res in enumerate(results[:5]) ) def preprocess_text(text: str) -> str: """Use spaCy for fast text cleaning/normalization""" doc = nlp(text) # Lemmatize and remove stopwords return " ".join([ token.lemma_ for token in doc if not token.is_stop and not token.is_punct ])[:1024] @app.on_event("startup") async def startup_event(): """Initialize core models on startup""" load_classifier() load_qa_model() load_summarization_model() @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): """Enhanced chat endpoint with dynamic model selection""" logger.info(f"Request: {request.prompt}") try: search_results = [] search_context = "" # Web search processing if request.use_search: search_terms = extract_search_terms(request.prompt.lower()) search_query = " ".join(search_terms) or request.prompt search_results = search_web(search_query) search_context = format_search_context(search_results) logger.info(f"Search Context: {search_context}") # Query classification task_type = classify_query(request.prompt) logger.info(f"Classified task: {task_type}") if task_type == 'question answering': response = qa_pipeline( question=request.prompt, context=search_context or request.prompt, max_answer_len=100 )['answer'] else: response = summarization_pipeline( search_context.lower() or request.prompt.lower(), max_length=150, min_length=30, do_sample=False )[0]['summary_text'] return ChatResponse( response=response, task_type=task_type, search_results=search_results if request.use_search else None ) except Exception as e: logger.error(f"Chat error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/search") async def search_endpoint(request: SearchRequest): """Search endpoint with improved error handling""" try: return {"results": search_web(request.query, request.max_results)} except Exception as e: logger.error(f"Search endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): """Enhanced health check with model status""" return { "message": "Open Source Chat API is running!", "models": { "classifier": bool(classifier_pipeline), "qa": bool(qa_pipeline), "summarization": bool(summarization_pipeline) }, "endpoints": { "chat": "/chat", "search": "/search", "docs": "/docs" } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)