Spaces:
Configuration error
Configuration error
# rag.py | |
import os | |
import json | |
import pickle | |
import logging | |
from typing import List, Tuple, Optional | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from config import VECTORSTORE_DIR, EMBEDDING_MODEL | |
log = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class RAGAgent: | |
""" | |
Loads a FAISS index + metadata from VECTORSTORE_DIR (config). | |
Provides retrieve(query, k) -> (contexts: List[str], sources: List[dict]) | |
""" | |
def __init__(self, vectorstore_dir: Optional[str] = None, embedding_model: Optional[str] = None): | |
self.vectorstore_dir = vectorstore_dir or str(VECTORSTORE_DIR) | |
self.embedding_model_name = embedding_model or EMBEDDING_MODEL | |
self.index: Optional[faiss.Index] = None | |
self.metadata: Optional[List[dict]] = None | |
self._embedder: Optional[SentenceTransformer] = None | |
self._loaded = False | |
def _find_index_file(self) -> str: | |
if not os.path.isdir(self.vectorstore_dir): | |
raise FileNotFoundError(f"Vectorstore dir not found: {self.vectorstore_dir}") | |
for fname in os.listdir(self.vectorstore_dir): | |
if fname.endswith(".faiss") or fname.endswith(".index") or fname.endswith(".bin") or fname.startswith("index"): | |
return os.path.join(self.vectorstore_dir, fname) | |
raise FileNotFoundError(f"No FAISS index file (.faiss/.index/.bin) found in {self.vectorstore_dir}") | |
def _find_meta_file(self) -> str: | |
for candidate in ("index.pkl", "metadata.pkl", "index_meta.pkl", "metadata.json", "index.json"): | |
p = os.path.join(self.vectorstore_dir, candidate) | |
if os.path.exists(p): | |
return p | |
for fname in os.listdir(self.vectorstore_dir): | |
if fname.endswith(".pkl"): | |
return os.path.join(self.vectorstore_dir, fname) | |
raise FileNotFoundError(f"No metadata (.pkl/.json) found in {self.vectorstore_dir}") | |
def embedder(self) -> SentenceTransformer: | |
if self._embedder is None: | |
log.info("Loading embedder: %s", self.embedding_model_name) | |
self._embedder = SentenceTransformer(self.embedding_model_name) | |
return self._embedder | |
def load(self) -> None: | |
"""Load index and metadata into memory (idempotent).""" | |
if self._loaded: | |
return | |
idx_path = self._find_index_file() | |
meta_path = self._find_meta_file() | |
log.info("Loading FAISS index from: %s", idx_path) | |
try: | |
self.index = faiss.read_index(idx_path) | |
except Exception as e: | |
raise RuntimeError(f"Failed to read faiss index {idx_path}: {e}") | |
log.info("Loading metadata from: %s", meta_path) | |
if meta_path.endswith(".json"): | |
with open(meta_path, "r", encoding="utf-8") as f: | |
self.metadata = json.load(f) | |
else: | |
with open(meta_path, "rb") as f: | |
self.metadata = pickle.load(f) | |
if not isinstance(self.metadata, list): | |
if isinstance(self.metadata, dict): | |
keys = sorted(self.metadata.keys()) | |
try: | |
self.metadata = [self.metadata[k] for k in keys] | |
except Exception: | |
self.metadata = list(self.metadata.values()) | |
else: | |
self.metadata = list(self.metadata) | |
log.info("Loaded index and metadata: metadata length=%d", len(self.metadata)) | |
self._loaded = True | |
def retrieve(self, query: str, k: int = 3) -> Tuple[List[str], List[dict]]: | |
""" | |
Return two lists: | |
- contexts: [str, ...] top-k chunk texts (may be fewer) | |
- sources: [ {meta..., "score": float}, ... ] | |
""" | |
if not self._loaded: | |
self.load() | |
if self.index is None or self.metadata is None: | |
return [], [] | |
q_emb = self.embedder.encode([query], convert_to_numpy=True) | |
# try normalize if index uses normalized vectors | |
try: | |
faiss.normalize_L2(q_emb) | |
except Exception: | |
pass | |
q_emb = q_emb.astype("float32") | |
# safe search call | |
try: | |
D, I = self.index.search(q_emb, k) | |
except Exception as e: | |
log.warning("FAISS search error: %s", e) | |
return [], [] | |
# ensure shapes | |
if I is None or D is None: | |
return [], [] | |
indices = np.array(I).reshape(-1)[:k].tolist() | |
scores = np.array(D).reshape(-1)[:k].tolist() | |
contexts = [] | |
sources = [] | |
for idx, score in zip(indices, scores): | |
if int(idx) < 0: | |
continue | |
# guard against idx out of metadata bounds | |
if idx >= len(self.metadata): | |
log.debug("Index %s >= metadata length %d — skipping", idx, len(self.metadata)) | |
continue | |
meta = self.metadata[int(idx)] | |
# extract text from common keys | |
text = None | |
for key in ("text", "page_content", "content", "chunk_text", "source_text"): | |
if isinstance(meta, dict) and key in meta and meta[key]: | |
text = meta[key] | |
break | |
if text is None: | |
# fallbac if metadata itself is a string or has 'text' attribute | |
if isinstance(meta, str): | |
text = meta | |
elif isinstance(meta, dict) and "metadata" in meta and isinstance(meta["metadata"], dict): | |
# sometimes nested | |
text = meta["metadata"].get("text") or meta["metadata"].get("page_content") | |
else: | |
text = str(meta) | |
contexts.append(text) | |
sources.append({"meta": meta, "score": float(score)}) | |
return contexts, sources | |