Spaces:
Sleeping
Sleeping
import json | |
import os | |
import pickle | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import Dict, List, Any, Tuple | |
from app.core.config import settings | |
import logging | |
logger = logging.getLogger(__name__) | |
class RAGService: | |
def __init__(self): | |
self.encoder = None | |
self.preloadedIndexes = {} | |
self._initialize_encoder() | |
self._load_indexes() | |
def _initialize_encoder(self): | |
try: | |
from sentence_transformers import SentenceTransformer | |
logger.info(f"Loading sentence transformer: {settings.sentence_transformer_model}") | |
self.encoder = SentenceTransformer(settings.sentence_transformer_model) | |
logger.info("Sentence transformer loaded successfully") | |
except ImportError: | |
logger.warning("sentence-transformers not installed - using placeholder mode") | |
self.encoder = "placeholder" | |
except Exception as e: | |
logger.error(f"Failed to load sentence transformer: {str(e)}") | |
self.encoder = "placeholder" | |
def loadFaissIndexAndChunks(self, indexPath: str, chunkPath: str) -> Tuple[Any, List]: | |
try: | |
if not os.path.exists(indexPath) or not os.path.exists(chunkPath): | |
logger.warning(f"Missing files: {indexPath} or {chunkPath}") | |
return None, [] | |
try: | |
import faiss | |
index = faiss.read_index(indexPath) | |
except ImportError: | |
logger.warning("faiss-cpu not installed - returning placeholder") | |
return "placeholder_index", [] | |
if chunkPath.endswith('.pkl'): | |
with open(chunkPath, 'rb') as f: | |
chunks = pickle.load(f) | |
else: | |
with open(chunkPath, 'r', encoding='utf-8') as f: | |
chunks = json.load(f) | |
logger.info(f"Loaded index from {indexPath} with {len(chunks)} chunks") | |
return index, chunks | |
except Exception as e: | |
logger.error(f"Failed to load index {indexPath}: {str(e)}") | |
return None, [] | |
def _load_indexes(self): | |
basePath = settings.faiss_indexes_base_path | |
self.preloadedIndexes = { | |
"constitution": self.loadFaissIndexAndChunks(f"{basePath}/constitution_bgeLarge.index", f"{basePath}/constitution_chunks.json"), | |
"ipcSections": self.loadFaissIndexAndChunks(f"{basePath}/ipc_bgeLarge.index", f"{basePath}/ipc_chunks.json"), | |
"ipcCase": self.loadFaissIndexAndChunks(f"{basePath}/ipc_case_flat.index", f"{basePath}/ipc_case_chunks.json"), | |
"statutes": self.loadFaissIndexAndChunks(f"{basePath}/statute_index.faiss", f"{basePath}/statute_chunks.pkl"), | |
"qaTexts": self.loadFaissIndexAndChunks(f"{basePath}/qa_faiss_index.idx", f"{basePath}/qa_text_chunks.json"), | |
"caseLaw": self.loadFaissIndexAndChunks(f"{basePath}/case_faiss.index", f"{basePath}/case_chunks.pkl") | |
} | |
self.preloadedIndexes = {k: v for k, v in self.preloadedIndexes.items() if v[0] is not None} | |
logger.info(f"Successfully loaded {len(self.preloadedIndexes)} indexes") | |
def search(self, index: Any, chunks: List, queryEmbedding, topK: int) -> List[Tuple[float, Any]]: | |
try: | |
if index == "placeholder_index": | |
return [(0.5, chunk) for chunk in chunks[:topK]] | |
import faiss | |
D, I = index.search(queryEmbedding, topK) | |
results = [] | |
for score, idx in zip(D[0], I[0]): | |
if idx < len(chunks): | |
results.append((score, chunks[idx])) | |
return results | |
except Exception as e: | |
logger.error(f"Search failed: {str(e)}") | |
return [] | |
def retrieveSupportChunksParallel(self, inputText: str) -> Tuple[Dict[str, List], Dict]: | |
if self.encoder == "placeholder": | |
logger.info("Using placeholder RAG retrieval") | |
logs = {"query": inputText} | |
support = {} | |
for name in ["constitution", "ipcSections", "ipcCase", "statutes", "qaTexts", "caseLaw"]: | |
if name in self.preloadedIndexes: | |
_, chunks = self.preloadedIndexes[name] | |
support[name] = chunks[:5] if chunks else [] | |
else: | |
support[name] = [] | |
logs["supportChunksUsed"] = support | |
return support, logs | |
try: | |
import faiss | |
queryEmbedding = self.encoder.encode([inputText], normalize_embeddings=True).astype('float32') | |
faiss.normalize_L2(queryEmbedding) | |
logs = {"query": inputText} | |
def retrieve(name): | |
if name not in self.preloadedIndexes: | |
return name, [] | |
idx, chunks = self.preloadedIndexes[name] | |
results = self.search(idx, chunks, queryEmbedding, 5) | |
return name, [c[1] for c in results] | |
support = {} | |
with ThreadPoolExecutor(max_workers=6) as executor: | |
futures = [executor.submit(retrieve, name) for name in self.preloadedIndexes.keys()] | |
for f in futures: | |
name, topChunks = f.result() | |
support[name] = topChunks | |
logs["supportChunksUsed"] = support | |
return support, logs | |
except Exception as e: | |
logger.error(f"Error retrieving support chunks: {str(e)}") | |
raise ValueError(f"Support chunk retrieval failed: {str(e)}") | |
def retrieveDualSupportChunks(self, inputText: str, geminiQueryModel): | |
try: | |
geminiQuery = geminiQueryModel.generateSearchQueryFromCase(inputText, geminiQueryModel) | |
except: | |
geminiQuery = None | |
supportFromCase, _ = self.retrieveSupportChunksParallel(inputText) | |
supportFromQuery, _ = self.retrieveSupportChunksParallel(geminiQuery or inputText) | |
combinedSupport = {} | |
for key in supportFromCase: | |
combined = supportFromCase[key] + supportFromQuery[key] | |
seen = set() | |
unique = [] | |
for chunk in combined: | |
if isinstance(chunk, str): | |
rep = chunk | |
else: | |
rep = chunk.get("text") or chunk.get("description") or chunk.get("section_desc") or str(chunk) | |
if rep not in seen: | |
seen.add(rep) | |
unique.append(chunk) | |
if len(unique) == 10: | |
break | |
combinedSupport[key] = unique | |
return combinedSupport, geminiQuery | |
def areIndexesLoaded(self) -> bool: | |
return len(self.preloadedIndexes) > 0 | |
def getLoadedIndexes(self) -> List[str]: | |
return list(self.preloadedIndexes.keys()) | |
def is_healthy(self) -> bool: | |
return self.encoder is not None | |