File size: 4,532 Bytes
2d6f61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import os
import torch
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from transformers import Pipeline
from typing import Dict, List, Union, Optional

class IndianLawRAGPipeline(Pipeline):
    def __init__(self, model, tokenizer, device=None, framework="pt", **kwargs):
        super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework, **kwargs)
        self.retriever_k = kwargs.get('retriever_k', 5)
        self.max_new_tokens = kwargs.get('max_new_tokens', 512)
        self.temperature = kwargs.get('temperature', 0.7)
        self._load_vector_store()
        
    def _load_vector_store(self):
        '''Load the vector store for retrieval'''
        # Path relative to the model directory
        vector_dir = os.path.join(os.path.dirname(__file__), "vector_db")
        
        # Default embedding model
        default_model = "sentence-transformers/all-mpnet-base-v2"
        
        try:
            # Load model info
            model_info_path = os.path.join(vector_dir, "model_info.json")
            if os.path.exists(model_info_path):
                import json
                with open(model_info_path, "r") as f:
                    model_info = json.load(f)
                model_name = model_info.get("name", default_model)
            else:
                model_name = default_model
                
            # Load embedding model
            self.retriever_model = SentenceTransformer(model_name)
            
            # Load FAISS index
            self.index = faiss.read_index(os.path.join(vector_dir, "index.faiss"))
            
            # Load chunks
            with open(os.path.join(vector_dir, "chunks.pkl"), "rb") as f:
                self.chunks = pickle.load(f)
                
            self.vector_store_loaded = True
        except Exception as e:
            print(f"Error loading vector store: {str(e)}")
            self.vector_store_loaded = False
            
    def retrieve(self, query: str) -> List[Dict]:
        '''Retrieve relevant passages for the query'''
        if not hasattr(self, 'vector_store_loaded') or not self.vector_store_loaded:
            return []
        
        # Generate query embedding
        query_embedding = self.retriever_model.encode([query], convert_to_numpy=True)
        query_embedding = query_embedding.astype('float32')
        faiss.normalize_L2(query_embedding)
        
        # Search index
        scores, indices = self.index.search(query_embedding, self.retriever_k)
        
        # Filter valid indices
        valid_indices = [idx for idx in indices[0] if idx < len(self.chunks)]
        retrieved_chunks = [self.chunks[idx] for idx in valid_indices]
        
        return retrieved_chunks
    
    def __call__(self, query_text, **kwargs):
        '''Process a query through RAG pipeline'''
        # Retrieve relevant context
        retrieved_chunks = self.retrieve(query_text)
        
        if not retrieved_chunks:
            # Fall back to direct LM generation if retrieval fails
            inputs = self.tokenizer(query_text, return_tensors="pt").to(self.device)
            outputs = self.model.generate(
                **inputs, 
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                do_sample=True
            )
            result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return [{"generated_text": result}]
        
        # Build context from chunks
        context = "\n\n".join(chunk['content'] for chunk in retrieved_chunks)
        
        # Format prompt
        prompt = f'''You are an Indian legal expert. Answer strictly based on the provided context.
    
Context:
{context}

Question:
{query_text}

Answer:'''
        
        # Generate answer
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs, 
            max_new_tokens=self.max_new_tokens,
            temperature=self.temperature,
            do_sample=True
        )
        
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Try to extract only the answer portion
        answer_start = full_response.find("Answer:")
        if answer_start != -1:
            answer = full_response[answer_start + len("Answer:"):]
        else:
            answer = full_response
        
        return [{"generated_text": answer.strip()}]