AI_Chatbot / utils /rag_handler.py
Abs6187's picture
Upload 3 files
d442ebd verified
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import Dict, List, Tuple
import torch
class LegalRetriever:
def __init__(self, sections: Dict[str, str]):
try:
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.sections = sections
self.section_texts = list(sections.values())
self.section_ids = list(sections.keys())
self._create_embeddings()
except Exception as e:
print(f"Error initializing LegalRetriever: {str(e)}")
raise
def _create_embeddings(self):
try:
batch_size = 32
embeddings = []
for i in range(0, len(self.section_texts), batch_size):
batch = self.section_texts[i:i + batch_size]
with torch.no_grad():
batch_embeddings = self.model.encode(batch, convert_to_tensor=True)
embeddings.append(batch_embeddings.cpu().numpy())
self.embeddings = np.vstack(embeddings)
self.dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(self.dimension)
self.index.add(self.embeddings.astype('float32'))
except Exception as e:
print(f"Error creating embeddings: {str(e)}")
raise
def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[str, str, float]]:
try:
if not query.strip():
return []
with torch.no_grad():
query_embedding = self.model.encode([query], convert_to_tensor=True)
query_embedding = query_embedding.cpu().numpy()
distances, indices = self.index.search(query_embedding.astype('float32'), min(top_k, len(self.sections)))
results = []
for idx, distance in zip(indices[0], distances[0]):
if idx < len(self.section_ids):
section_id = self.section_ids[idx]
content = self.sections[section_id]
score = 1 / (1 + distance)
results.append((section_id, content, score))
return sorted(results, key=lambda x: x[2], reverse=True)
except Exception as e:
print(f"Error retrieving sections: {str(e)}")
return []