Geoffrey Kip
Fix: Cache embedding model initialization to prevent concurrency crashes
7b7595c
"""
Utility functions for the Clinical Trial Agent.
Handles configuration, LanceDB index loading, data normalization, and custom filtering logic.
"""
import os
import streamlit as st
from typing import List, Optional
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.llms.gemini import Gemini
from llama_index.core.postprocessor import SentenceTransformerRerank
import lancedb
from dotenv import load_dotenv
@st.cache_resource
def get_reranker(top_n: int = 50):
"""
Loads and caches the Reranker model.
"""
print("🔄 Loading Reranker Model (Cached)...")
return SentenceTransformerRerank(
model="cross-encoder/ms-marco-MiniLM-L-12-v2",
top_n=top_n
)
# --- MONKEYPATCH START ---
# Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters.
original_query = LanceDBVectorStore.query
def patched_query(self, query, **kwargs):
try:
return original_query(self, query, **kwargs)
except Exception as e:
print(f"⚠️ LanceDB Query Error: {e}")
if hasattr(query, "filters"):
print(f" Filters: {query.filters}")
if "nprobes" in str(e):
from llama_index.core.vector_stores.types import VectorStoreQueryResult
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
raise e
LanceDBVectorStore.query = patched_query
# Patch _to_lance_filter to fix SQL quoting for IN operator with strings.
from llama_index.vector_stores.lancedb import base as lancedb_base
from llama_index.core.vector_stores.types import FilterOperator
original_to_lance_filter = lancedb_base._to_lance_filter
def patched_to_lance_filter(standard_filters, metadata_keys):
if not standard_filters:
return None
# Reimplement filter logic to ensure correct SQL generation for LanceDB
filters = []
for filter in standard_filters.filters:
key = filter.key
if metadata_keys and key not in metadata_keys:
continue
# Prefix key with 'metadata.' for LanceDB struct column
lance_key = f"metadata.{key}"
# Handle IN operator with proper string quoting
if filter.operator == FilterOperator.IN:
if isinstance(filter.value, list):
# Quote strings properly
values = []
for v in filter.value:
if isinstance(v, str):
values.append(f"'{v}'") # Single quotes for SQL
else:
values.append(str(v))
val_str = ", ".join(values)
filters.append(f"{lance_key} IN ({val_str})")
continue
# Standard operators
op = filter.operator
val = filter.value
if op == FilterOperator.EQ:
if isinstance(val, str):
filters.append(f"{lance_key} = '{val}'")
else:
filters.append(f"{lance_key} = {val}")
elif op == FilterOperator.GT:
filters.append(f"{lance_key} > {val}")
elif op == FilterOperator.LT:
filters.append(f"{lance_key} < {val}")
elif op == FilterOperator.GTE:
filters.append(f"{lance_key} >= {val}")
elif op == FilterOperator.LTE:
filters.append(f"{lance_key} <= {val}")
elif op == FilterOperator.NE:
if isinstance(val, str):
filters.append(f"{lance_key} != '{val}'")
else:
filters.append(f"{lance_key} != {val}")
# Add other operators as needed
if not filters:
return None
return " AND ".join(filters)
lancedb_base._to_lance_filter = patched_to_lance_filter
# --- MONKEYPATCH END ---
def load_environment():
"""Loads environment variables from .env file."""
load_dotenv()
# --- Configuration ---
@st.cache_resource
def init_embedding_model():
"""Initializes and caches the embedding model globally."""
Settings.embed_model = HuggingFaceEmbedding(
model_name="pritamdeka/S-PubMedBert-MS-MARCO",
device="cpu"
)
def setup_llama_index(api_key: Optional[str] = None):
"""
Configures global LlamaIndex settings (LLM).
Embedding model is handled by init_embedding_model().
"""
# Ensure embedding model is loaded
init_embedding_model()
# Use passed key, or fallback to env var
final_key = api_key or os.environ.get("GOOGLE_API_KEY")
if not final_key:
return
try:
# Pass the key explicitly if available
Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
except Exception as e:
print(f"⚠️ LLM initialization failed: {e}")
from llama_index.core.llms import MockLLM
Settings.llm = MockLLM()
@st.cache_resource
def load_index() -> VectorStoreIndex:
"""
Loads and caches the persistent LanceDB index.
"""
# setup_llama_index() <-- REMOVED: App handles setup. Calling here resets LLM if env var is missing.
# Initialize LanceDB
db_path = "./ct_gov_lancedb"
db = lancedb.connect(db_path)
# Define metadata keys explicitly to ensure filters work
metadata_keys = [
"nct_id", "title", "org", "sponsor", "status", "phase",
"study_type", "start_year", "condition", "intervention",
"country", "state"
]
# Create the vector store wrapper
vector_store = LanceDBVectorStore(
uri=db_path,
table_name="clinical_trials",
query_mode="hybrid",
)
# Manually set metadata keys as constructor doesn't accept them
vector_store._metadata_keys = metadata_keys
# Create storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Load the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store, storage_context=storage_context
)
return index
def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None):
"""
Creates a Hybrid Retriever using LanceDB's native hybrid search.
Args:
index (VectorStoreIndex): The loaded vector index.
similarity_top_k (int): Number of top results to retrieve.
filters (MetadataFilters, optional): Filters to apply.
Returns:
VectorIndexRetriever: The configured retriever.
"""
# LanceDB supports native hybrid search via query_mode="hybrid"
# We pass this configuration to the retriever
# Use standard retriever first to avoid LanceDB hybrid search issues on small datasets
return index.as_retriever(
similarity_top_k=similarity_top_k,
filters=filters,
)
# --- Normalization ---
# Centralized Sponsor Mappings
# Key: Canonical Name
# Value: List of variations/aliases (including the canonical name itself if needed for matching)
SPONSOR_MAPPINGS = {
"GlaxoSmithKline": [
"gsk", "glaxo", "glaxosmithkline", "glaxosmithkline",
"GlaxoSmithKline"
],
"Janssen": [
"j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen",
"Janssen Research & Development, LLC",
"Janssen Vaccines & Prevention B.V.",
"Janssen Pharmaceutical K.K.",
"Janssen-Cilag International NV",
"Janssen Sciences Ireland UC",
"Janssen Pharmaceutica N.V., Belgium",
"Janssen Scientific Affairs, LLC",
"Janssen-Cilag Ltd.",
"Xian-Janssen Pharmaceutical Ltd.",
"Janssen Korea, Ltd., Korea",
"Janssen-Cilag G.m.b.H",
"Janssen-Cilag, S.A.",
"Janssen BioPharma, Inc.",
],
"Bristol-Myers Squibb": [
"bms", "bristol", "bristol myers squibb", "bristol-myers squibb",
"Bristol-Myers Squibb"
],
"Merck Sharp & Dohme": [
"merck", "msd", "merck sharp & dohme",
"Merck Sharp & Dohme LLC"
],
"Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."],
"AstraZeneca": ["astrazeneca", "AstraZeneca"],
"Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"],
"Sanofi": ["sanofi", "Sanofi"],
"Novartis": ["novartis", "Novartis"],
}
def normalize_sponsor(sponsor: str) -> Optional[str]:
"""
Normalizes sponsor names to canonical forms using centralized mappings.
"""
if not sponsor:
return None
s = sponsor.lower().strip()
for canonical, variations in SPONSOR_MAPPINGS.items():
# Check if input matches canonical name (case-insensitive)
if s == canonical.lower():
return canonical
# Check variations and aliases
for v in variations:
v_lower = v.lower()
if v_lower == s:
return canonical
# If the variation is a known alias (like 'gsk'), check if it's in the string
if len(v) < 5 and v_lower in s:
return canonical
if canonical.lower() in s:
return canonical
return sponsor
def get_sponsor_variations(sponsor: str) -> Optional[List[str]]:
"""
Returns list of exact database 'org' values for a given sponsor alias.
"""
if not sponsor:
return None
# First, normalize the input to get the canonical name
canonical = normalize_sponsor(sponsor)
if canonical in SPONSOR_MAPPINGS:
return SPONSOR_MAPPINGS[canonical]
return None