|
import os |
|
import logging |
|
from pathlib import Path |
|
|
|
|
|
try: |
|
import faiss |
|
_HAS_FAISS = True |
|
except ImportError: |
|
logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.") |
|
_HAS_FAISS = False |
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
_HOME = Path.home() |
|
_ST_CACHE = os.getenv("SENTENCE_TRANSFORMERS_HOME", str(_HOME / ".cache" / "sentence-transformers")) |
|
_ST_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
def _load_st_model(): |
|
|
|
try: |
|
os.makedirs(_ST_CACHE, exist_ok=True) |
|
except Exception as e: |
|
logging.warning(f"Could not create cache directory {_ST_CACHE}: {e}") |
|
|
|
|
|
try: |
|
return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE) |
|
except Exception as e1: |
|
logging.warning(f"Primary load failed for '{_ST_MODEL_ID}' with cache '{_ST_CACHE}': {e1}") |
|
|
|
|
|
try: |
|
return SentenceTransformer(_ST_MODEL_ID, cache_folder=_ST_CACHE, trust_remote_code=True) |
|
except Exception as e2: |
|
logging.exception("Failed loading SentenceTransformer model on both attempts.") |
|
|
|
logging.error( |
|
"Disabling retrieval due to model load failure. " |
|
f"Check permissions for {_ST_CACHE} and HF_* env vars." |
|
) |
|
return None |
|
|
|
|
|
_model = _load_st_model() |
|
|
|
_index = None |
|
_docs = [] |
|
|
|
def init_retriever(docs=None): |
|
""" |
|
Initialize FAISS index if FAISS is available. |
|
docs: list[str] to index |
|
""" |
|
global _index, _docs |
|
if _model is None: |
|
_docs = docs or [] |
|
return |
|
if not _HAS_FAISS: |
|
_docs = docs or [] |
|
return |
|
|
|
if docs: |
|
_docs = docs |
|
import numpy as np |
|
embeddings = _model.encode(docs, convert_to_numpy=True, normalize_embeddings=False) |
|
d = embeddings.shape[1] |
|
_index = faiss.IndexFlatL2(d) |
|
_index.add(embeddings) |
|
|
|
def retrieve_context(query: str, k: int = 5): |
|
""" |
|
Retrieve top-k docs matching query. |
|
Falls back to empty list if FAISS unavailable or not initialized. |
|
""" |
|
if _model is None: |
|
return [] |
|
if not _HAS_FAISS or _index is None or not _docs: |
|
return [] |
|
|
|
q_emb = _model.encode([query], convert_to_numpy=True, normalize_embeddings=False) |
|
D, I = _index.search(q_emb, k) |
|
return [_docs[i] for i in I[0] if 0 <= i < len(_docs)] |
|
|