ChatLM / rag.py
ah707
Initital deploy
e539f46
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import textwrap
import pickle
import os
import logging
logger = logging.getLogger(__name__)
class RAG:
def __init__(self):
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
self.index = faiss.IndexFlatL2(384)
self.metadata = {}
self.index_file = "faiss_index/index.faiss"
self.metadata_file = "faiss_index/metadata.pkl"
os.makedirs("faiss_index", exist_ok=True)
if os.path.exists(self.index_file):
self.index = faiss.read_index(self.index_file)
with open(self.metadata_file, "rb") as f:
self.metadata = pickle.load(f)
async def embed_document(self, file_id: str, text: str):
chunks = textwrap.wrap(text, 500)
if not chunks:
logger.warning(f"Empty document provided for file_id {file_id}")
raise ValueError("Document contains no text to embed.")
embeddings = self.embedder.encode(chunks)
ids = np.array([i for i in range(len(self.metadata), len(self.metadata) + len(chunks))])
self.index.add(embeddings.astype("float32"))
for i, chunk in enumerate(chunks):
self.metadata[ids[i]] = {"file_id": file_id, "text": chunk}
faiss.write_index(self.index, self.index_file)
with open(self.metadata_file, "wb") as f:
pickle.dump(self.metadata, f)
async def query_document(self, question: str, file_id: str) -> list:
query_embedding = self.embedder.encode([question])[0].astype("float32")
distances, indices = self.index.search(np.array([query_embedding]), k=3)
return [
self.metadata[idx]["text"]
for idx in indices[0]
if idx in self.metadata and self.metadata[idx]["file_id"] == file_id
]