BearSystemsChat / model_cache /custom_pipeline.py
SlightlyHappy's picture
inital commit
2d6f61c
import os
import torch
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from transformers import Pipeline
from typing import Dict, List, Union, Optional
class IndianLawRAGPipeline(Pipeline):
def __init__(self, model, tokenizer, device=None, framework="pt", **kwargs):
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework, **kwargs)
self.retriever_k = kwargs.get('retriever_k', 5)
self.max_new_tokens = kwargs.get('max_new_tokens', 512)
self.temperature = kwargs.get('temperature', 0.7)
self._load_vector_store()
def _load_vector_store(self):
'''Load the vector store for retrieval'''
# Path relative to the model directory
vector_dir = os.path.join(os.path.dirname(__file__), "vector_db")
# Default embedding model
default_model = "sentence-transformers/all-mpnet-base-v2"
try:
# Load model info
model_info_path = os.path.join(vector_dir, "model_info.json")
if os.path.exists(model_info_path):
import json
with open(model_info_path, "r") as f:
model_info = json.load(f)
model_name = model_info.get("name", default_model)
else:
model_name = default_model
# Load embedding model
self.retriever_model = SentenceTransformer(model_name)
# Load FAISS index
self.index = faiss.read_index(os.path.join(vector_dir, "index.faiss"))
# Load chunks
with open(os.path.join(vector_dir, "chunks.pkl"), "rb") as f:
self.chunks = pickle.load(f)
self.vector_store_loaded = True
except Exception as e:
print(f"Error loading vector store: {str(e)}")
self.vector_store_loaded = False
def retrieve(self, query: str) -> List[Dict]:
'''Retrieve relevant passages for the query'''
if not hasattr(self, 'vector_store_loaded') or not self.vector_store_loaded:
return []
# Generate query embedding
query_embedding = self.retriever_model.encode([query], convert_to_numpy=True)
query_embedding = query_embedding.astype('float32')
faiss.normalize_L2(query_embedding)
# Search index
scores, indices = self.index.search(query_embedding, self.retriever_k)
# Filter valid indices
valid_indices = [idx for idx in indices[0] if idx < len(self.chunks)]
retrieved_chunks = [self.chunks[idx] for idx in valid_indices]
return retrieved_chunks
def __call__(self, query_text, **kwargs):
'''Process a query through RAG pipeline'''
# Retrieve relevant context
retrieved_chunks = self.retrieve(query_text)
if not retrieved_chunks:
# Fall back to direct LM generation if retrieval fails
inputs = self.tokenizer(query_text, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
do_sample=True
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": result}]
# Build context from chunks
context = "\n\n".join(chunk['content'] for chunk in retrieved_chunks)
# Format prompt
prompt = f'''You are an Indian legal expert. Answer strictly based on the provided context.
Context:
{context}
Question:
{query_text}
Answer:'''
# Generate answer
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
do_sample=True
)
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Try to extract only the answer portion
answer_start = full_response.find("Answer:")
if answer_start != -1:
answer = full_response[answer_start + len("Answer:"):]
else:
answer = full_response
return [{"generated_text": answer.strip()}]