from sentence_transformers import SentenceTransformer import faiss import numpy as np import textwrap import pickle import os import logging logger = logging.getLogger(__name__) class RAG: def __init__(self): self.embedder = SentenceTransformer("all-MiniLM-L6-v2") self.index = faiss.IndexFlatL2(384) self.metadata = {} self.index_file = "faiss_index/index.faiss" self.metadata_file = "faiss_index/metadata.pkl" os.makedirs("faiss_index", exist_ok=True) if os.path.exists(self.index_file): self.index = faiss.read_index(self.index_file) with open(self.metadata_file, "rb") as f: self.metadata = pickle.load(f) async def embed_document(self, file_id: str, text: str): chunks = textwrap.wrap(text, 500) if not chunks: logger.warning(f"Empty document provided for file_id {file_id}") raise ValueError("Document contains no text to embed.") embeddings = self.embedder.encode(chunks) ids = np.array([i for i in range(len(self.metadata), len(self.metadata) + len(chunks))]) self.index.add(embeddings.astype("float32")) for i, chunk in enumerate(chunks): self.metadata[ids[i]] = {"file_id": file_id, "text": chunk} faiss.write_index(self.index, self.index_file) with open(self.metadata_file, "wb") as f: pickle.dump(self.metadata, f) async def query_document(self, question: str, file_id: str) -> list: query_embedding = self.embedder.encode([question])[0].astype("float32") distances, indices = self.index.search(np.array([query_embedding]), k=3) return [ self.metadata[idx]["text"] for idx in indices[0] if idx in self.metadata and self.metadata[idx]["file_id"] == file_id ]