GitGlimpse / retriever.py
NathanAW24
add for streamlit_cot
fd27bb4
import os
import json
import numpy as np
from dotenv import load_dotenv
from rank_bm25 import BM25Okapi
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
def _read_documents_from_folder(folder_path):
"""
Private utility function to read all .txt files from a folder.
:param folder_path: Path to the folder containing text files.
:return: List of document contents.
"""
documents = []
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith(".txt"):
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8') as f:
documents.append([file, f.read()])
return documents
class BM25Retriever:
def __init__(self, documents):
"""
Initialize BM25 Retriever using rank-bm25 library.
:param documents: List of documents (each document is a string).
"""
self.tokenized_documents = [doc.split()
for [file_name, doc] in documents]
self.bm25 = BM25Okapi(self.tokenized_documents)
def retrieve(self, query, top_n=5):
"""
Retrieve the top N documents for a query based on BM25 scores.
:param query: Query string.
:param top_n: Number of top documents to return.
:return: List of tuples (document_index, score).
"""
tokenized_query = query.split() # Tokenize query
scores = self.bm25.get_scores(tokenized_query)
top_indices = sorted(
range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
return [(index, scores[index]) for index in top_indices]
class EnsembleRetriever:
def __init__(self, folder_path, vector_db_dir, model_name, device="cpu"):
# Load environment variables
load_dotenv()
# Initialize documents
self.folder_path = folder_path
self.documents = _read_documents_from_folder(folder_path)
self.doc_dict = {doc[0]: doc[1] for doc in self.documents}
# Initialize BM25 Retriever
self.bm25_retriever = BM25Retriever(self.documents)
# Initialize Vector Retrieval
self.vectorstore = self._init_vectorstore(
vector_db_dir, model_name, device)
self.vector_retriever = self.vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 5}
)
def _init_vectorstore(self, vector_db_dir, model_name, device):
embedding_model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True}
)
return Chroma(
persist_directory=vector_db_dir,
embedding_function=embedding_model
)
def get_documents(self, query):
"""Retrieve documents using ensemble retrieval mechanism."""
# Retrieve from BM25
bm25_results = self.bm25_retriever.retrieve(query, top_n=5)
bm25_docs = [(self.documents[doc_index][0], self.documents[doc_index][1])
for doc_index, _ in bm25_results]
# Retrieve from Vector Retrieval
vector_results = self.vector_retriever.invoke(query)
vector_docs = [(doc.metadata["file_name"], doc.page_content)
for doc in vector_results]
# Combine results, removing duplicates by file name
combined_docs = {}
for fname, content in bm25_docs + vector_docs:
if fname not in combined_docs:
combined_docs[fname] = content
return combined_docs
# Usage example
if __name__ == "__main__":
folder_path = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "processed_docs")
vector_db_dir = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "vector_db")
model_name = "thenlper/gte-small"
retriever = EnsembleRetriever(folder_path, vector_db_dir, model_name)
query = "How can I enhance the validation capabilities of a Select component?"
results = retriever.get_documents(query)
# Print results
for file_name, content in results.items():
print(f"File: {file_name}\nContent: {content}\n")