|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import textwrap |
|
import pickle |
|
import os |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class RAG: |
|
def __init__(self): |
|
self.embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
self.index = faiss.IndexFlatL2(384) |
|
self.metadata = {} |
|
self.index_file = "faiss_index/index.faiss" |
|
self.metadata_file = "faiss_index/metadata.pkl" |
|
os.makedirs("faiss_index", exist_ok=True) |
|
|
|
if os.path.exists(self.index_file): |
|
self.index = faiss.read_index(self.index_file) |
|
with open(self.metadata_file, "rb") as f: |
|
self.metadata = pickle.load(f) |
|
|
|
async def embed_document(self, file_id: str, text: str): |
|
chunks = textwrap.wrap(text, 500) |
|
if not chunks: |
|
logger.warning(f"Empty document provided for file_id {file_id}") |
|
raise ValueError("Document contains no text to embed.") |
|
|
|
embeddings = self.embedder.encode(chunks) |
|
ids = np.array([i for i in range(len(self.metadata), len(self.metadata) + len(chunks))]) |
|
self.index.add(embeddings.astype("float32")) |
|
|
|
for i, chunk in enumerate(chunks): |
|
self.metadata[ids[i]] = {"file_id": file_id, "text": chunk} |
|
|
|
faiss.write_index(self.index, self.index_file) |
|
with open(self.metadata_file, "wb") as f: |
|
pickle.dump(self.metadata, f) |
|
|
|
async def query_document(self, question: str, file_id: str) -> list: |
|
query_embedding = self.embedder.encode([question])[0].astype("float32") |
|
distances, indices = self.index.search(np.array([query_embedding]), k=3) |
|
|
|
return [ |
|
self.metadata[idx]["text"] |
|
for idx in indices[0] |
|
if idx in self.metadata and self.metadata[idx]["file_id"] == file_id |
|
] |
|
|