""" RAG (Retrieval Augmented Generation) store for fraud pattern matching """ import os import json import logging from typing import List, Dict, Any, Optional import pickle from sentence_transformers import SentenceTransformer import numpy as np from sklearn.metrics.pairwise import cosine_similarity logger = logging.getLogger(__name__) class RAGStore: """Simple RAG store using sentence transformers and local file storage""" def __init__(self, collection_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): self.collection_dir = collection_dir self.model_name = model_name self.embeddings_file = os.path.join(collection_dir, "embeddings.pkl") self.texts_file = os.path.join(collection_dir, "texts.json") self.metadata_file = os.path.join(collection_dir, "metadata.json") os.makedirs(collection_dir, exist_ok=True) # Initialize sentence transformer try: self.encoder = SentenceTransformer(model_name) logger.info(f"Initialized SentenceTransformer: {model_name}") except Exception as e: logger.error(f"Failed to initialize SentenceTransformer: {e}") self.encoder = None # Load existing data self.texts = [] self.metadatas = [] self.embeddings = None self._load_data() def _load_data(self): """Load existing embeddings, texts, and metadata""" try: if os.path.exists(self.texts_file): with open(self.texts_file, 'r') as f: self.texts = json.load(f) if os.path.exists(self.metadata_file): with open(self.metadata_file, 'r') as f: self.metadatas = json.load(f) if os.path.exists(self.embeddings_file): with open(self.embeddings_file, 'rb') as f: self.embeddings = pickle.load(f) logger.info(f"Loaded {len(self.texts)} existing documents") except Exception as e: logger.error(f"Error loading RAG data: {e}") self.texts = [] self.metadatas = [] self.embeddings = None def _save_data(self): """Save embeddings, texts, and metadata to files""" try: with open(self.texts_file, 'w') as f: json.dump(self.texts, f) with open(self.metadata_file, 'w') as f: json.dump(self.metadatas, f, default=str) if self.embeddings is not None: with open(self.embeddings_file, 'wb') as f: pickle.dump(self.embeddings, f) logger.info(f"Saved {len(self.texts)} documents to storage") except Exception as e: logger.error(f"Error saving RAG data: {e}") def add(self, texts: List[str], metadatas: List[Dict[str, Any]]): """Add new documents to the RAG store""" if not self.encoder: logger.warning("No encoder available, cannot add documents") return if len(texts) != len(metadatas): logger.error("Texts and metadatas must have the same length") return try: # Generate embeddings for new texts new_embeddings = self.encoder.encode(texts) # Add to existing data self.texts.extend(texts) self.metadatas.extend(metadatas) if self.embeddings is None: self.embeddings = new_embeddings else: self.embeddings = np.vstack([self.embeddings, new_embeddings]) # Save to disk self._save_data() logger.info(f"Added {len(texts)} new documents to RAG store") except Exception as e: logger.error(f"Error adding documents to RAG store: {e}") def query(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """Query the RAG store for similar documents""" if not self.encoder or self.embeddings is None or len(self.texts) == 0: logger.warning("RAG store is empty or encoder unavailable") return [] try: # Encode the query query_embedding = self.encoder.encode([query]) # Calculate similarities similarities = cosine_similarity(query_embedding, self.embeddings)[0] # Get top k results top_indices = np.argsort(similarities)[::-1][:k] results = [] for idx in top_indices: if similarities[idx] > 0.1: # Minimum similarity threshold results.append({ "text": self.texts[idx], "metadata": self.metadatas[idx], "similarity": float(similarities[idx]) }) logger.info(f"Query returned {len(results)} results") return results except Exception as e: logger.error(f"Error querying RAG store: {e}") return [] def get_stats(self) -> Dict[str, Any]: """Get statistics about the RAG store""" return { "total_documents": len(self.texts), "has_embeddings": self.embeddings is not None, "encoder_available": self.encoder is not None, "collection_dir": self.collection_dir } def clear(self): """Clear all data from the RAG store""" try: self.texts = [] self.metadatas = [] self.embeddings = None # Remove files for file_path in [self.embeddings_file, self.texts_file, self.metadata_file]: if os.path.exists(file_path): os.remove(file_path) logger.info("RAG store cleared") except Exception as e: logger.error(f"Error clearing RAG store: {e}") # Utility functions for fraud-specific RAG queries def build_fraud_context(transaction_data: Dict[str, Any]) -> str: """Build a searchable text representation of transaction data""" parts = [] if 'amount' in transaction_data: parts.append(f"amount:{transaction_data['amount']}") if 'merchant' in transaction_data: parts.append(f"merchant:{transaction_data['merchant']}") if 'category' in transaction_data: parts.append(f"category:{transaction_data['category']}") if 'description' in transaction_data: parts.append(f"description:{transaction_data['description']}") if 'timestamp' in transaction_data: parts.append(f"time:{transaction_data['timestamp']}") return " ".join(parts) def extract_fraud_patterns(rag_results: List[Dict[str, Any]]) -> List[str]: """Extract common fraud patterns from RAG results""" patterns = [] for result in rag_results: metadata = result.get('metadata', {}) similarity = result.get('similarity', 0) if similarity > 0.7: # High similarity threshold if 'merchant' in metadata: patterns.append(f"Similar merchant: {metadata['merchant']}") if 'amount' in metadata: patterns.append(f"Similar amount: ${metadata['amount']}") if 'category' in metadata: patterns.append(f"Similar category: {metadata['category']}") return list(set(patterns)) # Remove duplicates