Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
from typing import Dict, List, Tuple | |
import torch | |
class LegalRetriever: | |
def __init__(self, sections: Dict[str, str]): | |
try: | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.sections = sections | |
self.section_texts = list(sections.values()) | |
self.section_ids = list(sections.keys()) | |
self._create_embeddings() | |
except Exception as e: | |
print(f"Error initializing LegalRetriever: {str(e)}") | |
raise | |
def _create_embeddings(self): | |
try: | |
batch_size = 32 | |
embeddings = [] | |
for i in range(0, len(self.section_texts), batch_size): | |
batch = self.section_texts[i:i + batch_size] | |
with torch.no_grad(): | |
batch_embeddings = self.model.encode(batch, convert_to_tensor=True) | |
embeddings.append(batch_embeddings.cpu().numpy()) | |
self.embeddings = np.vstack(embeddings) | |
self.dimension = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(self.dimension) | |
self.index.add(self.embeddings.astype('float32')) | |
except Exception as e: | |
print(f"Error creating embeddings: {str(e)}") | |
raise | |
def retrieve(self, query: str, top_k: int = 3) -> List[Tuple[str, str, float]]: | |
try: | |
if not query.strip(): | |
return [] | |
with torch.no_grad(): | |
query_embedding = self.model.encode([query], convert_to_tensor=True) | |
query_embedding = query_embedding.cpu().numpy() | |
distances, indices = self.index.search(query_embedding.astype('float32'), min(top_k, len(self.sections))) | |
results = [] | |
for idx, distance in zip(indices[0], distances[0]): | |
if idx < len(self.section_ids): | |
section_id = self.section_ids[idx] | |
content = self.sections[section_id] | |
score = 1 / (1 + distance) | |
results.append((section_id, content, score)) | |
return sorted(results, key=lambda x: x[2], reverse=True) | |
except Exception as e: | |
print(f"Error retrieving sections: {str(e)}") | |
return [] | |