import os import logging from pathlib import Path # Optional FAISS (keeps original behavior) try: import faiss _HAS_FAISS = True except ImportError: logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.") _HAS_FAISS = False from sentence_transformers import SentenceTransformer # ---- Writable cache + stable repo id for Spaces ---- _HOME = Path.home() _ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers")) _ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # canonical repo id def _load_st_model(): # Ensure cache dir exists try: os.makedirs(_ST_CACHE, exist_ok=True) except Exception as e: logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}") # Primary attempt try: return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE) except Exception as e1: logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}") # Secondary attempt (allow trust_remote_code just in case) try: return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True) except Exception as e2: logging.exception("Failed loading SentenceTransformer model on both attempts.") # Soft-fail: disable retrieval rather than crashing the whole app logging.error( "Disabling retrieval due to model load failure. " f"Check permissions for {_ST_CACHE} and HF_* env vars." ) return None # Load embedding model (works even if FAISS missing) _model = _load_st_model() _index = None _docs = [] def init_retriever(docs=None): """ Initialize FAISS index if FAISS is available. docs: list[str] to index """ global _index, _docs if _model is None: _docs = docs or [] return if not _HAS_FAISS: _docs = docs or [] return if docs: _docs = docs import numpy as np embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False) d = embeddings.shape[1] _index = faiss.IndexFlatL2(d) _index.add(embeddings) def retrieve_context(query: str, k: int = 5): """ Retrieve top-k docs matching query. Falls back to empty list if FAISS unavailable or not initialized. """ if _model is None: return [] if not _HAS_FAISS or _index is None or not _docs: return [] q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False) D, I = _index.search(q_emb, k) return [_docs[i] for i in I[0] if 0 <= i < len(_docs)]