Spaces:
Sleeping
Sleeping
| """ | |
| ML Agentic Retriever for SentrySearch | |
| Implements Agentic RAG approach with intelligent query optimization. | |
| Provides intelligent ML-focused retrieval for threat intelligence with: | |
| - Query optimization for threat-to-ML translation | |
| - Source identification for relevant paper filtering | |
| - Enhanced hybrid retrieval with post-processing | |
| - Context-aware result ranking and structuring | |
| Usage: | |
| retriever = MLAgenticRetriever(anthropic_client, knowledge_base) | |
| ml_guidance = retriever.get_ml_guidance(threat_characteristics) | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import time | |
| import random | |
| import hashlib | |
| from typing import List, Dict, Optional, Tuple, Set | |
| from dataclasses import dataclass | |
| import re | |
| from anthropic import Anthropic | |
| import anthropic | |
| from src.data.ml_knowledge_base_builder import KnowledgeBaseStorage | |
| from src.search.bm25_retriever import BM25Retriever, BM25SearchResult | |
| logger = logging.getLogger(__name__) | |
| class ThreatCharacteristics: | |
| """Represents threat characteristics for ML guidance generation""" | |
| threat_name: str | |
| threat_type: str # e.g., "malware", "apt", "insider_threat" | |
| attack_vectors: List[str] # e.g., ["network", "email", "web"] | |
| target_assets: List[str] # e.g., ["user_accounts", "financial_data"] | |
| behavior_patterns: List[str] # e.g., ["lateral_movement", "data_exfiltration"] | |
| time_characteristics: str # e.g., "persistent", "burst", "periodic" | |
| class OptimizedQuery: | |
| """Represents an optimized query for ML retrieval""" | |
| original_query: str | |
| optimized_queries: List[str] | |
| ml_focus_areas: List[str] | |
| reasoning: str | |
| class SourceSelection: | |
| """Represents filtered source selection""" | |
| relevant_companies: List[str] | |
| relevant_years: List[str] | |
| relevant_techniques: List[str] | |
| reasoning: str | |
| class MLRetrievalResult: | |
| """Represents a structured ML retrieval result""" | |
| content: str | |
| metadata: Dict | |
| relevance_score: float | |
| source_paper: str | |
| ml_techniques: List[str] | |
| implementation_details: str | |
| applicability_score: float | |
| retrieval_method: str = 'unknown' # 'vector', 'bm25', or 'hybrid' | |
| bm25_score: float = 0.0 | |
| hybrid_score: float = 0.0 | |
| matched_terms: List[str] = None | |
| class QueryOptimizer: | |
| """Optimizes queries to focus on ML anomaly detection approaches""" | |
| def __init__(self, anthropic_client): | |
| self.client = anthropic_client | |
| def _api_call_with_retry(self, **kwargs): | |
| """Make API call with intelligent retry logic using retry-after header""" | |
| max_retries = 3 | |
| base_delay = 5 | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"DEBUG: Query Optimizer API call attempt {attempt + 1}/{max_retries}") | |
| return self.client.messages.create(**kwargs) | |
| except anthropic.RateLimitError as e: | |
| if attempt == max_retries - 1: | |
| print(f"DEBUG: Query Optimizer rate limit exceeded after {max_retries} attempts") | |
| raise e | |
| # Check if the error response has retry-after information | |
| retry_after = None | |
| if hasattr(e, 'response') and e.response: | |
| retry_after_header = e.response.headers.get('retry-after') | |
| if retry_after_header: | |
| try: | |
| retry_after = float(retry_after_header) | |
| print(f"DEBUG: Query Optimizer API provided retry-after: {retry_after} seconds") | |
| except (ValueError, TypeError): | |
| pass | |
| # Use retry-after if available, otherwise exponential backoff | |
| if retry_after: | |
| delay = retry_after + random.uniform(1, 3) | |
| else: | |
| delay = base_delay * (2 ** attempt) + random.uniform(1, 5) | |
| delay = min(delay, 120) | |
| print(f"DEBUG: Query Optimizer rate limit hit. Waiting {delay:.1f} seconds before retry {attempt + 2}") | |
| time.sleep(delay) | |
| except Exception as e: | |
| print(f"DEBUG: Query Optimizer non-rate-limit error: {e}") | |
| raise e | |
| def optimize_query(self, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery: | |
| """Convert threat characteristics into ML-focused queries""" | |
| prompt = f""" | |
| You are an expert in both cybersecurity threats and machine learning anomaly detection. | |
| Convert this threat information into 3-5 specific queries about ML approaches for detection. | |
| Threat Information: | |
| - Name: {threat_characteristics.threat_name} | |
| - Type: {threat_characteristics.threat_type} | |
| - Attack Vectors: {', '.join(threat_characteristics.attack_vectors)} | |
| - Target Assets: {', '.join(threat_characteristics.target_assets)} | |
| - Behavior Patterns: {', '.join(threat_characteristics.behavior_patterns)} | |
| - Time Characteristics: {threat_characteristics.time_characteristics} | |
| Generate queries that focus on: | |
| 1. Specific ML techniques for detecting this threat type | |
| 2. Feature engineering approaches for the attack vectors | |
| 3. Behavioral analysis methods for the patterns observed | |
| 4. Implementation considerations for the target environment | |
| Format your response as: | |
| QUERIES: | |
| 1. [Query 1] | |
| 2. [Query 2] | |
| 3. [Query 3] | |
| etc. | |
| ML_FOCUS_AREAS: [comma-separated focus areas] | |
| REASONING: [1-2 sentences explaining the ML approach rationale] | |
| """ | |
| try: | |
| response = self._api_call_with_retry( | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=600, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| # Safe access to response content | |
| if not response.content or len(response.content) == 0: | |
| raise ValueError("Empty response from API") | |
| content = response.content[0].text.strip() | |
| return self._parse_optimization_response(content, threat_characteristics) | |
| except Exception as e: | |
| logger.error(f"Query optimization failed: {e}") | |
| # Fallback to simple query | |
| fallback_query = f"Machine learning approaches for detecting {threat_characteristics.threat_name}" | |
| return OptimizedQuery( | |
| original_query=threat_characteristics.threat_name, | |
| optimized_queries=[fallback_query], | |
| ml_focus_areas=["anomaly_detection"], | |
| reasoning="Fallback query due to optimization failure" | |
| ) | |
| def _parse_optimization_response(self, response: str, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery: | |
| """Parse the LLM response into structured query optimization""" | |
| queries = [] | |
| ml_focus_areas = [] | |
| reasoning = "" | |
| lines = response.split('\n') | |
| current_section = None | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith('QUERIES:'): | |
| current_section = 'queries' | |
| continue | |
| elif line.startswith('ML_FOCUS_AREAS:'): | |
| current_section = 'focus' | |
| ml_focus_areas = [area.strip() for area in line.replace('ML_FOCUS_AREAS:', '').split(',')] | |
| continue | |
| elif line.startswith('REASONING:'): | |
| current_section = 'reasoning' | |
| reasoning = line.replace('REASONING:', '').strip() | |
| continue | |
| if current_section == 'queries' and line: | |
| # Extract query from numbered list | |
| query_match = re.match(r'\d+\.\s*(.+)', line) | |
| if query_match: | |
| queries.append(query_match.group(1).strip()) | |
| elif current_section == 'reasoning' and line: | |
| reasoning += ' ' + line | |
| # Ensure we have at least one query | |
| if not queries: | |
| queries = [f"Machine learning detection approaches for {threat_characteristics.threat_name}"] | |
| return OptimizedQuery( | |
| original_query=threat_characteristics.threat_name, | |
| optimized_queries=queries, | |
| ml_focus_areas=ml_focus_areas if ml_focus_areas else ["anomaly_detection"], | |
| reasoning=reasoning.strip() | |
| ) | |
| class SourceIdentifier: | |
| """Identifies most relevant papers/sources for a given threat""" | |
| def __init__(self, knowledge_base: KnowledgeBaseStorage): | |
| self.knowledge_base = knowledge_base | |
| self._load_source_mappings() | |
| def _load_source_mappings(self): | |
| """Load mappings between threat types and relevant sources""" | |
| # Company expertise mappings | |
| self.company_expertise = { | |
| 'Netflix': ['performance_monitoring', 'infrastructure_anomalies', 'streaming_security'], | |
| 'LinkedIn': ['user_behavior', 'abuse_detection', 'social_platform_security'], | |
| 'Slack': ['communication_security', 'invite_spam', 'workspace_security'], | |
| 'Cloudflare': ['network_security', 'bot_detection', 'traffic_analysis'], | |
| 'Uber': ['fraud_detection', 'real_time_systems', 'human_in_the_loop'], | |
| 'Grab': ['financial_fraud', 'graph_analysis', 'transaction_security'], | |
| 'OLX Group': ['marketplace_fraud', 'deep_learning', 'user_verification'], | |
| 'Stack Exchange': ['content_moderation', 'spam_detection', 'community_security'], | |
| 'Mercari': ['e_commerce_security', 'content_moderation', 'automated_review'] | |
| } | |
| # Attack vector to technique mappings | |
| self.attack_vector_techniques = { | |
| 'network': ['traffic_analysis', 'graph_ml', 'behavioral_analysis'], | |
| 'email': ['text_classification', 'nlp', 'spam_detection'], | |
| 'web': ['bot_detection', 'traffic_analysis', 'behavioral_analysis'], | |
| 'insider': ['user_behavior', 'behavioral_analysis', 'anomaly_detection'], | |
| 'financial': ['fraud_detection', 'transaction_analysis', 'graph_ml'] | |
| } | |
| # Threat type to ML approach mappings | |
| self.threat_ml_mappings = { | |
| 'malware': ['behavioral_analysis', 'static_analysis', 'dynamic_analysis'], | |
| 'apt': ['behavioral_analysis', 'network_analysis', 'long_term_patterns'], | |
| 'fraud': ['fraud_detection', 'transaction_analysis', 'user_behavior'], | |
| 'spam': ['text_classification', 'content_moderation', 'nlp'], | |
| 'abuse': ['user_behavior', 'abuse_detection', 'behavioral_analysis'] | |
| } | |
| def identify_relevant_sources(self, optimized_query: OptimizedQuery, | |
| threat_characteristics: ThreatCharacteristics) -> SourceSelection: | |
| """Identify most relevant papers/sources for the optimized queries""" | |
| relevant_companies = set() | |
| relevant_techniques = set() | |
| # Map attack vectors to companies | |
| for vector in threat_characteristics.attack_vectors: | |
| for company, expertise in self.company_expertise.items(): | |
| for expert_area in expertise: | |
| if any(keyword in expert_area for keyword in [vector, threat_characteristics.threat_type]): | |
| relevant_companies.add(company) | |
| # Add techniques for this attack vector | |
| if vector in self.attack_vector_techniques: | |
| relevant_techniques.update(self.attack_vector_techniques[vector]) | |
| # Map threat type to companies and techniques | |
| threat_type = threat_characteristics.threat_type.lower() | |
| for company, expertise in self.company_expertise.items(): | |
| if any(threat_type in expert_area or expert_area in threat_type for expert_area in expertise): | |
| relevant_companies.add(company) | |
| # Add techniques for threat type | |
| if threat_type in self.threat_ml_mappings: | |
| relevant_techniques.update(self.threat_ml_mappings[threat_type]) | |
| # Add techniques from ML focus areas | |
| relevant_techniques.update(optimized_query.ml_focus_areas) | |
| # Always include available companies from knowledge base | |
| stats = self.knowledge_base.get_stats() | |
| available_companies = set(stats['companies']) | |
| # For now, include all available companies to ensure we get results | |
| # In production, you can make this more selective | |
| relevant_companies.update(available_companies) | |
| # If still empty somehow, use all available | |
| if not relevant_companies: | |
| relevant_companies = available_companies | |
| # Generate reasoning | |
| reasoning = f"Selected companies based on expertise in {threat_characteristics.threat_type} and {', '.join(threat_characteristics.attack_vectors)}. Focus on {', '.join(list(relevant_techniques)[:3])} techniques." | |
| return SourceSelection( | |
| relevant_companies=list(relevant_companies), | |
| relevant_years=['2019', '2020', '2021', '2022'], # All available years | |
| relevant_techniques=list(relevant_techniques), | |
| reasoning=reasoning | |
| ) | |
| class EnhancedRetriever: | |
| """Enhanced hybrid retriever with vector search + BM25 and post-processing""" | |
| def __init__(self, knowledge_base: KnowledgeBaseStorage): | |
| self.knowledge_base = knowledge_base | |
| self.bm25_retriever = BM25Retriever(knowledge_base) | |
| def retrieve_with_context(self, optimized_query: OptimizedQuery, | |
| source_selection: SourceSelection, | |
| max_results: int = 10, | |
| trace_exporter=None) -> List[MLRetrievalResult]: | |
| """Hybrid retrieval using both vector search and BM25, with context-aware processing""" | |
| vector_results = [] | |
| bm25_results = [] | |
| # Log query optimization to trace | |
| if trace_exporter: | |
| trace_exporter.log_query_optimization( | |
| optimized_query.original_query, | |
| optimized_query.optimized_queries, | |
| optimized_query.reasoning, | |
| optimized_query.ml_focus_areas | |
| ) | |
| # 1. Vector Search - Search with each optimized query | |
| if trace_exporter: | |
| trace_exporter.log_stage_start("vector_retrieval") | |
| for query in optimized_query.optimized_queries: | |
| results = self.knowledge_base.search(query, n_results=max_results) | |
| for result in results: | |
| # Filter by relevant sources | |
| metadata = result['metadata'] | |
| company = metadata.get('company', '') | |
| if company in source_selection.relevant_companies: | |
| ml_result = self._create_ml_result(result, optimized_query, source_selection) | |
| ml_result.retrieval_method = 'vector' | |
| vector_results.append(ml_result) | |
| if trace_exporter: | |
| trace_exporter.log_stage_end("vector_retrieval", result_count=len(vector_results)) | |
| # Convert to format expected by trace exporter | |
| vector_trace_results = [self._convert_to_trace_format(r) for r in vector_results] | |
| trace_exporter.log_retrieval_results(vector_trace_results, "vector") | |
| # 2. BM25 Search - Search with each optimized query | |
| if trace_exporter: | |
| trace_exporter.log_stage_start("bm25_retrieval") | |
| for query in optimized_query.optimized_queries: | |
| bm25_search_results = self.bm25_retriever.search(query, n_results=max_results) | |
| for bm25_result in bm25_search_results: | |
| # Filter by relevant sources | |
| company = bm25_result.metadata.get('company', '') | |
| if company in source_selection.relevant_companies: | |
| # Convert BM25 result to ML result format | |
| ml_result = self._create_ml_result_from_bm25(bm25_result, optimized_query, source_selection) | |
| ml_result.retrieval_method = 'bm25' | |
| bm25_results.append(ml_result) | |
| if trace_exporter: | |
| trace_exporter.log_stage_end("bm25_retrieval", result_count=len(bm25_results)) | |
| # Convert to format expected by trace exporter | |
| bm25_trace_results = [self._convert_to_trace_format(r) for r in bm25_results] | |
| trace_exporter.log_retrieval_results(bm25_trace_results, "bm25") | |
| # 3. Combine and deduplicate results | |
| if trace_exporter: | |
| trace_exporter.log_stage_start("hybrid_fusion") | |
| all_results = self._fuse_hybrid_results(vector_results, bm25_results) | |
| # 4. Post-process results | |
| processed_results = self._post_process_results(all_results) | |
| # 5. Rank by hybrid score combining relevance and applicability | |
| ranked_results = sorted(processed_results, | |
| key=lambda x: x.hybrid_score, | |
| reverse=True) | |
| final_results = ranked_results[:max_results] | |
| # Log hybrid fusion results | |
| if trace_exporter: | |
| trace_exporter.log_stage_end("hybrid_fusion", final_result_count=len(final_results)) | |
| # Convert to format expected by trace exporter | |
| hybrid_trace_results = [self._convert_to_trace_format(r) for r in final_results] | |
| trace_exporter.log_retrieval_results(hybrid_trace_results, "hybrid") | |
| # Log hybrid scoring if available | |
| if final_results: | |
| avg_vector_score = sum(r.relevance_score for r in final_results) / len(final_results) | |
| avg_bm25_score = sum(getattr(r, 'bm25_score', 0) for r in final_results) / len(final_results) | |
| avg_hybrid_score = sum(r.hybrid_score for r in final_results) / len(final_results) | |
| avg_applicability_score = sum(r.applicability_score for r in final_results) / len(final_results) | |
| trace_exporter.log_hybrid_scoring( | |
| avg_vector_score, avg_bm25_score, avg_hybrid_score, avg_applicability_score | |
| ) | |
| return final_results | |
| def _create_ml_result(self, search_result: Dict, | |
| optimized_query: OptimizedQuery, | |
| source_selection: SourceSelection) -> MLRetrievalResult: | |
| """Convert search result to structured ML result""" | |
| metadata = search_result['metadata'] | |
| # Calculate applicability score based on technique overlap | |
| ml_techniques_raw = metadata.get('ml_techniques', '') | |
| # Ensure we have a string before splitting | |
| ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else '' | |
| paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip()) | |
| relevant_techniques = set(source_selection.relevant_techniques) | |
| technique_overlap = len(paper_techniques.intersection(relevant_techniques)) | |
| applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0) | |
| # Extract implementation details from content | |
| content = search_result['document'] | |
| implementation_details = self._extract_implementation_details(content) | |
| result = MLRetrievalResult( | |
| content=content, | |
| metadata=metadata, | |
| relevance_score=search_result['score'], | |
| source_paper=metadata.get('source_title', ''), | |
| ml_techniques=list(paper_techniques), | |
| implementation_details=implementation_details, | |
| applicability_score=applicability_score, | |
| retrieval_method='vector' | |
| ) | |
| # Calculate hybrid score (for vector results, this is just the combination) | |
| result.hybrid_score = (result.relevance_score * 0.6 + result.applicability_score * 0.4) | |
| return result | |
| def _extract_implementation_details(self, content: str) -> str: | |
| """Extract key implementation details from content""" | |
| # Look for implementation-specific keywords | |
| impl_keywords = [ | |
| 'architecture', 'framework', 'algorithm', 'model', | |
| 'feature', 'training', 'deployment', 'performance', | |
| 'accuracy', 'precision', 'recall', 'latency' | |
| ] | |
| sentences = content.split('.') | |
| impl_sentences = [] | |
| for sentence in sentences: | |
| if any(keyword in sentence.lower() for keyword in impl_keywords): | |
| impl_sentences.append(sentence.strip()) | |
| return '. '.join(impl_sentences[:3]) # Top 3 relevant sentences | |
| def _post_process_results(self, results: List[MLRetrievalResult]) -> List[MLRetrievalResult]: | |
| """Post-process results for deduplication and enhancement""" | |
| # Simple deduplication by content hash | |
| seen_hashes = set() | |
| deduplicated = [] | |
| for result in results: | |
| content_hash = hash(result.content[:200]) # Hash first 200 chars | |
| if content_hash not in seen_hashes: | |
| seen_hashes.add(content_hash) | |
| deduplicated.append(result) | |
| return deduplicated | |
| def _convert_to_trace_format(self, ml_result: MLRetrievalResult) -> Dict: | |
| """Convert MLRetrievalResult to format expected by trace exporter""" | |
| return { | |
| "content": ml_result.content, | |
| "metadata": ml_result.metadata, | |
| "score": ml_result.relevance_score, | |
| "method": ml_result.retrieval_method, | |
| "matched_terms": getattr(ml_result, 'matched_terms', None), | |
| "source_company": ml_result.metadata.get('company'), | |
| "ml_techniques": ml_result.ml_techniques | |
| } | |
| def _create_ml_result_from_bm25(self, bm25_result: BM25SearchResult, | |
| optimized_query: OptimizedQuery, | |
| source_selection: SourceSelection) -> MLRetrievalResult: | |
| """Convert BM25 search result to ML result format""" | |
| metadata = bm25_result.metadata | |
| # Calculate applicability score based on technique overlap | |
| ml_techniques_raw = metadata.get('ml_techniques', '') | |
| ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else '' | |
| paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip()) | |
| relevant_techniques = set(source_selection.relevant_techniques) | |
| technique_overlap = len(paper_techniques.intersection(relevant_techniques)) | |
| applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0) | |
| # Extract implementation details from content | |
| implementation_details = self._extract_implementation_details(bm25_result.content) | |
| result = MLRetrievalResult( | |
| content=bm25_result.content, | |
| metadata=metadata, | |
| relevance_score=bm25_result.relevance_score, | |
| source_paper=metadata.get('source_title', ''), | |
| ml_techniques=list(paper_techniques), | |
| implementation_details=implementation_details, | |
| applicability_score=applicability_score, | |
| retrieval_method='bm25', | |
| bm25_score=bm25_result.bm25_score, | |
| matched_terms=bm25_result.matched_terms | |
| ) | |
| # Calculate hybrid score (for BM25 results, give more weight to exact matches) | |
| bm25_weight = min(bm25_result.bm25_score / 5.0, 1.0) # Normalize BM25 score | |
| term_match_bonus = len(bm25_result.matched_terms) * 0.1 # Bonus for matched terms | |
| result.hybrid_score = (bm25_weight * 0.5 + result.applicability_score * 0.4 + term_match_bonus) | |
| return result | |
| def _fuse_hybrid_results(self, vector_results: List[MLRetrievalResult], | |
| bm25_results: List[MLRetrievalResult]) -> List[MLRetrievalResult]: | |
| """Fuse vector and BM25 results using reciprocal rank fusion""" | |
| # Create dictionaries for fast lookup | |
| vector_lookup = {self._get_result_key(r): r for r in vector_results} | |
| bm25_lookup = {self._get_result_key(r): r for r in bm25_results} | |
| # Get all unique result keys | |
| all_keys = set(vector_lookup.keys()) | set(bm25_lookup.keys()) | |
| fused_results = [] | |
| for key in all_keys: | |
| vector_result = vector_lookup.get(key) | |
| bm25_result = bm25_lookup.get(key) | |
| if vector_result and bm25_result: | |
| # Both methods found this result - create hybrid | |
| hybrid_result = self._create_hybrid_result(vector_result, bm25_result) | |
| fused_results.append(hybrid_result) | |
| elif vector_result: | |
| # Only vector search found this | |
| fused_results.append(vector_result) | |
| elif bm25_result: | |
| # Only BM25 found this | |
| fused_results.append(bm25_result) | |
| return fused_results | |
| def _get_result_key(self, result: MLRetrievalResult) -> str: | |
| """Generate a unique key for a result based on content hash""" | |
| # Use first 100 characters of content as key to detect near-duplicates | |
| content_key = result.content[:100] if result.content else "" | |
| return hashlib.md5(content_key.encode()).hexdigest()[:16] | |
| def _create_hybrid_result(self, vector_result: MLRetrievalResult, | |
| bm25_result: MLRetrievalResult) -> MLRetrievalResult: | |
| """Create a hybrid result by combining vector and BM25 results""" | |
| # Use vector result as base and enhance with BM25 data | |
| hybrid_result = MLRetrievalResult( | |
| content=vector_result.content, | |
| metadata=vector_result.metadata, | |
| relevance_score=max(vector_result.relevance_score, bm25_result.relevance_score), | |
| source_paper=vector_result.source_paper, | |
| ml_techniques=vector_result.ml_techniques, | |
| implementation_details=vector_result.implementation_details, | |
| applicability_score=max(vector_result.applicability_score, bm25_result.applicability_score), | |
| retrieval_method='hybrid', | |
| bm25_score=bm25_result.bm25_score, | |
| matched_terms=bm25_result.matched_terms | |
| ) | |
| # Calculate enhanced hybrid score using reciprocal rank fusion concept | |
| vector_score = vector_result.hybrid_score | |
| bm25_score = bm25_result.hybrid_score | |
| # Reciprocal rank fusion with k=60 (standard value) | |
| k = 60 | |
| vector_rank = 1.0 / (k + vector_score * 100) # Convert score to rank | |
| bm25_rank = 1.0 / (k + bm25_score * 100) | |
| # Combine ranks and add bonus for being found by both methods | |
| hybrid_result.hybrid_score = (vector_rank + bm25_rank) + 0.2 # 0.2 bonus for hybrid | |
| return hybrid_result | |
| class MLAgenticRetriever: | |
| """Main agentic retriever orchestrating all components""" | |
| def __init__(self, anthropic_client, knowledge_base_path: str = "./ml_knowledge_base"): | |
| self.client = anthropic_client | |
| self.knowledge_base = KnowledgeBaseStorage(knowledge_base_path) | |
| # Initialize agents | |
| self.query_optimizer = QueryOptimizer(anthropic_client) | |
| self.source_identifier = SourceIdentifier(self.knowledge_base) | |
| self.enhanced_retriever = EnhancedRetriever(self.knowledge_base) | |
| logger.info("ML Agentic Retriever initialized") | |
| def get_ml_guidance(self, threat_characteristics: ThreatCharacteristics, trace_exporter=None) -> Dict: | |
| """Get comprehensive ML guidance for threat detection""" | |
| try: | |
| logger.info(f"Getting ML guidance for: {threat_characteristics.threat_name}") | |
| # Step 1: Query Optimization | |
| optimized_query = self.query_optimizer.optimize_query(threat_characteristics) | |
| logger.info(f"Generated {len(optimized_query.optimized_queries)} optimized queries") | |
| # Step 2: Source Identification | |
| source_selection = self.source_identifier.identify_relevant_sources( | |
| optimized_query, threat_characteristics | |
| ) | |
| logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies") | |
| # Step 3: Enhanced Retrieval | |
| ml_results = self.enhanced_retriever.retrieve_with_context( | |
| optimized_query, source_selection, max_results=8, | |
| trace_exporter=trace_exporter | |
| ) | |
| logger.info(f"Retrieved {len(ml_results)} relevant ML approaches") | |
| # Step 4: Structure results | |
| guidance = self._structure_ml_guidance( | |
| threat_characteristics, optimized_query, source_selection, ml_results | |
| ) | |
| return guidance | |
| except Exception as e: | |
| logger.error(f"ML guidance generation failed: {e}") | |
| return self._create_fallback_guidance(threat_characteristics) | |
| def get_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics, | |
| complete_threat_data: Dict, trace_exporter=None) -> Dict: | |
| """Get enhanced ML guidance leveraging complete threat intelligence context""" | |
| try: | |
| logger.info(f"Getting enhanced ML guidance for: {threat_characteristics.threat_name}") | |
| # Step 1: Enhanced Query Optimization with threat context | |
| optimized_query = self._optimize_query_with_context(threat_characteristics, complete_threat_data) | |
| logger.info(f"Generated {len(optimized_query.optimized_queries)} context-enhanced queries") | |
| # Step 2: Enhanced Source Identification | |
| source_selection = self.source_identifier.identify_relevant_sources( | |
| optimized_query, threat_characteristics | |
| ) | |
| logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies") | |
| # Step 3: Enhanced Retrieval with threat context | |
| ml_results = self.enhanced_retriever.retrieve_with_context( | |
| optimized_query, source_selection, max_results=10, # More results for enhanced mode | |
| trace_exporter=trace_exporter | |
| ) | |
| logger.info(f"Retrieved {len(ml_results)} relevant ML approaches") | |
| # Step 4: Structure results with enhanced context | |
| guidance = self._structure_enhanced_ml_guidance( | |
| threat_characteristics, optimized_query, source_selection, ml_results, complete_threat_data | |
| ) | |
| return guidance | |
| except Exception as e: | |
| logger.error(f"Enhanced ML guidance generation failed: {e}") | |
| return self._create_enhanced_fallback_guidance(threat_characteristics, complete_threat_data) | |
| def _structure_ml_guidance(self, threat_characteristics: ThreatCharacteristics, | |
| optimized_query: OptimizedQuery, | |
| source_selection: SourceSelection, | |
| ml_results: List[MLRetrievalResult]) -> Dict: | |
| """Structure the ML guidance into organized sections""" | |
| # Group results by ML technique | |
| techniques_map = {} | |
| for result in ml_results: | |
| for technique in result.ml_techniques: | |
| if technique not in techniques_map: | |
| techniques_map[technique] = [] | |
| techniques_map[technique].append(result) | |
| # Create structured guidance | |
| guidance = { | |
| 'threat_name': threat_characteristics.threat_name, | |
| 'ml_approaches': [], | |
| 'implementation_considerations': [], | |
| 'source_papers': [], | |
| 'query_optimization': { | |
| 'original_query': optimized_query.original_query, | |
| 'optimized_queries': optimized_query.optimized_queries, | |
| 'reasoning': optimized_query.reasoning | |
| }, | |
| 'source_selection': { | |
| 'relevant_companies': source_selection.relevant_companies, | |
| 'reasoning': source_selection.reasoning | |
| } | |
| } | |
| # Add ML approaches | |
| for technique, results in techniques_map.items(): | |
| if results: # Only include techniques with results | |
| best_result = max(results, key=lambda x: x.applicability_score) | |
| approach = { | |
| 'technique': str(technique), | |
| 'description': best_result.implementation_details, | |
| 'source_company': str(best_result.metadata.get('company', '')), | |
| 'source_paper': str(best_result.source_paper), | |
| 'applicability_score': best_result.applicability_score, | |
| 'relevance_score': best_result.relevance_score, | |
| 'retrieval_method': best_result.retrieval_method, | |
| 'hybrid_score': best_result.hybrid_score, | |
| 'bm25_score': best_result.bm25_score, | |
| 'matched_terms': best_result.matched_terms if best_result.matched_terms else [] | |
| } | |
| guidance['ml_approaches'].append(approach) | |
| # Add implementation considerations | |
| for result in ml_results[:3]: # Top 3 results | |
| consideration = { | |
| 'aspect': f"{result.metadata.get('company', '')} Implementation", | |
| 'details': result.implementation_details, | |
| 'source': result.source_paper | |
| } | |
| guidance['implementation_considerations'].append(consideration) | |
| # Add source papers | |
| seen_papers = set() | |
| for result in ml_results: | |
| paper_title = result.source_paper | |
| if paper_title not in seen_papers: | |
| seen_papers.add(paper_title) | |
| paper_info = { | |
| 'title': paper_title, | |
| 'company': result.metadata.get('company', ''), | |
| 'year': result.metadata.get('year', ''), | |
| 'url': result.metadata.get('source_url', ''), | |
| 'techniques': result.ml_techniques | |
| } | |
| guidance['source_papers'].append(paper_info) | |
| return guidance | |
| def _create_fallback_guidance(self, threat_characteristics: ThreatCharacteristics) -> Dict: | |
| """Create fallback guidance when main pipeline fails""" | |
| return { | |
| 'threat_name': threat_characteristics.threat_name, | |
| 'ml_approaches': [{ | |
| 'technique': 'anomaly_detection', | |
| 'description': 'General anomaly detection approaches using statistical methods and machine learning', | |
| 'source_company': 'General', | |
| 'source_paper': 'Fallback recommendation', | |
| 'applicability_score': 0.5, | |
| 'relevance_score': 0.5 | |
| }], | |
| 'implementation_considerations': [{ | |
| 'aspect': 'General Implementation', | |
| 'details': 'Consider implementing statistical anomaly detection as a baseline approach', | |
| 'source': 'Fallback recommendation' | |
| }], | |
| 'source_papers': [], | |
| 'error': 'ML guidance generation failed - fallback recommendations provided' | |
| } | |
| def create_test_threat_characteristics() -> ThreatCharacteristics: | |
| """Create test threat characteristics for validation""" | |
| return ThreatCharacteristics( | |
| threat_name="ShadowPad", | |
| threat_type="malware", | |
| attack_vectors=["network", "lateral_movement"], | |
| target_assets=["corporate_networks", "sensitive_data"], | |
| behavior_patterns=["persistence", "data_exfiltration", "command_control"], | |
| time_characteristics="persistent" | |
| ) | |
| # Add enhanced methods to MLAgenticRetriever class | |
| def _optimize_query_with_context(self, threat_characteristics: ThreatCharacteristics, | |
| complete_threat_data: Dict) -> OptimizedQuery: | |
| """Create enhanced queries using complete threat intelligence context""" | |
| # Extract additional context for query enhancement | |
| context_elements = [] | |
| # Technical capabilities | |
| if tech_details := complete_threat_data.get('technicalDetails'): | |
| if capabilities := tech_details.get('capabilities'): | |
| # Ensure capabilities is a list before slicing | |
| if isinstance(capabilities, list): | |
| cap_names = [cap.get('name', str(cap)) if isinstance(cap, dict) else str(cap) | |
| for cap in capabilities[:3]] | |
| context_elements.extend(cap_names) | |
| # C2 protocols | |
| if c2_data := complete_threat_data.get('commandAndControl'): | |
| if methods := c2_data.get('communicationMethods'): | |
| # Ensure methods is a list before slicing | |
| if isinstance(methods, list): | |
| protocols = [method.get('protocol', str(method)) if isinstance(method, dict) else str(method) | |
| for method in methods[:2]] | |
| context_elements.extend(protocols) | |
| # Use the regular optimizer with enhanced threat characteristics | |
| enhanced_characteristics = ThreatCharacteristics( | |
| threat_name=threat_characteristics.threat_name, | |
| threat_type=threat_characteristics.threat_type, | |
| attack_vectors=threat_characteristics.attack_vectors + context_elements[:2], | |
| target_assets=threat_characteristics.target_assets, | |
| behavior_patterns=threat_characteristics.behavior_patterns + context_elements[2:4], | |
| time_characteristics=threat_characteristics.time_characteristics | |
| ) | |
| return self.query_optimizer.optimize_query(enhanced_characteristics) | |
| def _structure_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics, | |
| optimized_query: OptimizedQuery, | |
| source_selection: SourceSelection, | |
| ml_results: List, # MLRetrievalResult type | |
| complete_threat_data: Dict) -> Dict: | |
| """Structure enhanced ML guidance with threat context""" | |
| # Start with regular structuring | |
| guidance = self._structure_ml_guidance( | |
| threat_characteristics, optimized_query, source_selection, ml_results | |
| ) | |
| # Enhance with threat context | |
| guidance['threat_context_applied'] = True | |
| guidance['context_sources'] = { | |
| 'technical_details': bool(complete_threat_data.get('technicalDetails')), | |
| 'command_and_control': bool(complete_threat_data.get('commandAndControl')), | |
| 'detection_and_mitigation': bool(complete_threat_data.get('detectionAndMitigation')), | |
| 'forensic_artifacts': bool(complete_threat_data.get('forensicArtifacts')) | |
| } | |
| # Add threat-specific implementation considerations | |
| if tech_details := complete_threat_data.get('technicalDetails'): | |
| if os_data := tech_details.get('operatingSystems'): | |
| # Ensure os_data is a list before slicing | |
| if isinstance(os_data, list): | |
| os_names = [os.get('name', str(os)) if isinstance(os, dict) else str(os) for os in os_data[:2]] | |
| guidance['implementation_considerations'].append({ | |
| 'aspect': 'OS Compatibility', | |
| 'details': f'Ensure ML models are trained on {", ".join(os_names)} environments for optimal detection.', | |
| 'source': 'Threat Intelligence Profile' | |
| }) | |
| return guidance | |
| def _create_enhanced_fallback_guidance(self, threat_characteristics: ThreatCharacteristics, | |
| complete_threat_data: Dict) -> Dict: | |
| """Create enhanced fallback guidance with threat context""" | |
| fallback = self._create_fallback_guidance(threat_characteristics) | |
| # Add context-aware recommendations | |
| fallback['threat_context_applied'] = True | |
| fallback['enhanced_fallback'] = True | |
| # Add context-specific ML approaches | |
| if complete_threat_data.get('commandAndControl'): | |
| fallback['ml_approaches'].append({ | |
| 'technique': 'C2 Traffic Analysis', | |
| 'source_company': 'Context-Derived', | |
| 'description': 'ML-based detection of command and control communication patterns identified in the threat profile.', | |
| 'applicability_score': 0.8 | |
| }) | |
| if complete_threat_data.get('forensicArtifacts'): | |
| fallback['ml_approaches'].append({ | |
| 'technique': 'Artifact-Based Detection', | |
| 'source_company': 'Context-Derived', | |
| 'description': 'Machine learning models trained on forensic artifacts specific to this threat.', | |
| 'applicability_score': 0.7 | |
| }) | |
| return fallback | |
| # Monkey patch these methods onto the MLAgenticRetriever class | |
| MLAgenticRetriever._optimize_query_with_context = _optimize_query_with_context | |
| MLAgenticRetriever._structure_enhanced_ml_guidance = _structure_enhanced_ml_guidance | |
| MLAgenticRetriever._create_enhanced_fallback_guidance = _create_enhanced_fallback_guidance | |
| def main(): | |
| """Test the ML Agentic Retriever""" | |
| # Initialize | |
| api_key = os.getenv('ANTHROPIC_API_KEY') | |
| if not api_key: | |
| print("Error: ANTHROPIC_API_KEY environment variable not set") | |
| return | |
| print("π€ Testing ML Agentic Retriever") | |
| print("=" * 40) | |
| # Create components | |
| anthropic_client = Anthropic(api_key=api_key) | |
| retriever = MLAgenticRetriever(anthropic_client) | |
| # Test with sample threat | |
| threat = create_test_threat_characteristics() | |
| print(f"π― Testing with threat: {threat.threat_name}") | |
| print(f" Type: {threat.threat_type}") | |
| print(f" Attack Vectors: {', '.join(threat.attack_vectors)}") | |
| print(f" Behavior Patterns: {', '.join(threat.behavior_patterns)}") | |
| # Get ML guidance | |
| guidance = retriever.get_ml_guidance(threat) | |
| print(f"\nπ§ ML Guidance Generated:") | |
| print(f" ML Approaches: {len(guidance['ml_approaches'])}") | |
| print(f" Implementation Considerations: {len(guidance['implementation_considerations'])}") | |
| print(f" Source Papers: {len(guidance['source_papers'])}") | |
| # Show details | |
| if guidance['ml_approaches']: | |
| print(f"\nπ Top ML Approaches:") | |
| for i, approach in enumerate(guidance['ml_approaches'][:3], 1): | |
| print(f" {i}. {approach['technique']} ({approach['source_company']})") | |
| print(f" Applicability: {approach['applicability_score']:.2f}") | |
| print(f" Description: {approach['description'][:100]}...") | |
| if guidance['source_papers']: | |
| print(f"\nπ Source Papers:") | |
| for paper in guidance['source_papers'][:3]: | |
| print(f" β’ {paper['company']} ({paper['year']}): {paper['title'][:60]}...") | |
| print(f"\nβ Agentic retrieval test complete!") | |
| if __name__ == "__main__": | |
| main() |