Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for the Clinical Trial Agent. | |
| Handles configuration, LanceDB index loading, data normalization, and custom filtering logic. | |
| """ | |
| import os | |
| import streamlit as st | |
| from typing import List, Optional | |
| from llama_index.core import VectorStoreIndex, StorageContext, Settings | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.vector_stores.lancedb import LanceDBVectorStore | |
| from llama_index.llms.gemini import Gemini | |
| from llama_index.core.postprocessor import SentenceTransformerRerank | |
| import lancedb | |
| from dotenv import load_dotenv | |
| def get_reranker(top_n: int = 50): | |
| """ | |
| Loads and caches the Reranker model. | |
| """ | |
| print("🔄 Loading Reranker Model (Cached)...") | |
| return SentenceTransformerRerank( | |
| model="cross-encoder/ms-marco-MiniLM-L-12-v2", | |
| top_n=top_n | |
| ) | |
| # --- MONKEYPATCH START --- | |
| # Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters. | |
| original_query = LanceDBVectorStore.query | |
| def patched_query(self, query, **kwargs): | |
| try: | |
| return original_query(self, query, **kwargs) | |
| except Exception as e: | |
| print(f"⚠️ LanceDB Query Error: {e}") | |
| if hasattr(query, "filters"): | |
| print(f" Filters: {query.filters}") | |
| if "nprobes" in str(e): | |
| from llama_index.core.vector_stores.types import VectorStoreQueryResult | |
| return VectorStoreQueryResult(nodes=[], similarities=[], ids=[]) | |
| raise e | |
| LanceDBVectorStore.query = patched_query | |
| # Patch _to_lance_filter to fix SQL quoting for IN operator with strings. | |
| from llama_index.vector_stores.lancedb import base as lancedb_base | |
| from llama_index.core.vector_stores.types import FilterOperator | |
| original_to_lance_filter = lancedb_base._to_lance_filter | |
| def patched_to_lance_filter(standard_filters, metadata_keys): | |
| if not standard_filters: | |
| return None | |
| # Reimplement filter logic to ensure correct SQL generation for LanceDB | |
| filters = [] | |
| for filter in standard_filters.filters: | |
| key = filter.key | |
| if metadata_keys and key not in metadata_keys: | |
| continue | |
| # Prefix key with 'metadata.' for LanceDB struct column | |
| lance_key = f"metadata.{key}" | |
| # Handle IN operator with proper string quoting | |
| if filter.operator == FilterOperator.IN: | |
| if isinstance(filter.value, list): | |
| # Quote strings properly | |
| values = [] | |
| for v in filter.value: | |
| if isinstance(v, str): | |
| values.append(f"'{v}'") # Single quotes for SQL | |
| else: | |
| values.append(str(v)) | |
| val_str = ", ".join(values) | |
| filters.append(f"{lance_key} IN ({val_str})") | |
| continue | |
| # Standard operators | |
| op = filter.operator | |
| val = filter.value | |
| if op == FilterOperator.EQ: | |
| if isinstance(val, str): | |
| filters.append(f"{lance_key} = '{val}'") | |
| else: | |
| filters.append(f"{lance_key} = {val}") | |
| elif op == FilterOperator.GT: | |
| filters.append(f"{lance_key} > {val}") | |
| elif op == FilterOperator.LT: | |
| filters.append(f"{lance_key} < {val}") | |
| elif op == FilterOperator.GTE: | |
| filters.append(f"{lance_key} >= {val}") | |
| elif op == FilterOperator.LTE: | |
| filters.append(f"{lance_key} <= {val}") | |
| elif op == FilterOperator.NE: | |
| if isinstance(val, str): | |
| filters.append(f"{lance_key} != '{val}'") | |
| else: | |
| filters.append(f"{lance_key} != {val}") | |
| # Add other operators as needed | |
| if not filters: | |
| return None | |
| return " AND ".join(filters) | |
| lancedb_base._to_lance_filter = patched_to_lance_filter | |
| # --- MONKEYPATCH END --- | |
| def load_environment(): | |
| """Loads environment variables from .env file.""" | |
| load_dotenv() | |
| # --- Configuration --- | |
| def init_embedding_model(): | |
| """Initializes and caches the embedding model globally.""" | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name="pritamdeka/S-PubMedBert-MS-MARCO", | |
| device="cpu" | |
| ) | |
| def setup_llama_index(api_key: Optional[str] = None): | |
| """ | |
| Configures global LlamaIndex settings (LLM). | |
| Embedding model is handled by init_embedding_model(). | |
| """ | |
| # Ensure embedding model is loaded | |
| init_embedding_model() | |
| # Use passed key, or fallback to env var | |
| final_key = api_key or os.environ.get("GOOGLE_API_KEY") | |
| if not final_key: | |
| return | |
| try: | |
| # Pass the key explicitly if available | |
| Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key) | |
| except Exception as e: | |
| print(f"⚠️ LLM initialization failed: {e}") | |
| from llama_index.core.llms import MockLLM | |
| Settings.llm = MockLLM() | |
| def load_index() -> VectorStoreIndex: | |
| """ | |
| Loads and caches the persistent LanceDB index. | |
| """ | |
| # setup_llama_index() <-- REMOVED: App handles setup. Calling here resets LLM if env var is missing. | |
| # Initialize LanceDB | |
| db_path = "./ct_gov_lancedb" | |
| db = lancedb.connect(db_path) | |
| # Define metadata keys explicitly to ensure filters work | |
| metadata_keys = [ | |
| "nct_id", "title", "org", "sponsor", "status", "phase", | |
| "study_type", "start_year", "condition", "intervention", | |
| "country", "state" | |
| ] | |
| # Create the vector store wrapper | |
| vector_store = LanceDBVectorStore( | |
| uri=db_path, | |
| table_name="clinical_trials", | |
| query_mode="hybrid", | |
| ) | |
| # Manually set metadata keys as constructor doesn't accept them | |
| vector_store._metadata_keys = metadata_keys | |
| # Create storage context | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| # Load the index from the vector store | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store, storage_context=storage_context | |
| ) | |
| return index | |
| def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None): | |
| """ | |
| Creates a Hybrid Retriever using LanceDB's native hybrid search. | |
| Args: | |
| index (VectorStoreIndex): The loaded vector index. | |
| similarity_top_k (int): Number of top results to retrieve. | |
| filters (MetadataFilters, optional): Filters to apply. | |
| Returns: | |
| VectorIndexRetriever: The configured retriever. | |
| """ | |
| # LanceDB supports native hybrid search via query_mode="hybrid" | |
| # We pass this configuration to the retriever | |
| # Use standard retriever first to avoid LanceDB hybrid search issues on small datasets | |
| return index.as_retriever( | |
| similarity_top_k=similarity_top_k, | |
| filters=filters, | |
| ) | |
| # --- Normalization --- | |
| # Centralized Sponsor Mappings | |
| # Key: Canonical Name | |
| # Value: List of variations/aliases (including the canonical name itself if needed for matching) | |
| SPONSOR_MAPPINGS = { | |
| "GlaxoSmithKline": [ | |
| "gsk", "glaxo", "glaxosmithkline", "glaxosmithkline", | |
| "GlaxoSmithKline" | |
| ], | |
| "Janssen": [ | |
| "j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen", | |
| "Janssen Research & Development, LLC", | |
| "Janssen Vaccines & Prevention B.V.", | |
| "Janssen Pharmaceutical K.K.", | |
| "Janssen-Cilag International NV", | |
| "Janssen Sciences Ireland UC", | |
| "Janssen Pharmaceutica N.V., Belgium", | |
| "Janssen Scientific Affairs, LLC", | |
| "Janssen-Cilag Ltd.", | |
| "Xian-Janssen Pharmaceutical Ltd.", | |
| "Janssen Korea, Ltd., Korea", | |
| "Janssen-Cilag G.m.b.H", | |
| "Janssen-Cilag, S.A.", | |
| "Janssen BioPharma, Inc.", | |
| ], | |
| "Bristol-Myers Squibb": [ | |
| "bms", "bristol", "bristol myers squibb", "bristol-myers squibb", | |
| "Bristol-Myers Squibb" | |
| ], | |
| "Merck Sharp & Dohme": [ | |
| "merck", "msd", "merck sharp & dohme", | |
| "Merck Sharp & Dohme LLC" | |
| ], | |
| "Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."], | |
| "AstraZeneca": ["astrazeneca", "AstraZeneca"], | |
| "Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"], | |
| "Sanofi": ["sanofi", "Sanofi"], | |
| "Novartis": ["novartis", "Novartis"], | |
| } | |
| def normalize_sponsor(sponsor: str) -> Optional[str]: | |
| """ | |
| Normalizes sponsor names to canonical forms using centralized mappings. | |
| """ | |
| if not sponsor: | |
| return None | |
| s = sponsor.lower().strip() | |
| for canonical, variations in SPONSOR_MAPPINGS.items(): | |
| # Check if input matches canonical name (case-insensitive) | |
| if s == canonical.lower(): | |
| return canonical | |
| # Check variations and aliases | |
| for v in variations: | |
| v_lower = v.lower() | |
| if v_lower == s: | |
| return canonical | |
| # If the variation is a known alias (like 'gsk'), check if it's in the string | |
| if len(v) < 5 and v_lower in s: | |
| return canonical | |
| if canonical.lower() in s: | |
| return canonical | |
| return sponsor | |
| def get_sponsor_variations(sponsor: str) -> Optional[List[str]]: | |
| """ | |
| Returns list of exact database 'org' values for a given sponsor alias. | |
| """ | |
| if not sponsor: | |
| return None | |
| # First, normalize the input to get the canonical name | |
| canonical = normalize_sponsor(sponsor) | |
| if canonical in SPONSOR_MAPPINGS: | |
| return SPONSOR_MAPPINGS[canonical] | |
| return None | |