import os import torch import faiss import pickle from sentence_transformers import SentenceTransformer from transformers import Pipeline from typing import Dict, List, Union, Optional class IndianLawRAGPipeline(Pipeline): def __init__(self, model, tokenizer, device=None, framework="pt", **kwargs): super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework, **kwargs) self.retriever_k = kwargs.get('retriever_k', 5) self.max_new_tokens = kwargs.get('max_new_tokens', 512) self.temperature = kwargs.get('temperature', 0.7) self._load_vector_store() def _load_vector_store(self): '''Load the vector store for retrieval''' # Path relative to the model directory vector_dir = os.path.join(os.path.dirname(__file__), "vector_db") # Default embedding model default_model = "sentence-transformers/all-mpnet-base-v2" try: # Load model info model_info_path = os.path.join(vector_dir, "model_info.json") if os.path.exists(model_info_path): import json with open(model_info_path, "r") as f: model_info = json.load(f) model_name = model_info.get("name", default_model) else: model_name = default_model # Load embedding model self.retriever_model = SentenceTransformer(model_name) # Load FAISS index self.index = faiss.read_index(os.path.join(vector_dir, "index.faiss")) # Load chunks with open(os.path.join(vector_dir, "chunks.pkl"), "rb") as f: self.chunks = pickle.load(f) self.vector_store_loaded = True except Exception as e: print(f"Error loading vector store: {str(e)}") self.vector_store_loaded = False def retrieve(self, query: str) -> List[Dict]: '''Retrieve relevant passages for the query''' if not hasattr(self, 'vector_store_loaded') or not self.vector_store_loaded: return [] # Generate query embedding query_embedding = self.retriever_model.encode([query], convert_to_numpy=True) query_embedding = query_embedding.astype('float32') faiss.normalize_L2(query_embedding) # Search index scores, indices = self.index.search(query_embedding, self.retriever_k) # Filter valid indices valid_indices = [idx for idx in indices[0] if idx < len(self.chunks)] retrieved_chunks = [self.chunks[idx] for idx in valid_indices] return retrieved_chunks def __call__(self, query_text, **kwargs): '''Process a query through RAG pipeline''' # Retrieve relevant context retrieved_chunks = self.retrieve(query_text) if not retrieved_chunks: # Fall back to direct LM generation if retrieval fails inputs = self.tokenizer(query_text, return_tensors="pt").to(self.device) outputs = self.model.generate( **inputs, max_new_tokens=self.max_new_tokens, temperature=self.temperature, do_sample=True ) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return [{"generated_text": result}] # Build context from chunks context = "\n\n".join(chunk['content'] for chunk in retrieved_chunks) # Format prompt prompt = f'''You are an Indian legal expert. Answer strictly based on the provided context. Context: {context} Question: {query_text} Answer:''' # Generate answer inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) outputs = self.model.generate( **inputs, max_new_tokens=self.max_new_tokens, temperature=self.temperature, do_sample=True ) full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Try to extract only the answer portion answer_start = full_response.find("Answer:") if answer_start != -1: answer = full_response[answer_start + len("Answer:"):] else: answer = full_response return [{"generated_text": answer.strip()}]