chagu-dev / rag_sec /document_retriver.py
talexm
update
f861dee
raw
history blame
1.51 kB
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
class DocumentRetriever:
def __init__(self):
self.documents = []
self.vectorizer = TfidfVectorizer()
self.index = None
def load_documents(self, source_dir):
from pathlib import Path
data_dir = Path(source_dir)
if not data_dir.exists():
print(f"Source directory not found: {source_dir}")
return
for file in data_dir.glob("*.txt"):
with open(file, "r", encoding="utf-8") as f:
self.documents.append(f.read())
print(f"Loaded {len(self.documents)} documents.")
# Create the FAISS index
self._build_index()
def _build_index(self):
# Generate TF-IDF vectors for documents
doc_vectors = self.vectorizer.fit_transform(self.documents).toarray()
# Create FAISS index
self.index = faiss.IndexFlatL2(doc_vectors.shape[1])
self.index.add(doc_vectors.astype(np.float32))
def retrieve(self, query, top_k=5):
if not self.index:
return ["Document retrieval is not initialized."]
# Vectorize the query
query_vector = self.vectorizer.transform([query]).toarray().astype(np.float32)
# Perform FAISS search
distances, indices = self.index.search(query_vector, top_k)
# Return matching documents
return [self.documents[i] for i in indices[0] if i < len(self.documents)]