File size: 1,867 Bytes
e539f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import textwrap
import pickle
import os
import logging

logger = logging.getLogger(__name__)

class RAG:
    def __init__(self):
        self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
        self.index = faiss.IndexFlatL2(384)
        self.metadata = {}
        self.index_file = "faiss_index/index.faiss"
        self.metadata_file = "faiss_index/metadata.pkl"
        os.makedirs("faiss_index", exist_ok=True)

        if os.path.exists(self.index_file):
            self.index = faiss.read_index(self.index_file)
            with open(self.metadata_file, "rb") as f:
                self.metadata = pickle.load(f)

    async def embed_document(self, file_id: str, text: str):
        chunks = textwrap.wrap(text, 500)
        if not chunks:
            logger.warning(f"Empty document provided for file_id {file_id}")
            raise ValueError("Document contains no text to embed.")

        embeddings = self.embedder.encode(chunks)
        ids = np.array([i for i in range(len(self.metadata), len(self.metadata) + len(chunks))])
        self.index.add(embeddings.astype("float32"))

        for i, chunk in enumerate(chunks):
            self.metadata[ids[i]] = {"file_id": file_id, "text": chunk}

        faiss.write_index(self.index, self.index_file)
        with open(self.metadata_file, "wb") as f:
            pickle.dump(self.metadata, f)

    async def query_document(self, question: str, file_id: str) -> list:
        query_embedding = self.embedder.encode([question])[0].astype("float32")
        distances, indices = self.index.search(np.array([query_embedding]), k=3)

        return [
            self.metadata[idx]["text"]
            for idx in indices[0]
            if idx in self.metadata and self.metadata[idx]["file_id"] == file_id
        ]