Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import faiss | |
import pickle | |
from sentence_transformers import SentenceTransformer | |
from transformers import Pipeline | |
from typing import Dict, List, Union, Optional | |
class IndianLawRAGPipeline(Pipeline): | |
def __init__(self, model, tokenizer, device=None, framework="pt", **kwargs): | |
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework, **kwargs) | |
self.retriever_k = kwargs.get('retriever_k', 5) | |
self.max_new_tokens = kwargs.get('max_new_tokens', 512) | |
self.temperature = kwargs.get('temperature', 0.7) | |
self._load_vector_store() | |
def _load_vector_store(self): | |
'''Load the vector store for retrieval''' | |
# Path relative to the model directory | |
vector_dir = os.path.join(os.path.dirname(__file__), "vector_db") | |
# Default embedding model | |
default_model = "sentence-transformers/all-mpnet-base-v2" | |
try: | |
# Load model info | |
model_info_path = os.path.join(vector_dir, "model_info.json") | |
if os.path.exists(model_info_path): | |
import json | |
with open(model_info_path, "r") as f: | |
model_info = json.load(f) | |
model_name = model_info.get("name", default_model) | |
else: | |
model_name = default_model | |
# Load embedding model | |
self.retriever_model = SentenceTransformer(model_name) | |
# Load FAISS index | |
self.index = faiss.read_index(os.path.join(vector_dir, "index.faiss")) | |
# Load chunks | |
with open(os.path.join(vector_dir, "chunks.pkl"), "rb") as f: | |
self.chunks = pickle.load(f) | |
self.vector_store_loaded = True | |
except Exception as e: | |
print(f"Error loading vector store: {str(e)}") | |
self.vector_store_loaded = False | |
def retrieve(self, query: str) -> List[Dict]: | |
'''Retrieve relevant passages for the query''' | |
if not hasattr(self, 'vector_store_loaded') or not self.vector_store_loaded: | |
return [] | |
# Generate query embedding | |
query_embedding = self.retriever_model.encode([query], convert_to_numpy=True) | |
query_embedding = query_embedding.astype('float32') | |
faiss.normalize_L2(query_embedding) | |
# Search index | |
scores, indices = self.index.search(query_embedding, self.retriever_k) | |
# Filter valid indices | |
valid_indices = [idx for idx in indices[0] if idx < len(self.chunks)] | |
retrieved_chunks = [self.chunks[idx] for idx in valid_indices] | |
return retrieved_chunks | |
def __call__(self, query_text, **kwargs): | |
'''Process a query through RAG pipeline''' | |
# Retrieve relevant context | |
retrieved_chunks = self.retrieve(query_text) | |
if not retrieved_chunks: | |
# Fall back to direct LM generation if retrieval fails | |
inputs = self.tokenizer(query_text, return_tensors="pt").to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.max_new_tokens, | |
temperature=self.temperature, | |
do_sample=True | |
) | |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return [{"generated_text": result}] | |
# Build context from chunks | |
context = "\n\n".join(chunk['content'] for chunk in retrieved_chunks) | |
# Format prompt | |
prompt = f'''You are an Indian legal expert. Answer strictly based on the provided context. | |
Context: | |
{context} | |
Question: | |
{query_text} | |
Answer:''' | |
# Generate answer | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.max_new_tokens, | |
temperature=self.temperature, | |
do_sample=True | |
) | |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Try to extract only the answer portion | |
answer_start = full_response.find("Answer:") | |
if answer_start != -1: | |
answer = full_response[answer_start + len("Answer:"):] | |
else: | |
answer = full_response | |
return [{"generated_text": answer.strip()}] | |