Spaces:
Sleeping
Sleeping
""" | |
RAG (Retrieval Augmented Generation) store for fraud pattern matching | |
""" | |
import os | |
import json | |
import logging | |
from typing import List, Dict, Any, Optional | |
import pickle | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
logger = logging.getLogger(__name__) | |
class RAGStore: | |
"""Simple RAG store using sentence transformers and local file storage""" | |
def __init__(self, collection_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
self.collection_dir = collection_dir | |
self.model_name = model_name | |
self.embeddings_file = os.path.join(collection_dir, "embeddings.pkl") | |
self.texts_file = os.path.join(collection_dir, "texts.json") | |
self.metadata_file = os.path.join(collection_dir, "metadata.json") | |
os.makedirs(collection_dir, exist_ok=True) | |
# Initialize sentence transformer | |
try: | |
self.encoder = SentenceTransformer(model_name) | |
logger.info(f"Initialized SentenceTransformer: {model_name}") | |
except Exception as e: | |
logger.error(f"Failed to initialize SentenceTransformer: {e}") | |
self.encoder = None | |
# Load existing data | |
self.texts = [] | |
self.metadatas = [] | |
self.embeddings = None | |
self._load_data() | |
def _load_data(self): | |
"""Load existing embeddings, texts, and metadata""" | |
try: | |
if os.path.exists(self.texts_file): | |
with open(self.texts_file, 'r') as f: | |
self.texts = json.load(f) | |
if os.path.exists(self.metadata_file): | |
with open(self.metadata_file, 'r') as f: | |
self.metadatas = json.load(f) | |
if os.path.exists(self.embeddings_file): | |
with open(self.embeddings_file, 'rb') as f: | |
self.embeddings = pickle.load(f) | |
logger.info(f"Loaded {len(self.texts)} existing documents") | |
except Exception as e: | |
logger.error(f"Error loading RAG data: {e}") | |
self.texts = [] | |
self.metadatas = [] | |
self.embeddings = None | |
def _save_data(self): | |
"""Save embeddings, texts, and metadata to files""" | |
try: | |
with open(self.texts_file, 'w') as f: | |
json.dump(self.texts, f) | |
with open(self.metadata_file, 'w') as f: | |
json.dump(self.metadatas, f, default=str) | |
if self.embeddings is not None: | |
with open(self.embeddings_file, 'wb') as f: | |
pickle.dump(self.embeddings, f) | |
logger.info(f"Saved {len(self.texts)} documents to storage") | |
except Exception as e: | |
logger.error(f"Error saving RAG data: {e}") | |
def add(self, texts: List[str], metadatas: List[Dict[str, Any]]): | |
"""Add new documents to the RAG store""" | |
if not self.encoder: | |
logger.warning("No encoder available, cannot add documents") | |
return | |
if len(texts) != len(metadatas): | |
logger.error("Texts and metadatas must have the same length") | |
return | |
try: | |
# Generate embeddings for new texts | |
new_embeddings = self.encoder.encode(texts) | |
# Add to existing data | |
self.texts.extend(texts) | |
self.metadatas.extend(metadatas) | |
if self.embeddings is None: | |
self.embeddings = new_embeddings | |
else: | |
self.embeddings = np.vstack([self.embeddings, new_embeddings]) | |
# Save to disk | |
self._save_data() | |
logger.info(f"Added {len(texts)} new documents to RAG store") | |
except Exception as e: | |
logger.error(f"Error adding documents to RAG store: {e}") | |
def query(self, query: str, k: int = 5) -> List[Dict[str, Any]]: | |
"""Query the RAG store for similar documents""" | |
if not self.encoder or self.embeddings is None or len(self.texts) == 0: | |
logger.warning("RAG store is empty or encoder unavailable") | |
return [] | |
try: | |
# Encode the query | |
query_embedding = self.encoder.encode([query]) | |
# Calculate similarities | |
similarities = cosine_similarity(query_embedding, self.embeddings)[0] | |
# Get top k results | |
top_indices = np.argsort(similarities)[::-1][:k] | |
results = [] | |
for idx in top_indices: | |
if similarities[idx] > 0.1: # Minimum similarity threshold | |
results.append({ | |
"text": self.texts[idx], | |
"metadata": self.metadatas[idx], | |
"similarity": float(similarities[idx]) | |
}) | |
logger.info(f"Query returned {len(results)} results") | |
return results | |
except Exception as e: | |
logger.error(f"Error querying RAG store: {e}") | |
return [] | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get statistics about the RAG store""" | |
return { | |
"total_documents": len(self.texts), | |
"has_embeddings": self.embeddings is not None, | |
"encoder_available": self.encoder is not None, | |
"collection_dir": self.collection_dir | |
} | |
def clear(self): | |
"""Clear all data from the RAG store""" | |
try: | |
self.texts = [] | |
self.metadatas = [] | |
self.embeddings = None | |
# Remove files | |
for file_path in [self.embeddings_file, self.texts_file, self.metadata_file]: | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
logger.info("RAG store cleared") | |
except Exception as e: | |
logger.error(f"Error clearing RAG store: {e}") | |
# Utility functions for fraud-specific RAG queries | |
def build_fraud_context(transaction_data: Dict[str, Any]) -> str: | |
"""Build a searchable text representation of transaction data""" | |
parts = [] | |
if 'amount' in transaction_data: | |
parts.append(f"amount:{transaction_data['amount']}") | |
if 'merchant' in transaction_data: | |
parts.append(f"merchant:{transaction_data['merchant']}") | |
if 'category' in transaction_data: | |
parts.append(f"category:{transaction_data['category']}") | |
if 'description' in transaction_data: | |
parts.append(f"description:{transaction_data['description']}") | |
if 'timestamp' in transaction_data: | |
parts.append(f"time:{transaction_data['timestamp']}") | |
return " ".join(parts) | |
def extract_fraud_patterns(rag_results: List[Dict[str, Any]]) -> List[str]: | |
"""Extract common fraud patterns from RAG results""" | |
patterns = [] | |
for result in rag_results: | |
metadata = result.get('metadata', {}) | |
similarity = result.get('similarity', 0) | |
if similarity > 0.7: # High similarity threshold | |
if 'merchant' in metadata: | |
patterns.append(f"Similar merchant: {metadata['merchant']}") | |
if 'amount' in metadata: | |
patterns.append(f"Similar amount: ${metadata['amount']}") | |
if 'category' in metadata: | |
patterns.append(f"Similar category: {metadata['category']}") | |
return list(set(patterns)) # Remove duplicates |