Medica_DecisionSupportAI / retriever.py
VED-AGI-1's picture
Update retriever.py
1dae236 verified
import os
import logging
from pathlib import Path
# Optional FAISS (keeps original behavior)
try:
import faiss
_HAS_FAISS = True
except ImportError:
logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.")
_HAS_FAISS = False
from sentence_transformers import SentenceTransformer
# ---- Writable cache + stable repo id for Spaces ----
_HOME = Path.home()
_ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers"))
_ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id
def _load_st_model():
# Ensure cache dir exists
try:
os.makedirs(_ST_CACHE, exist_ok=True)
except Exception as e:
logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}")
# Primary attempt
try:
return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE)
except Exception as e1:
logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}")
# Secondary attempt (allow trust_remote_code just in case)
try:
return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True)
except Exception as e2:
logging.exception("Failed loading SentenceTransformer model on both attempts.")
# Soft-fail: disable retrieval rather than crashing the whole app
logging.error(
"Disabling retrieval due to model load failure. "
f"Check permissions for {_ST_CACHE} and HF_* env vars."
)
return None
# Load embedding model (works even if FAISS missing)
_model = _load_st_model()
_index = None
_docs = []
def init_retriever(docs=None):
"""
Initialize FAISS index if FAISS is available.
docs: list[str] to index
"""
global _index, _docs
if _model is None:
_docs = docs or []
return
if not _HAS_FAISS:
_docs = docs or []
return
if docs:
_docs = docs
import numpy as np
embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False)
d = embeddings.shape[1]
_index = faiss.IndexFlatL2(d)
_index.add(embeddings)
def retrieve_context(query: str, k: int = 5):
"""
Retrieve top-k docs matching query.
Falls back to empty list if FAISS unavailable or not initialized.
"""
if _model is None:
return []
if not _HAS_FAISS or _index is None or not _docs:
return []
q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False)
D, I = _index.search(q_emb, k)
return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]