File size: 7,781 Bytes
76bba0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
RAG (Retrieval Augmented Generation) store for fraud pattern matching
"""
import os
import json
import logging
from typing import List, Dict, Any, Optional
import pickle
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

logger = logging.getLogger(__name__)

class RAGStore:
    """Simple RAG store using sentence transformers and local file storage"""
    
    def __init__(self, collection_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.collection_dir = collection_dir
        self.model_name = model_name
        self.embeddings_file = os.path.join(collection_dir, "embeddings.pkl")
        self.texts_file = os.path.join(collection_dir, "texts.json")
        self.metadata_file = os.path.join(collection_dir, "metadata.json")
        
        os.makedirs(collection_dir, exist_ok=True)
        
        # Initialize sentence transformer
        try:
            self.encoder = SentenceTransformer(model_name)
            logger.info(f"Initialized SentenceTransformer: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize SentenceTransformer: {e}")
            self.encoder = None
        
        # Load existing data
        self.texts = []
        self.metadatas = []
        self.embeddings = None
        self._load_data()
    
    def _load_data(self):
        """Load existing embeddings, texts, and metadata"""
        try:
            if os.path.exists(self.texts_file):
                with open(self.texts_file, 'r') as f:
                    self.texts = json.load(f)
            
            if os.path.exists(self.metadata_file):
                with open(self.metadata_file, 'r') as f:
                    self.metadatas = json.load(f)
            
            if os.path.exists(self.embeddings_file):
                with open(self.embeddings_file, 'rb') as f:
                    self.embeddings = pickle.load(f)
            
            logger.info(f"Loaded {len(self.texts)} existing documents")
            
        except Exception as e:
            logger.error(f"Error loading RAG data: {e}")
            self.texts = []
            self.metadatas = []
            self.embeddings = None
    
    def _save_data(self):
        """Save embeddings, texts, and metadata to files"""
        try:
            with open(self.texts_file, 'w') as f:
                json.dump(self.texts, f)
            
            with open(self.metadata_file, 'w') as f:
                json.dump(self.metadatas, f, default=str)
            
            if self.embeddings is not None:
                with open(self.embeddings_file, 'wb') as f:
                    pickle.dump(self.embeddings, f)
            
            logger.info(f"Saved {len(self.texts)} documents to storage")
            
        except Exception as e:
            logger.error(f"Error saving RAG data: {e}")
    
    def add(self, texts: List[str], metadatas: List[Dict[str, Any]]):
        """Add new documents to the RAG store"""
        if not self.encoder:
            logger.warning("No encoder available, cannot add documents")
            return
        
        if len(texts) != len(metadatas):
            logger.error("Texts and metadatas must have the same length")
            return
        
        try:
            # Generate embeddings for new texts
            new_embeddings = self.encoder.encode(texts)
            
            # Add to existing data
            self.texts.extend(texts)
            self.metadatas.extend(metadatas)
            
            if self.embeddings is None:
                self.embeddings = new_embeddings
            else:
                self.embeddings = np.vstack([self.embeddings, new_embeddings])
            
            # Save to disk
            self._save_data()
            
            logger.info(f"Added {len(texts)} new documents to RAG store")
            
        except Exception as e:
            logger.error(f"Error adding documents to RAG store: {e}")
    
    def query(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """Query the RAG store for similar documents"""
        if not self.encoder or self.embeddings is None or len(self.texts) == 0:
            logger.warning("RAG store is empty or encoder unavailable")
            return []
        
        try:
            # Encode the query
            query_embedding = self.encoder.encode([query])
            
            # Calculate similarities
            similarities = cosine_similarity(query_embedding, self.embeddings)[0]
            
            # Get top k results
            top_indices = np.argsort(similarities)[::-1][:k]
            
            results = []
            for idx in top_indices:
                if similarities[idx] > 0.1:  # Minimum similarity threshold
                    results.append({
                        "text": self.texts[idx],
                        "metadata": self.metadatas[idx],
                        "similarity": float(similarities[idx])
                    })
            
            logger.info(f"Query returned {len(results)} results")
            return results
            
        except Exception as e:
            logger.error(f"Error querying RAG store: {e}")
            return []
    
    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the RAG store"""
        return {
            "total_documents": len(self.texts),
            "has_embeddings": self.embeddings is not None,
            "encoder_available": self.encoder is not None,
            "collection_dir": self.collection_dir
        }
    
    def clear(self):
        """Clear all data from the RAG store"""
        try:
            self.texts = []
            self.metadatas = []
            self.embeddings = None
            
            # Remove files
            for file_path in [self.embeddings_file, self.texts_file, self.metadata_file]:
                if os.path.exists(file_path):
                    os.remove(file_path)
            
            logger.info("RAG store cleared")
            
        except Exception as e:
            logger.error(f"Error clearing RAG store: {e}")

# Utility functions for fraud-specific RAG queries
def build_fraud_context(transaction_data: Dict[str, Any]) -> str:
    """Build a searchable text representation of transaction data"""
    parts = []
    
    if 'amount' in transaction_data:
        parts.append(f"amount:{transaction_data['amount']}")
    
    if 'merchant' in transaction_data:
        parts.append(f"merchant:{transaction_data['merchant']}")
    
    if 'category' in transaction_data:
        parts.append(f"category:{transaction_data['category']}")
    
    if 'description' in transaction_data:
        parts.append(f"description:{transaction_data['description']}")
    
    if 'timestamp' in transaction_data:
        parts.append(f"time:{transaction_data['timestamp']}")
    
    return " ".join(parts)

def extract_fraud_patterns(rag_results: List[Dict[str, Any]]) -> List[str]:
    """Extract common fraud patterns from RAG results"""
    patterns = []
    
    for result in rag_results:
        metadata = result.get('metadata', {})
        similarity = result.get('similarity', 0)
        
        if similarity > 0.7:  # High similarity threshold
            if 'merchant' in metadata:
                patterns.append(f"Similar merchant: {metadata['merchant']}")
            
            if 'amount' in metadata:
                patterns.append(f"Similar amount: ${metadata['amount']}")
            
            if 'category' in metadata:
                patterns.append(f"Similar category: {metadata['category']}")
    
    return list(set(patterns))  # Remove duplicates