Spaces:
Sleeping
Sleeping
File size: 4,532 Bytes
2d6f61c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import torch
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from transformers import Pipeline
from typing import Dict, List, Union, Optional
class IndianLawRAGPipeline(Pipeline):
def __init__(self, model, tokenizer, device=None, framework="pt", **kwargs):
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework, **kwargs)
self.retriever_k = kwargs.get('retriever_k', 5)
self.max_new_tokens = kwargs.get('max_new_tokens', 512)
self.temperature = kwargs.get('temperature', 0.7)
self._load_vector_store()
def _load_vector_store(self):
'''Load the vector store for retrieval'''
# Path relative to the model directory
vector_dir = os.path.join(os.path.dirname(__file__), "vector_db")
# Default embedding model
default_model = "sentence-transformers/all-mpnet-base-v2"
try:
# Load model info
model_info_path = os.path.join(vector_dir, "model_info.json")
if os.path.exists(model_info_path):
import json
with open(model_info_path, "r") as f:
model_info = json.load(f)
model_name = model_info.get("name", default_model)
else:
model_name = default_model
# Load embedding model
self.retriever_model = SentenceTransformer(model_name)
# Load FAISS index
self.index = faiss.read_index(os.path.join(vector_dir, "index.faiss"))
# Load chunks
with open(os.path.join(vector_dir, "chunks.pkl"), "rb") as f:
self.chunks = pickle.load(f)
self.vector_store_loaded = True
except Exception as e:
print(f"Error loading vector store: {str(e)}")
self.vector_store_loaded = False
def retrieve(self, query: str) -> List[Dict]:
'''Retrieve relevant passages for the query'''
if not hasattr(self, 'vector_store_loaded') or not self.vector_store_loaded:
return []
# Generate query embedding
query_embedding = self.retriever_model.encode([query], convert_to_numpy=True)
query_embedding = query_embedding.astype('float32')
faiss.normalize_L2(query_embedding)
# Search index
scores, indices = self.index.search(query_embedding, self.retriever_k)
# Filter valid indices
valid_indices = [idx for idx in indices[0] if idx < len(self.chunks)]
retrieved_chunks = [self.chunks[idx] for idx in valid_indices]
return retrieved_chunks
def __call__(self, query_text, **kwargs):
'''Process a query through RAG pipeline'''
# Retrieve relevant context
retrieved_chunks = self.retrieve(query_text)
if not retrieved_chunks:
# Fall back to direct LM generation if retrieval fails
inputs = self.tokenizer(query_text, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
do_sample=True
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": result}]
# Build context from chunks
context = "\n\n".join(chunk['content'] for chunk in retrieved_chunks)
# Format prompt
prompt = f'''You are an Indian legal expert. Answer strictly based on the provided context.
Context:
{context}
Question:
{query_text}
Answer:'''
# Generate answer
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
do_sample=True
)
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Try to extract only the answer portion
answer_start = full_response.find("Answer:")
if answer_start != -1:
answer = full_response[answer_start + len("Answer:"):]
else:
answer = full_response
return [{"generated_text": answer.strip()}]
|