| import json | |
| from dataclasses import asdict | |
| from typing import List, Dict, Any, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from classifier.head import ClassifierHead | |
| from classifier.infer import predict_query | |
| from classifier.utils import get_models | |
| from retriever import Retriever | |
| from team.candidates import get_candidates, _available | |
| from config import settings | |
| class HealthQueryPipeline: | |
| def __init__(self, use_reranker: bool = False): | |
| self.use_reranker = use_reranker | |
| self.embedding_model: Optional[SentenceTransformer] = None | |
| self.classifier: Optional[ClassifierHead] = None | |
| self.retriever: Optional[Retriever] = None | |
| self.is_initialized = False | |
| def initialize(self): | |
| """Loads models and initializes the retriever.""" | |
| if self.is_initialized: | |
| return | |
| print(f"Loading embedding model: {settings.MODEL_NAME}...") | |
| self.embedding_model, self.classifier = get_models(model_id=settings.CLASSIFIER_NAME) | |
| print("Model loaded.") | |
| print("Initializing retriever...") | |
| cfg = _available(settings.CORPORA_CONFIG) | |
| if not cfg: | |
| raise RuntimeError("No corpora files found in data/corpora. Build them first.") | |
| self.retriever = Retriever( | |
| corpora_config=cfg, | |
| use_reranker=self.use_reranker, | |
| embedding_model=self.embedding_model | |
| ) | |
| print("Retriever initialized.") | |
| self.is_initialized = True | |
| def predict(self, query: str, k: int = 10) -> Dict[str, Any]: | |
| """ | |
| Runs the full pipeline: Classification -> Retrieval (if medical). | |
| """ | |
| if not self.is_initialized: | |
| self.initialize() | |
| classification = predict_query( | |
| text=[query], | |
| embedding_model=self.embedding_model, | |
| classifier_head=self.classifier, | |
| ) | |
| predictions = classification["prediction"] | |
| result = { | |
| "query": query, | |
| "classification": { | |
| "prediction": predictions[0], | |
| "probabilities": { | |
| cat: prob | |
| for cat, prob in zip(settings.CATEGORIES, classification['probabilities']) | |
| } | |
| }, | |
| "retrieval": [] | |
| } | |
| if "medical" in predictions: | |
| hits = get_candidates( | |
| query=query, | |
| retriever=self.retriever, | |
| k_retrieve=k, | |
| ) | |
| result["retrieval"] = [asdict(hit) for hit in hits] | |
| return result | |
| def get_index_progress(self): | |
| """Returns (current, total) of the underlying index.""" | |
| if not self.retriever: | |
| return 0, 0 | |
| return self.retriever.get_index_progress() | |