LegalLens-API / app /services /rag_service.py
negi2725's picture
Update app/services/rag_service.py
d769b8f verified
import json
import os
import pickle
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Any, Tuple
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
class RAGService:
def __init__(self):
self.encoder = None
self.preloadedIndexes = {}
self._initialize_encoder()
self._load_indexes()
def _initialize_encoder(self):
try:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading sentence transformer: {settings.sentence_transformer_model}")
self.encoder = SentenceTransformer(settings.sentence_transformer_model)
logger.info("Sentence transformer loaded successfully")
except ImportError:
logger.warning("sentence-transformers not installed - using placeholder mode")
self.encoder = "placeholder"
except Exception as e:
logger.error(f"Failed to load sentence transformer: {str(e)}")
self.encoder = "placeholder"
def loadFaissIndexAndChunks(self, indexPath: str, chunkPath: str) -> Tuple[Any, List]:
try:
if not os.path.exists(indexPath) or not os.path.exists(chunkPath):
logger.warning(f"Missing files: {indexPath} or {chunkPath}")
return None, []
try:
import faiss
index = faiss.read_index(indexPath)
except ImportError:
logger.warning("faiss-cpu not installed - returning placeholder")
return "placeholder_index", []
if chunkPath.endswith('.pkl'):
with open(chunkPath, 'rb') as f:
chunks = pickle.load(f)
else:
with open(chunkPath, 'r', encoding='utf-8') as f:
chunks = json.load(f)
logger.info(f"Loaded index from {indexPath} with {len(chunks)} chunks")
return index, chunks
except Exception as e:
logger.error(f"Failed to load index {indexPath}: {str(e)}")
return None, []
def _load_indexes(self):
basePath = settings.faiss_indexes_base_path
self.preloadedIndexes = {
"constitution": self.loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"),
"ipcSections": self.loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"),
"ipcCase": self.loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"),
"statutes": self.loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"),
"qaTexts": self.loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"),
"caseLaw": self.loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl")
}
self.preloadedIndexes = {k: v for k, v in self.preloadedIndexes.items() if v[0] is not None}
logger.info(f"Successfully loaded {len(self.preloadedIndexes)} indexes")
def search(self, index: Any, chunks: List, queryEmbedding, topK: int) -> List[Tuple[float, Any]]:
try:
if index == "placeholder_index":
return [(0.5, chunk) for chunk in chunks[:topK]]
import faiss
D, I = index.search(queryEmbedding, topK)
results = []
for score, idx in zip(D[0], I[0]):
if idx < len(chunks):
results.append((score, chunks[idx]))
return results
except Exception as e:
logger.error(f"Search failed: {str(e)}")
return []
def retrieveSupportChunksParallel(self, inputText: str) -> Tuple[Dict[str, List], Dict]:
if self.encoder == "placeholder":
logger.info("Using placeholder RAG retrieval")
logs = {"query": inputText}
support = {}
for name in ["constitution", "ipcSections", "ipcCase", "statutes", "qaTexts", "caseLaw"]:
if name in self.preloadedIndexes:
_, chunks = self.preloadedIndexes[name]
support[name] = chunks[:5] if chunks else []
else:
support[name] = []
logs["supportChunksUsed"] = support
return support, logs
try:
import faiss
queryEmbedding = self.encoder.encode([inputText], normalize_embeddings=True).astype('float32')
faiss.normalize_L2(queryEmbedding)
logs = {"query": inputText}
def retrieve(name):
if name not in self.preloadedIndexes:
return name, []
idx, chunks = self.preloadedIndexes[name]
results = self.search(idx, chunks, queryEmbedding, 5)
return name, [c[1] for c in results]
support = {}
with ThreadPoolExecutor(max_workers=6) as executor:
futures = [executor.submit(retrieve, name) for name in self.preloadedIndexes.keys()]
for f in futures:
name, topChunks = f.result()
support[name] = topChunks
logs["supportChunksUsed"] = support
return support, logs
except Exception as e:
logger.error(f"Error retrieving support chunks: {str(e)}")
raise ValueError(f"Support chunk retrieval failed: {str(e)}")
def retrieveDualSupportChunks(self, inputText: str, geminiQueryModel):
try:
geminiQuery = geminiQueryModel.generateSearchQueryFromCase(inputText, geminiQueryModel)
except:
geminiQuery = None
supportFromCase, _ = self.retrieveSupportChunksParallel(inputText)
supportFromQuery, _ = self.retrieveSupportChunksParallel(geminiQuery or inputText)
combinedSupport = {}
for key in supportFromCase:
combined = supportFromCase[key] + supportFromQuery[key]
seen = set()
unique = []
for chunk in combined:
if isinstance(chunk, str):
rep = chunk
else:
rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk)
if rep not in seen:
seen.add(rep)
unique.append(chunk)
if len(unique) == 10:
break
combinedSupport[key] = unique
return combinedSupport, geminiQuery
def areIndexesLoaded(self) -> bool:
return len(self.preloadedIndexes) > 0
def getLoadedIndexes(self) -> List[str]:
return list(self.preloadedIndexes.keys())
def is_healthy(self) -> bool:
return self.encoder is not None