Spaces:
Sleeping
Sleeping
| """ | |
| Medical Q&A UI - BM25 + Dense Retrieval Models WITH DISK CACHING | |
| This version caches the indexes to disk for fast startup (30 seconds vs 5-8 minutes!) | |
| """ | |
| import gradio as gr | |
| from typing import Dict, List | |
| from pathlib import Path | |
| import pickle | |
| import hashlib | |
| import json | |
| import subprocess | |
| import sys | |
| from retriever.index_bm25 import BM25Index | |
| from retriever.index_dense import DenseIndex | |
| from retriever.ingest import load_jsonl | |
| from retriever.rrf import rrf | |
| from team.interfaces import Candidate | |
| # Cache directory | |
| CACHE_DIR = Path("cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| # Ensure embeddings cache directory exists (for Dense index) | |
| EMBEDDINGS_CACHE_DIR = Path(".cache/embeddings") | |
| EMBEDDINGS_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| print("=" * 70) | |
| print(" Medical Document Retrieval System (CACHED VERSION)") | |
| print(" Using BM25 + Dense Embeddings + RRF Fusion") | |
| print(" With disk caching for fast startup!") | |
| print("=" * 70) | |
| def _ensure_corpora_exist(): | |
| """Build corpora files if they don't exist""" | |
| data_dir = Path("data/corpora") | |
| required_files = [ | |
| data_dir / "medical_qa.jsonl", | |
| data_dir / "miriad_text.jsonl", | |
| data_dir / "unidoc_qa.jsonl" | |
| ] | |
| if all(f.exists() for f in required_files): | |
| return # All files exist | |
| print("\n" + "=" * 70) | |
| print("⚠️ Corpora files not found. Building them now...") | |
| print(" This will take 2-3 minutes on first launch.") | |
| print("=" * 70 + "\n") | |
| try: | |
| # Run build_corpora.py | |
| subprocess.run( | |
| [sys.executable, "adapters/build_corpora.py"], | |
| check=True, | |
| capture_output=False | |
| ) | |
| print("\n✓ Corpora files built successfully!\n") | |
| except subprocess.CalledProcessError as e: | |
| print(f"\n✗ Failed to build corpora: {e}") | |
| raise RuntimeError("Could not build corpora files. Please run 'python adapters/build_corpora.py' manually.") | |
| def _default_corpora_config() -> Dict[str, dict]: | |
| return { | |
| "medical_qa": {"path": "data/corpora/medical_qa.jsonl", | |
| "text_fields": ["question", "answer", "title"]}, | |
| "miriad": {"path": "data/corpora/miriad_text.jsonl", | |
| "text_fields": ["question", "answer", "title"]}, | |
| "unidoc": {"path": "data/corpora/unidoc_qa.jsonl", | |
| "text_fields": ["question", "answer", "title"]}, | |
| } | |
| def _available(cfg: Dict[str, dict]) -> Dict[str, dict]: | |
| return {k: v for k, v in cfg.items() if Path(v["path"]).exists()} | |
| def _get_cache_key(corpora_config: Dict[str, dict]) -> str: | |
| """Generate a unique cache key based on corpora config""" | |
| config_str = json.dumps(corpora_config, sort_keys=True) | |
| return hashlib.md5(config_str.encode()).hexdigest() | |
| class CachedRetriever: | |
| """Retriever with disk caching for BM25 and documents (Dense has its own caching)""" | |
| def __init__(self, corpora_config: Dict[str, dict], use_reranker: bool = False): | |
| self.corpora_config = corpora_config | |
| self.use_reranker = use_reranker | |
| self.cache_key = _get_cache_key(corpora_config) | |
| # Cache file paths | |
| self.bm25_cache = CACHE_DIR / f"bm25_{self.cache_key}.pkl" | |
| self.docs_cache = CACHE_DIR / f"docs_{self.cache_key}.pkl" | |
| # Note: Dense index uses its own caching in .cache/embeddings/ | |
| # Load or build indexes | |
| self.docs_all = self._load_or_build_docs() | |
| self.bm25 = self._load_or_build_bm25() | |
| self.dense = self._load_or_build_dense() | |
| def _load_or_build_docs(self) -> List: | |
| """Load documents from cache or build from scratch""" | |
| if self.docs_cache.exists(): | |
| print(f"Loading documents from cache... ({self.docs_cache.name})") | |
| try: | |
| with open(self.docs_cache, 'rb') as f: | |
| docs_all = pickle.load(f) | |
| print(f" ✓ Loaded {len(docs_all)} documents from cache") | |
| return docs_all | |
| except Exception as e: | |
| print(f" ✗ Cache load failed: {e}") | |
| print(" → Rebuilding documents...") | |
| print("Building documents from corpora files...") | |
| docs_all = [] | |
| for name, cfg in self.corpora_config.items(): | |
| print(f" Loading {name}...") | |
| docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question", "answer")))) | |
| docs_all.extend(docs) | |
| # Save to cache | |
| print(f"Saving documents to cache... ({len(docs_all)} docs)") | |
| with open(self.docs_cache, 'wb') as f: | |
| pickle.dump(docs_all, f) | |
| return docs_all | |
| def _load_or_build_bm25(self) -> BM25Index: | |
| """Load BM25 index from cache or build from scratch""" | |
| if self.bm25_cache.exists(): | |
| print(f"Loading BM25 index from cache... ({self.bm25_cache.name})") | |
| try: | |
| with open(self.bm25_cache, 'rb') as f: | |
| bm25_index = pickle.load(f) | |
| print(f" ✓ BM25 index loaded from cache") | |
| return bm25_index | |
| except Exception as e: | |
| print(f" ✗ Cache load failed: {e}") | |
| print(" → Rebuilding BM25 index...") | |
| print("Building BM25 index from scratch...") | |
| bm25_index = BM25Index(self.docs_all) | |
| # Save to cache | |
| print(f"Saving BM25 index to cache...") | |
| with open(self.bm25_cache, 'wb') as f: | |
| pickle.dump(bm25_index, f) | |
| return bm25_index | |
| def _load_or_build_dense(self) -> DenseIndex: | |
| """Build Dense index (note: Dense index has its own internal caching)""" | |
| print("Initializing Dense index (uses internal caching)...") | |
| # DenseIndex has its own caching system in .cache/embeddings/ | |
| # We don't need to pickle it - just let it build/load from its own cache | |
| dense_index = DenseIndex(self.docs_all) | |
| print(f" ✓ Dense index ready") | |
| return dense_index | |
| # Ensure corpora files exist (auto-build if missing) | |
| _ensure_corpora_exist() | |
| # Initialize cached retriever (fast if cached, slow first time) | |
| print("\nInitializing retrieval system...") | |
| cfg = _available(_default_corpora_config()) | |
| if not cfg: | |
| raise RuntimeError("No corpora files found in data/corpora. Build them first.") | |
| retriever = CachedRetriever(corpora_config=cfg, use_reranker=False) | |
| print("\n✓ Retrieval system ready!") | |
| print(f" Total documents indexed: {len(retriever.docs_all):,}") | |
| print("=" * 70) | |
| def get_candidates_cached(query: str, k_retrieve: int = 50) -> List[Candidate]: | |
| """ | |
| Returns top-N fused candidates with component scores (bm25, dense, rrf). | |
| Uses the cached retriever for fast queries. | |
| """ | |
| # Get separate result lists (doc, score) | |
| bm = retriever.bm25.search(query, k=max(k_retrieve, 100)) | |
| de = retriever.dense.search(query, k=max(k_retrieve, 100)) | |
| # Maps for score lookup | |
| bm_map = {d.id: float(s) for d, s in bm} | |
| de_map = {d.id: float(s) for d, s in de} | |
| # Fuse and pick candidate set | |
| fused = rrf([bm, de], k=max(k_retrieve, 50)) | |
| # Compute RRF per candidate using rank positions | |
| K = 60 | |
| bm_rank = {d.id: i for i, (d, _) in enumerate(bm)} | |
| de_rank = {d.id: i for i, (d, _) in enumerate(de)} | |
| out: List[Candidate] = [] | |
| for doc, _ in fused[:k_retrieve]: | |
| rrf_score = 0.0 | |
| if doc.id in bm_rank: | |
| rrf_score += 1.0 / (K + bm_rank[doc.id] + 1) | |
| if doc.id in de_rank: | |
| rrf_score += 1.0 / (K + de_rank[doc.id] + 1) | |
| out.append(Candidate( | |
| id=doc.id, | |
| title=doc.title or "", | |
| text=doc.text, | |
| meta=doc.meta or {}, | |
| bm25=bm_map.get(doc.id, 0.0), | |
| dense=de_map.get(doc.id, 0.0), | |
| rrf=rrf_score, | |
| )) | |
| # Baseline order: RRF | |
| out.sort(key=lambda c: c.rrf, reverse=True) | |
| return out | |
| def retrieve_documents(query, num_results=5): | |
| """Retrieve relevant medical documents using your team's models""" | |
| if not query or not query.strip(): | |
| return """ | |
| <div style="padding: 20px; background-color: #e7f3ff; border-radius: 10px; border-left: 5px solid #2196f3;"> | |
| <h3 style="margin-top: 0; color: #0d47a1;">How to Use</h3> | |
| <p style="margin: 0; color: #1565c0;">Enter a medical query and we'll find relevant documents using BM25 + Dense retrieval with RRF fusion.</p> | |
| <p style="margin: 8px 0 0 0; color: #1565c0;"><strong>Example:</strong> "headache with blurred vision" or "symptoms of diabetes"</p> | |
| </div> | |
| """ | |
| try: | |
| # Use cached retrieval system (fast!) | |
| hits = get_candidates_cached(query=query, k_retrieve=num_results) | |
| if not hits: | |
| return """ | |
| <div style="padding: 20px; background-color: #fff3cd; border-radius: 10px; border-left: 5px solid #ffc107;"> | |
| <h3 style="margin-top: 0; color: #856404;">No Results Found</h3> | |
| <p style="margin: 0; color: #856404;">Try rephrasing your query or using different medical terms.</p> | |
| </div> | |
| """ | |
| # Build results HTML | |
| result_html = f""" | |
| <div style="padding: 15px; background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%); border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #28a745;"> | |
| <h3 style="margin-top: 0; color: #155724;">Found {len(hits)} Relevant Medical Documents</h3> | |
| <p style="margin: 0;"><strong>Retrieved using:</strong> BM25 + Dense Embeddings + RRF Fusion (CACHED)</p> | |
| </div> | |
| """ | |
| for i, hit in enumerate(hits, 1): | |
| title = hit.title if hit.title and hit.title.strip() else None | |
| source = hit.meta.get('source', 'Unknown') if hit.meta else 'Unknown' | |
| # Check if we have separate question/answer fields in metadata | |
| question = hit.meta.get('question', '') if hit.meta else '' | |
| answer = hit.meta.get('answer', '') if hit.meta else '' | |
| # If we have separate Q&A, format them nicely | |
| if question and answer: | |
| content_html = f""" | |
| <div style="margin-bottom: 12px;"> | |
| <strong style="color: #1976d2;">Question:</strong> | |
| <p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{question}</p> | |
| </div> | |
| <div> | |
| <strong style="color: #388e3c;">Answer:</strong> | |
| <p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{answer[:500] + ("..." if len(answer) > 500 else "")}</p> | |
| </div> | |
| """ | |
| else: | |
| # Fallback to combined text | |
| text = hit.text[:500] + ("..." if len(hit.text) > 500 else "") | |
| content_html = f'<p style="margin: 0; line-height: 1.7; color: #34495e;">{text}</p>' | |
| # Display relevance scores | |
| bm25_score = hit.bm25 | |
| dense_score = hit.dense | |
| rrf_score = hit.rrf | |
| # Build title HTML only if title exists | |
| title_html = f'<h4 style="margin: 0 0 15px 0; color: #2c3e50;">{title}</h4>' if title else '' | |
| result_html += f""" | |
| <div style="border: 2px solid #dee2e6; padding: 20px; margin: 20px 0; border-radius: 10px; background-color: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; margin: -20px -20px 20px -20px; border-radius: 8px 8px 0 0;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center;"> | |
| <h4 style="margin: 0; color: white;">Document #{i}</h4> | |
| <span style="background-color: rgba(255,255,255,0.2); padding: 4px 12px; border-radius: 12px; font-size: 0.85em; color: white;"> | |
| {source} | |
| </span> | |
| </div> | |
| </div> | |
| <div style="margin-bottom: 15px;"> | |
| {title_html} | |
| {content_html} | |
| </div> | |
| <div style="padding-top: 12px; border-top: 1px solid #e9ecef;"> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 10px;"> | |
| <div style="background-color: #e3f2fd; padding: 8px; border-radius: 5px; text-align: center;"> | |
| <div style="font-size: 0.75em; color: #1976d2; font-weight: bold;">BM25</div> | |
| <div style="font-size: 1.1em; color: #0d47a1;">{bm25_score:.4f}</div> | |
| </div> | |
| <div style="background-color: #f3e5f5; padding: 8px; border-radius: 5px; text-align: center;"> | |
| <div style="font-size: 0.75em; color: #7b1fa2; font-weight: bold;">Dense</div> | |
| <div style="font-size: 1.1em; color: #4a148c;">{dense_score:.4f}</div> | |
| </div> | |
| <div style="background-color: #e8f5e9; padding: 8px; border-radius: 5px; text-align: center;"> | |
| <div style="font-size: 0.75em; color: #388e3c; font-weight: bold;">RRF Fusion</div> | |
| <div style="font-size: 1.1em; color: #1b5e20;">{rrf_score:.4f}</div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| return result_html | |
| except Exception as e: | |
| return f""" | |
| <div style="padding: 20px; background-color: #f8d7da; border-radius: 10px; border-left: 5px solid #dc3545;"> | |
| <h3 style="margin-top: 0; color: #721c24;">Error</h3> | |
| <p style="margin: 0; color: #721c24;">{str(e)}</p> | |
| </div> | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(title="Medical Document Retrieval (Cached)") as demo: | |
| gr.Markdown(""" | |
| # Medical Document Retrieval System (CACHED VERSION) | |
| **Models:** | |
| - BM25 Index (keyword-based retrieval) | |
| - Dense Embeddings (embeddinggemma-300m-medical) | |
| - RRF Fusion (combines both approaches) | |
| ### Features: | |
| - Searches across 10,000+ medical documents | |
| - Shows relevance scores from each model component | |
| - Returns the most relevant medical information | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox( | |
| label="Enter your medical query", | |
| placeholder="Example: headache with blurred vision", | |
| lines=2 | |
| ) | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of results to retrieve" | |
| ) | |
| submit_btn = gr.Button("Retrieve Documents", variant="primary", size="lg") | |
| output_html = gr.HTML(label="Search Results") | |
| submit_btn.click( | |
| fn=retrieve_documents, | |
| inputs=[query_input, num_results], | |
| outputs=output_html | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "headache with blurred vision", | |
| "symptoms of diabetes", | |
| "chest pain when exercising", | |
| "treatment for high blood pressure", | |
| "causes of chronic fatigue", | |
| ], | |
| inputs=query_input, | |
| label="Try these example queries:" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### Technical Details | |
| - **BM25**: Statistical keyword matching (TF-IDF based) | |
| - **Dense**: Semantic search using transformer embeddings | |
| - **RRF Fusion**: Reciprocal Rank Fusion combines both methods | |
| - **Caching**: Indexes saved to disk in `cache/` folder for fast reloading | |
| *Note: First launch builds and caches indexes (5-8 min). After that, startup takes only ~30 seconds!* | |
| """) | |
| print("\nOpening web interface...") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| # Auto-detect environment: HuggingFace Spaces vs local | |
| import os | |
| is_spaces = os.getenv("SPACE_ID") is not None | |
| if is_spaces: | |
| # HuggingFace Spaces: listen on all interfaces, default port | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| else: | |
| # Local: standard config | |
| print(" Local access: http://127.0.0.1:7863") | |
| demo.launch(server_name="127.0.0.1", server_port=7863) | |