import json import os import pickle from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Any, Tuple from app.core.config import settings import logging logger = logging.getLogger(__name__) class RAGService: def __init__(self): self.encoder = None self.preloadedIndexes = {} self._initialize_encoder() self._load_indexes() def _initialize_encoder(self): try: from sentence_transformers import SentenceTransformer logger.info(f"Loading sentence transformer: {settings.sentence_transformer_model}") self.encoder = SentenceTransformer(settings.sentence_transformer_model) logger.info("Sentence transformer loaded successfully") except ImportError: logger.warning("sentence-transformers not installed - using placeholder mode") self.encoder = "placeholder" except Exception as e: logger.error(f"Failed to load sentence transformer: {str(e)}") self.encoder = "placeholder" def loadFaissIndexAndChunks(self, indexPath: str, chunkPath: str) -> Tuple[Any, List]: try: if not os.path.exists(indexPath) or not os.path.exists(chunkPath): logger.warning(f"Missing files: {indexPath} or {chunkPath}") return None, [] try: import faiss index = faiss.read_index(indexPath) except ImportError: logger.warning("faiss-cpu not installed - returning placeholder") return "placeholder_index", [] if chunkPath.endswith('.pkl'): with open(chunkPath, 'rb') as f: chunks = pickle.load(f) else: with open(chunkPath, 'r', encoding='utf-8') as f: chunks = json.load(f) logger.info(f"Loaded index from {indexPath} with {len(chunks)} chunks") return index, chunks except Exception as e: logger.error(f"Failed to load index {indexPath}: {str(e)}") return None, [] def _load_indexes(self): basePath = settings.faiss_indexes_base_path self.preloadedIndexes = { "constitution": self.loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"), "ipcSections": self.loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"), "ipcCase": self.loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"), "statutes": self.loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"), "qaTexts": self.loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"), "caseLaw": self.loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl") } self.preloadedIndexes = {k: v for k, v in self.preloadedIndexes.items() if v[0] is not None} logger.info(f"Successfully loaded {len(self.preloadedIndexes)} indexes") def search(self, index: Any, chunks: List, queryEmbedding, topK: int) -> List[Tuple[float, Any]]: try: if index == "placeholder_index": return [(0.5, chunk) for chunk in chunks[:topK]] import faiss D, I = index.search(queryEmbedding, topK) results = [] for score, idx in zip(D[0], I[0]): if idx < len(chunks): results.append((score, chunks[idx])) return results except Exception as e: logger.error(f"Search failed: {str(e)}") return [] def retrieveSupportChunksParallel(self, inputText: str) -> Tuple[Dict[str, List], Dict]: if self.encoder == "placeholder": logger.info("Using placeholder RAG retrieval") logs = {"query": inputText} support = {} for name in ["constitution", "ipcSections", "ipcCase", "statutes", "qaTexts", "caseLaw"]: if name in self.preloadedIndexes: _, chunks = self.preloadedIndexes[name] support[name] = chunks[:5] if chunks else [] else: support[name] = [] logs["supportChunksUsed"] = support return support, logs try: import faiss queryEmbedding = self.encoder.encode([inputText], normalize_embeddings=True).astype('float32') faiss.normalize_L2(queryEmbedding) logs = {"query": inputText} def retrieve(name): if name not in self.preloadedIndexes: return name, [] idx, chunks = self.preloadedIndexes[name] results = self.search(idx, chunks, queryEmbedding, 5) return name, [c[1] for c in results] support = {} with ThreadPoolExecutor(max_workers=6) as executor: futures = [executor.submit(retrieve, name) for name in self.preloadedIndexes.keys()] for f in futures: name, topChunks = f.result() support[name] = topChunks logs["supportChunksUsed"] = support return support, logs except Exception as e: logger.error(f"Error retrieving support chunks: {str(e)}") raise ValueError(f"Support chunk retrieval failed: {str(e)}") def retrieveDualSupportChunks(self, inputText: str, geminiQueryModel): try: geminiQuery = geminiQueryModel.generateSearchQueryFromCase(inputText, geminiQueryModel) except: geminiQuery = None supportFromCase, _ = self.retrieveSupportChunksParallel(inputText) supportFromQuery, _ = self.retrieveSupportChunksParallel(geminiQuery or inputText) combinedSupport = {} for key in supportFromCase: combined = supportFromCase[key] + supportFromQuery[key] seen = set() unique = [] for chunk in combined: if isinstance(chunk, str): rep = chunk else: rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk) if rep not in seen: seen.add(rep) unique.append(chunk) if len(unique) == 10: break combinedSupport[key] = unique return combinedSupport, geminiQuery def areIndexesLoaded(self) -> bool: return len(self.preloadedIndexes) > 0 def getLoadedIndexes(self) -> List[str]: return list(self.preloadedIndexes.keys()) def is_healthy(self) -> bool: return self.encoder is not None