ricomanifesto commited on
Commit
38c5fec
Β·
1 Parent(s): 1165e73

Restructure project to mirror organized directory layout

Browse files

- Organize src/ into modular directories: core/, data/, search/, ui/
- Move core logic files to src/core/
- Move data management to src/data/
- Move search functionality to src/search/
- Move UI components to src/ui/
- Update all import paths to reflect new structure
- Enhance .gitignore to exclude non-essential files
- Maintain Hugging Face Spaces compatibility

.gitignore CHANGED
@@ -1,40 +1,64 @@
1
- # Documentation (exclude all markdown except README)
2
- docs/
3
- *.md
4
- !README.md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Data files
7
- data/
8
- traces/
9
- *.db
10
- *.log
11
 
12
- # Environment and secrets
13
  .env
14
- venv/
15
- __pycache__/
16
- *.pyc
17
 
18
- # IDE files
19
  .vscode/
20
  .idea/
21
  *.swp
22
  *.swo
23
 
24
- # OS files
25
  .DS_Store
26
  Thumbs.db
27
 
28
- # Build artifacts
29
- dist/
30
- build/
31
- *.egg-info/
 
 
 
 
 
 
 
32
 
33
- # Cloudflare Workers (if you don't want to include them)
34
- wrangler.toml
 
35
 
36
- # Migration scripts (if not needed for users)
37
- scripts/migration/
38
 
39
- # Test files (optional)
40
- tests/
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
 
24
+ # Virtual environments
25
+ venv/
26
+ env/
27
+ ENV/
 
28
 
29
+ # Environment variables
30
  .env
31
+ .env.local
32
+ .env.production
 
33
 
34
+ # IDEs
35
  .vscode/
36
  .idea/
37
  *.swp
38
  *.swo
39
 
40
+ # OS
41
  .DS_Store
42
  Thumbs.db
43
 
44
+ # Application specific
45
+ /data/
46
+ docs/
47
+ tests/
48
+ scripts/
49
+ backup_*/
50
+ PROJECT_STRUCTURE.md
51
+ main.py
52
+
53
+ # Logs
54
+ *.log
55
 
56
+ # Temporary files
57
+ *.tmp
58
+ *.temp
59
 
60
+ # Node modules (if any)
61
+ node_modules/
62
 
63
+ # Cloudflare
64
+ .wrangler/
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
10
 
11
  # Import and run the main application
12
- from app import create_ui
13
 
14
  if __name__ == "__main__":
15
  create_ui()
 
9
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
10
 
11
  # Import and run the main application
12
+ from src.ui.app import create_ui
13
 
14
  if __name__ == "__main__":
15
  create_ui()
src/{markdown_generator.py β†’ core/markdown_generator.py} RENAMED
File without changes
src/{ml_guidance_generator.py β†’ core/ml_guidance_generator.py} RENAMED
@@ -22,7 +22,7 @@ from datetime import datetime
22
 
23
  from pydantic import BaseModel, Field, validator
24
  # Import Workers-based retriever (production system)
25
- from ml_workers_retriever import MLWorkersRetriever, ThreatCharacteristics
26
  from anthropic import Anthropic
27
  import anthropic
28
 
 
22
 
23
  from pydantic import BaseModel, Field, validator
24
  # Import Workers-based retriever (production system)
25
+ from src.search.ml_workers_retriever import MLWorkersRetriever, ThreatCharacteristics
26
  from anthropic import Anthropic
27
  import anthropic
28
 
src/{section_validator.py β†’ core/section_validator.py} RENAMED
@@ -8,7 +8,7 @@ import anthropic
8
  from typing import Optional, List
9
  from datetime import datetime
10
  import re
11
- from validation_criteria import SECTION_CRITERIA, VALIDATION_PROMPTS
12
 
13
 
14
  class SectionValidator:
 
8
  from typing import Optional, List
9
  from datetime import datetime
10
  import re
11
+ from src.core.validation_criteria import SECTION_CRITERIA, VALIDATION_PROMPTS
12
 
13
 
14
  class SectionValidator:
src/{threat_intel_tool.py β†’ core/threat_intel_tool.py} RENAMED
@@ -11,9 +11,9 @@ import re
11
  from pydantic import ValidationError
12
  import time
13
  import random
14
- from section_validator import SectionValidator, SectionImprover
15
- from ml_guidance_generator import MLGuidanceGenerator, ThreatCharacteristics
16
- from trace_exporter import get_trace_exporter
17
 
18
 
19
  class ThreatIntelTool:
 
11
  from pydantic import ValidationError
12
  import time
13
  import random
14
+ from src.core.section_validator import SectionValidator, SectionImprover
15
+ from src.core.ml_guidance_generator import MLGuidanceGenerator, ThreatCharacteristics
16
+ from src.core.trace_exporter import get_trace_exporter
17
 
18
 
19
  class ThreatIntelTool:
src/{trace_exporter.py β†’ core/trace_exporter.py} RENAMED
File without changes
src/{validation_criteria.py β†’ core/validation_criteria.py} RENAMED
File without changes
src/data/ml_knowledge_base_builder.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Knowledge Base Builder for SentrySearch
3
+
4
+ Builds a production-ready knowledge base from curated ML anomaly detection papers
5
+ and blog posts. Implements Agentic RAG approach with intelligent content processing.
6
+
7
+ Features:
8
+ - Real content ingestion from URLs
9
+ - LLM-powered content enrichment
10
+ - Persistent ChromaDB storage
11
+ - Question-like chunk processing
12
+ - Production-ready error handling
13
+ """
14
+
15
+ import os
16
+ import json
17
+ import time
18
+ import random
19
+ import hashlib
20
+ from typing import List, Dict, Optional, Tuple
21
+ from dataclasses import dataclass, asdict
22
+ from urllib.parse import urlparse
23
+ import logging
24
+ from pathlib import Path
25
+
26
+ import requests
27
+ from bs4 import BeautifulSoup
28
+ import chromadb
29
+ from chromadb.config import Settings
30
+ from anthropic import Anthropic
31
+ import anthropic
32
+ import numpy as np
33
+ from sklearn.feature_extraction.text import TfidfVectorizer
34
+
35
+ # Set up logging
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class MLPaperSource:
42
+ """Represents a source ML paper or blog post"""
43
+ title: str
44
+ url: str
45
+ company: str
46
+ year: str
47
+ description: str
48
+ ml_techniques: List[str]
49
+
50
+
51
+ @dataclass
52
+ class EnrichedChunk:
53
+ """Represents a processed and enriched document chunk"""
54
+ chunk_id: str
55
+ source_title: str
56
+ source_url: str
57
+ company: str
58
+ year: str
59
+ original_content: str
60
+ enriched_content: str # Question-like format
61
+ ml_techniques: List[str]
62
+ chunk_summary: str
63
+ keywords: List[str]
64
+ chunk_index: int
65
+ content_hash: str
66
+ bm25_terms: List[str] = None # Additional search terms for BM25
67
+ faq_questions: List[str] = None # FAQ-style questions
68
+
69
+
70
+ class ContentExtractor:
71
+ """Extracts and cleans content from web pages"""
72
+
73
+ def __init__(self, timeout: int = 30):
74
+ self.timeout = timeout
75
+ self.session = requests.Session()
76
+ self.session.headers.update({
77
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
78
+ })
79
+
80
+ def extract_from_url(self, url: str) -> Optional[str]:
81
+ """Extract clean text content from a URL"""
82
+ try:
83
+ logger.info(f"Extracting content from: {url}")
84
+
85
+ response = self.session.get(url, timeout=self.timeout)
86
+ response.raise_for_status()
87
+
88
+ # Parse HTML
89
+ soup = BeautifulSoup(response.content, 'html.parser')
90
+
91
+ # Remove unwanted elements
92
+ for element in soup(['script', 'style', 'nav', 'header', 'footer', 'aside']):
93
+ element.decompose()
94
+
95
+ # Extract main content
96
+ content = self._extract_main_content(soup)
97
+
98
+ # Clean and normalize text
99
+ cleaned_content = self._clean_text(content)
100
+
101
+ logger.info(f"Extracted {len(cleaned_content)} characters from {url}")
102
+ return cleaned_content
103
+
104
+ except requests.RequestException as e:
105
+ logger.error(f"Request failed for {url}: {e}")
106
+ return None
107
+ except Exception as e:
108
+ logger.error(f"Content extraction failed for {url}: {e}")
109
+ return None
110
+
111
+ def _extract_main_content(self, soup: BeautifulSoup) -> str:
112
+ """Extract main content from parsed HTML"""
113
+
114
+ # Try common article selectors
115
+ content_selectors = [
116
+ 'article',
117
+ '[role="main"]',
118
+ '.post-content',
119
+ '.article-content',
120
+ '.entry-content',
121
+ '.content',
122
+ 'main'
123
+ ]
124
+
125
+ for selector in content_selectors:
126
+ content_elem = soup.select_one(selector)
127
+ if content_elem:
128
+ return content_elem.get_text()
129
+
130
+ # Fallback to body content
131
+ body = soup.find('body')
132
+ return body.get_text() if body else soup.get_text()
133
+
134
+ def _clean_text(self, text: str) -> str:
135
+ """Clean and normalize extracted text"""
136
+ # Remove extra whitespace
137
+ lines = [line.strip() for line in text.split('\n')]
138
+ lines = [line for line in lines if line]
139
+
140
+ # Join lines and normalize spaces
141
+ cleaned = ' '.join(lines)
142
+ cleaned = ' '.join(cleaned.split())
143
+
144
+ return cleaned
145
+
146
+
147
+ class ContentEnricher:
148
+ """Enriches content using LLM-powered processing"""
149
+
150
+ def __init__(self, anthropic_client):
151
+ self.client = anthropic_client
152
+
153
+ def _api_call_with_retry(self, **kwargs):
154
+ """Make API call with intelligent retry logic using retry-after header"""
155
+ max_retries = 3
156
+ base_delay = 5
157
+
158
+ for attempt in range(max_retries):
159
+ try:
160
+ print(f"DEBUG: Content Enricher API call attempt {attempt + 1}/{max_retries}")
161
+ return self.client.messages.create(**kwargs)
162
+
163
+ except anthropic.RateLimitError as e:
164
+ if attempt == max_retries - 1:
165
+ print(f"DEBUG: Content Enricher rate limit exceeded after {max_retries} attempts")
166
+ raise e
167
+
168
+ # Check if the error response has retry-after information
169
+ retry_after = None
170
+ if hasattr(e, 'response') and e.response:
171
+ retry_after_header = e.response.headers.get('retry-after')
172
+ if retry_after_header:
173
+ try:
174
+ retry_after = float(retry_after_header)
175
+ print(f"DEBUG: Content Enricher API provided retry-after: {retry_after} seconds")
176
+ except (ValueError, TypeError):
177
+ pass
178
+
179
+ # Use retry-after if available, otherwise exponential backoff
180
+ if retry_after:
181
+ delay = retry_after + random.uniform(1, 3)
182
+ else:
183
+ delay = base_delay * (2 ** attempt) + random.uniform(1, 5)
184
+ delay = min(delay, 120)
185
+
186
+ print(f"DEBUG: Content Enricher rate limit hit. Waiting {delay:.1f} seconds before retry {attempt + 2}")
187
+ time.sleep(delay)
188
+
189
+ except Exception as e:
190
+ print(f"DEBUG: Content Enricher non-rate-limit error: {e}")
191
+ raise e
192
+
193
+ def enrich_chunk(self, chunk: str, source: MLPaperSource) -> Dict[str, str]:
194
+ """Enrich a chunk with summary, keywords, question-like format, and BM25-optimized metadata"""
195
+
196
+ prompt = f"""
197
+ Analyze this text chunk from a machine learning anomaly detection paper/blog and provide:
198
+
199
+ 1. QUESTION_FORMAT: Rewrite the chunk content as if it's answering questions about the ML approach
200
+ 2. SUMMARY: A 2-line summary of what this chunk covers
201
+ 3. KEYWORDS: 5-8 relevant technical keywords (comma-separated)
202
+ 4. BM25_TERMS: Additional search terms for BM25 retrieval (comma-separated, include variations, synonyms, acronyms)
203
+ 5. FAQ_QUESTIONS: 2-3 potential questions this chunk could answer (pipe-separated)
204
+
205
+ Source Context:
206
+ - Company: {source.company}
207
+ - ML Techniques: {', '.join(source.ml_techniques)}
208
+ - Year: {source.year}
209
+
210
+ Text Chunk:
211
+ {chunk[:1500]}
212
+
213
+ Format your response as:
214
+ QUESTION_FORMAT: [rewritten content]
215
+ SUMMARY: [summary]
216
+ KEYWORDS: [keywords]
217
+ BM25_TERMS: [search terms with variations]
218
+ FAQ_QUESTIONS: [question1|question2|question3]
219
+ """
220
+
221
+ try:
222
+ response = self._api_call_with_retry(
223
+ model="claude-sonnet-4-20250514",
224
+ max_tokens=800,
225
+ messages=[{"role": "user", "content": prompt}]
226
+ )
227
+
228
+ # Safe access to response content
229
+ if not response.content or len(response.content) == 0:
230
+ raise ValueError("Empty response from content enrichment API")
231
+
232
+ if not hasattr(response.content[0], 'text'):
233
+ raise ValueError("Response content missing text attribute")
234
+
235
+ content = response.content[0].text.strip()
236
+ return self._parse_enrichment_response(content)
237
+
238
+ except Exception as e:
239
+ logger.error(f"Content enrichment failed: {e}")
240
+ # Return fallback enrichment
241
+ return {
242
+ 'question_format': chunk,
243
+ 'summary': f"Content about {source.ml_techniques[0]} implementation at {source.company}",
244
+ 'keywords': ', '.join(source.ml_techniques + [source.company.lower(), 'anomaly detection']),
245
+ 'bm25_terms': ', '.join(source.ml_techniques + [source.company.lower(), 'ml', 'detection', 'analysis']),
246
+ 'faq_questions': f"How does {source.company} implement {source.ml_techniques[0]}?|What is {source.ml_techniques[0]} used for?"
247
+ }
248
+
249
+ def _parse_enrichment_response(self, response: str) -> Dict[str, str]:
250
+ """Parse LLM response into structured enrichment data"""
251
+ result = {
252
+ 'question_format': '',
253
+ 'summary': '',
254
+ 'keywords': '',
255
+ 'bm25_terms': '',
256
+ 'faq_questions': ''
257
+ }
258
+
259
+ lines = response.split('\n')
260
+ current_field = None
261
+
262
+ for line in lines:
263
+ line = line.strip()
264
+ if line.startswith('QUESTION_FORMAT:'):
265
+ current_field = 'question_format'
266
+ result[current_field] = line.replace('QUESTION_FORMAT:', '').strip()
267
+ elif line.startswith('SUMMARY:'):
268
+ current_field = 'summary'
269
+ result[current_field] = line.replace('SUMMARY:', '').strip()
270
+ elif line.startswith('KEYWORDS:'):
271
+ current_field = 'keywords'
272
+ result[current_field] = line.replace('KEYWORDS:', '').strip()
273
+ elif line.startswith('BM25_TERMS:'):
274
+ current_field = 'bm25_terms'
275
+ result[current_field] = line.replace('BM25_TERMS:', '').strip()
276
+ elif line.startswith('FAQ_QUESTIONS:'):
277
+ current_field = 'faq_questions'
278
+ result[current_field] = line.replace('FAQ_QUESTIONS:', '').strip()
279
+ elif current_field and line:
280
+ result[current_field] += ' ' + line
281
+
282
+ return result
283
+
284
+
285
+ class DocumentProcessor:
286
+ """Processes documents into enriched chunks"""
287
+
288
+ def __init__(self, content_enricher: ContentEnricher, chunk_size: int = 800):
289
+ self.enricher = content_enricher
290
+ self.chunk_size = chunk_size
291
+
292
+ def process_document(self, source: MLPaperSource, content: str) -> List[EnrichedChunk]:
293
+ """Process a document into enriched chunks"""
294
+
295
+ if not content or len(content) < 100:
296
+ logger.warning(f"Content too short for {source.title}")
297
+ return []
298
+
299
+ # Create chunks
300
+ chunks = self._create_chunks(content, source)
301
+
302
+ # Enrich each chunk
303
+ enriched_chunks = []
304
+ for i, chunk_content in enumerate(chunks):
305
+
306
+ # Generate content hash for deduplication
307
+ content_hash = hashlib.md5(chunk_content.encode()).hexdigest()
308
+
309
+ # Enrich with LLM
310
+ enrichment = self.enricher.enrich_chunk(chunk_content, source)
311
+
312
+ chunk = EnrichedChunk(
313
+ chunk_id=f"{source.company}_{source.year}_{i}_{content_hash[:8]}",
314
+ source_title=source.title,
315
+ source_url=source.url,
316
+ company=source.company,
317
+ year=source.year,
318
+ original_content=chunk_content,
319
+ enriched_content=enrichment['question_format'],
320
+ ml_techniques=source.ml_techniques,
321
+ chunk_summary=enrichment['summary'],
322
+ keywords=enrichment['keywords'].split(', ') if enrichment['keywords'] else [],
323
+ chunk_index=i,
324
+ content_hash=content_hash
325
+ )
326
+
327
+ # Add BM25-specific metadata to chunk
328
+ chunk.bm25_terms = enrichment.get('bm25_terms', '').split(', ') if enrichment.get('bm25_terms') else []
329
+ chunk.faq_questions = enrichment.get('faq_questions', '').split('|') if enrichment.get('faq_questions') else []
330
+
331
+ enriched_chunks.append(chunk)
332
+
333
+ # Rate limiting for API calls
334
+ time.sleep(0.5)
335
+
336
+ logger.info(f"Processed {len(enriched_chunks)} chunks for {source.title}")
337
+ return enriched_chunks
338
+
339
+ def _create_chunks(self, content: str, source: MLPaperSource) -> List[str]:
340
+ """Create overlapping chunks from content"""
341
+ chunks = []
342
+ overlap = self.chunk_size // 4 # 25% overlap
343
+
344
+ for i in range(0, len(content), self.chunk_size - overlap):
345
+ chunk = content[i:i + self.chunk_size]
346
+
347
+ # Skip very short chunks
348
+ if len(chunk) < 200:
349
+ continue
350
+
351
+ # Try to break at sentence boundaries
352
+ if i + self.chunk_size < len(content):
353
+ last_period = chunk.rfind('.')
354
+ if last_period > len(chunk) * 0.7: # If period is in last 30%
355
+ chunk = chunk[:last_period + 1]
356
+
357
+ chunks.append(chunk.strip())
358
+
359
+ return chunks
360
+
361
+
362
+ class KnowledgeBaseStorage:
363
+ """Manages persistent storage of the knowledge base"""
364
+
365
+ def __init__(self, storage_path: str = "./ml_knowledge_base"):
366
+ self.storage_path = Path(storage_path)
367
+ self.storage_path.mkdir(exist_ok=True)
368
+
369
+ # Initialize ChromaDB with persistent storage
370
+ self.chroma_client = chromadb.PersistentClient(
371
+ path=str(self.storage_path / "chroma_db")
372
+ )
373
+
374
+ self.collection_name = "ml_anomaly_detection"
375
+ self.collection = None
376
+ self._initialize_collection()
377
+
378
+ def _initialize_collection(self):
379
+ """Initialize or get existing collection"""
380
+ try:
381
+ # Try to get existing collection
382
+ self.collection = self.chroma_client.get_collection(self.collection_name)
383
+ logger.info(f"Loaded existing collection with {self.collection.count()} documents")
384
+ except:
385
+ # Create new collection
386
+ self.collection = self.chroma_client.create_collection(
387
+ name=self.collection_name,
388
+ metadata={"description": "ML Anomaly Detection Knowledge Base"}
389
+ )
390
+ logger.info("Created new collection")
391
+
392
+ def add_chunks(self, chunks: List[EnrichedChunk]) -> bool:
393
+ """Add enriched chunks to the knowledge base"""
394
+ try:
395
+ if not chunks:
396
+ return True
397
+
398
+ # Prepare data for ChromaDB
399
+ documents = []
400
+ metadatas = []
401
+ ids = []
402
+
403
+ for chunk in chunks:
404
+ # Create enriched document text
405
+ document_text = f"""
406
+ Title: {chunk.source_title}
407
+ Company: {chunk.company}
408
+ Year: {chunk.year}
409
+ ML Techniques: {', '.join(chunk.ml_techniques)}
410
+ Keywords: {', '.join(chunk.keywords)}
411
+ Summary: {chunk.chunk_summary}
412
+
413
+ Content: {chunk.enriched_content}
414
+ """.strip()
415
+
416
+ documents.append(document_text)
417
+ metadatas.append({
418
+ 'source_title': chunk.source_title,
419
+ 'source_url': chunk.source_url,
420
+ 'company': chunk.company,
421
+ 'year': chunk.year,
422
+ 'ml_techniques': '|'.join(chunk.ml_techniques),
423
+ 'keywords': '|'.join(chunk.keywords),
424
+ 'chunk_summary': chunk.chunk_summary,
425
+ 'chunk_index': chunk.chunk_index,
426
+ 'content_hash': chunk.content_hash,
427
+ 'bm25_terms': '|'.join(chunk.bm25_terms) if chunk.bm25_terms else '',
428
+ 'faq_questions': '|'.join(chunk.faq_questions) if chunk.faq_questions else ''
429
+ })
430
+ ids.append(chunk.chunk_id)
431
+
432
+ # Add to ChromaDB
433
+ self.collection.add(
434
+ documents=documents,
435
+ metadatas=metadatas,
436
+ ids=ids
437
+ )
438
+
439
+ # Save chunk details as JSON backup
440
+ self._save_chunks_backup(chunks)
441
+
442
+ logger.info(f"Added {len(chunks)} chunks to knowledge base")
443
+ return True
444
+
445
+ except Exception as e:
446
+ logger.error(f"Failed to add chunks to knowledge base: {e}")
447
+ return False
448
+
449
+ def _save_chunks_backup(self, chunks: List[EnrichedChunk]):
450
+ """Save chunk details as JSON backup"""
451
+ backup_file = self.storage_path / "chunks_backup.jsonl"
452
+
453
+ with open(backup_file, 'a', encoding='utf-8') as f:
454
+ for chunk in chunks:
455
+ f.write(json.dumps(asdict(chunk), ensure_ascii=False) + '\n')
456
+
457
+ def get_stats(self) -> Dict:
458
+ """Get knowledge base statistics"""
459
+ try:
460
+ count = self.collection.count()
461
+
462
+ # Get unique companies and years
463
+ if count > 0:
464
+ results = self.collection.get(include=['metadatas'])
465
+ companies = set()
466
+ years = set()
467
+ ml_techniques = set()
468
+
469
+ for metadata in results['metadatas']:
470
+ companies.add(metadata.get('company', ''))
471
+ years.add(metadata.get('year', ''))
472
+ techniques = metadata.get('ml_techniques', '').split('|')
473
+ ml_techniques.update([t for t in techniques if t])
474
+
475
+ return {
476
+ 'total_chunks': count,
477
+ 'companies': sorted(list(companies)),
478
+ 'years': sorted(list(years)),
479
+ 'ml_techniques': sorted(list(ml_techniques)),
480
+ 'storage_path': str(self.storage_path)
481
+ }
482
+ else:
483
+ return {
484
+ 'total_chunks': 0,
485
+ 'companies': [],
486
+ 'years': [],
487
+ 'ml_techniques': [],
488
+ 'storage_path': str(self.storage_path)
489
+ }
490
+
491
+ except Exception as e:
492
+ logger.error(f"Failed to get stats: {e}")
493
+ return {'error': str(e)}
494
+
495
+ def search(self, query: str, n_results: int = 10) -> List[Dict]:
496
+ """Search the knowledge base"""
497
+ try:
498
+ results = self.collection.query(
499
+ query_texts=[query],
500
+ n_results=n_results,
501
+ include=['documents', 'metadatas', 'distances']
502
+ )
503
+
504
+ search_results = []
505
+ for i, doc in enumerate(results['documents'][0]):
506
+ search_results.append({
507
+ 'document': doc,
508
+ 'metadata': results['metadatas'][0][i],
509
+ 'distance': results['distances'][0][i],
510
+ 'score': 1 / (1 + results['distances'][0][i]) # Convert distance to similarity
511
+ })
512
+
513
+ return search_results
514
+
515
+ except Exception as e:
516
+ logger.error(f"Search failed: {e}")
517
+ return []
518
+
519
+
520
+ def get_curated_ml_sources() -> List[MLPaperSource]:
521
+ """Get the curated list of ML anomaly detection sources"""
522
+
523
+ sources = [
524
+ MLPaperSource(
525
+ title="Detecting Performance Anomalies in External Firmware Deployments",
526
+ url="https://netflixtechblog.com/detecting-performance-anomalies-in-external-firmware-deployments-ed41b1bfcf46",
527
+ company="Netflix",
528
+ year="2019",
529
+ description="Netflix's approach to detecting anomalies in firmware performance using ML",
530
+ ml_techniques=["statistical_analysis", "anomaly_detection", "performance_monitoring"]
531
+ ),
532
+ MLPaperSource(
533
+ title="Detecting and Preventing Abuse on LinkedIn using Isolation Forests",
534
+ url="https://engineering.linkedin.com/blog/2019/isolation-forest",
535
+ company="LinkedIn",
536
+ year="2019",
537
+ description="LinkedIn's implementation of isolation forests for abuse detection",
538
+ ml_techniques=["isolation_forest", "unsupervised_learning", "abuse_detection"]
539
+ ),
540
+ MLPaperSource(
541
+ title="How Does Spam Protection Work on Stack Exchange?",
542
+ url="https://stackoverflow.blog/2020/06/25/how-does-spam-protection-work-on-stack-exchange/",
543
+ company="Stack Exchange",
544
+ year="2020",
545
+ description="Stack Exchange's ML-based spam detection system",
546
+ ml_techniques=["text_classification", "nlp", "spam_detection"]
547
+ ),
548
+ MLPaperSource(
549
+ title="Blocking Slack Invite Spam With Machine Learning",
550
+ url="https://slack.engineering/blocking-slack-invite-spam-with-machine-learning/",
551
+ company="Slack",
552
+ year="2020",
553
+ description="Slack's ML approach to preventing invite spam",
554
+ ml_techniques=["classification", "feature_engineering", "spam_detection"]
555
+ ),
556
+ MLPaperSource(
557
+ title="Cloudflare Bot Management: Machine Learning and More",
558
+ url="https://blog.cloudflare.com/cloudflare-bot-management-machine-learning-and-more/",
559
+ company="Cloudflare",
560
+ year="2020",
561
+ description="Cloudflare's ML-powered bot detection and management",
562
+ ml_techniques=["behavioral_analysis", "traffic_analysis", "bot_detection"]
563
+ ),
564
+ MLPaperSource(
565
+ title="Graph for Fraud Detection",
566
+ url="https://engineering.grab.com/graph-for-fraud-detection",
567
+ company="Grab",
568
+ year="2022",
569
+ description="Grab's graph-based approach to fraud detection",
570
+ ml_techniques=["graph_ml", "fraud_detection", "network_analysis"]
571
+ ),
572
+ MLPaperSource(
573
+ title="Machine Learning for Fraud Detection in Streaming Services",
574
+ url="https://netflixtechblog.com/machine-learning-for-fraud-detection-in-streaming-services-b0b4ef3be3f6",
575
+ company="Netflix",
576
+ year="2023",
577
+ description="Netflix's ML approach to detecting fraud in streaming services",
578
+ ml_techniques=["fraud_detection", "streaming_analytics", "behavioral_analysis"]
579
+ ),
580
+ MLPaperSource(
581
+ title="Data Generation and Sampling Strategies",
582
+ url="https://blog.cloudflare.com/data-generation-and-sampling-strategies/",
583
+ company="Cloudflare",
584
+ year="2023",
585
+ description="Cloudflare's data generation and sampling strategies for ML training",
586
+ ml_techniques=["data_generation", "sampling", "training_data"]
587
+ ),
588
+ MLPaperSource(
589
+ title="Machine Learning Mobile Traffic Bots",
590
+ url="https://blog.cloudflare.com/machine-learning-mobile-traffic-bots/",
591
+ company="Cloudflare",
592
+ year="2023",
593
+ description="Cloudflare's ML approach to detecting mobile traffic bots",
594
+ ml_techniques=["bot_detection", "mobile_traffic", "behavioral_analysis"]
595
+ ),
596
+ MLPaperSource(
597
+ title="Project Radar: Intelligent Early Fraud Detection",
598
+ url="https://www.uber.com/blog/project-radar-intelligent-early-fraud-detection/",
599
+ company="Uber",
600
+ year="2023",
601
+ description="Uber's Project Radar for intelligent early fraud detection",
602
+ ml_techniques=["fraud_detection", "early_detection", "real_time_ml"]
603
+ )
604
+ ]
605
+
606
+ return sources
607
+
608
+
609
+ def main():
610
+ """Main function to build the ML knowledge base"""
611
+
612
+ # Initialize components
613
+ api_key = os.getenv('ANTHROPIC_API_KEY')
614
+ if not api_key:
615
+ logger.error("ANTHROPIC_API_KEY environment variable not set")
616
+ return
617
+
618
+ print("πŸ”¨ Building ML Anomaly Detection Knowledge Base")
619
+ print("=" * 50)
620
+
621
+ # Initialize components
622
+ anthropic_client = Anthropic(api_key=api_key)
623
+ content_extractor = ContentExtractor()
624
+ content_enricher = ContentEnricher(anthropic_client)
625
+ document_processor = DocumentProcessor(content_enricher)
626
+ knowledge_base = KnowledgeBaseStorage()
627
+
628
+ # Get current stats
629
+ current_stats = knowledge_base.get_stats()
630
+ print(f"πŸ“Š Current knowledge base: {current_stats['total_chunks']} chunks")
631
+
632
+ # Get sources to process
633
+ sources = get_curated_ml_sources()
634
+ print(f"πŸ“š Processing {len(sources)} ML sources...")
635
+
636
+ # Process each source
637
+ total_chunks_added = 0
638
+ successful_sources = 0
639
+
640
+ for i, source in enumerate(sources, 1):
641
+ print(f"\nπŸ”„ [{i}/{len(sources)}] Processing: {source.title}")
642
+ print(f" Company: {source.company} | Year: {source.year}")
643
+
644
+ # Extract content
645
+ content = content_extractor.extract_from_url(source.url)
646
+
647
+ if not content:
648
+ print(f" ❌ Failed to extract content")
649
+ continue
650
+
651
+ print(f" πŸ“ Extracted {len(content):,} characters")
652
+
653
+ # Process into chunks
654
+ chunks = document_processor.process_document(source, content)
655
+
656
+ if not chunks:
657
+ print(f" ❌ No chunks generated")
658
+ continue
659
+
660
+ print(f" 🧩 Generated {len(chunks)} chunks")
661
+
662
+ # Add to knowledge base
663
+ if knowledge_base.add_chunks(chunks):
664
+ total_chunks_added += len(chunks)
665
+ successful_sources += 1
666
+ print(f" βœ… Added to knowledge base")
667
+ else:
668
+ print(f" ❌ Failed to add to knowledge base")
669
+
670
+ # Final stats
671
+ print(f"\nπŸŽ‰ Knowledge Base Build Complete!")
672
+ print("=" * 50)
673
+ print(f"Sources processed: {successful_sources}/{len(sources)}")
674
+ print(f"Total chunks added: {total_chunks_added}")
675
+
676
+ final_stats = knowledge_base.get_stats()
677
+ print(f"Final knowledge base size: {final_stats['total_chunks']} chunks")
678
+ print(f"Companies: {', '.join(final_stats['companies'])}")
679
+ print(f"Years: {', '.join(final_stats['years'])}")
680
+ print(f"Storage location: {final_stats['storage_path']}")
681
+
682
+ # Test search
683
+ print(f"\nπŸ” Testing search functionality...")
684
+ test_queries = [
685
+ "How does Netflix detect performance anomalies?",
686
+ "What ML techniques work for fraud detection?",
687
+ "Isolation forest implementation details"
688
+ ]
689
+
690
+ for query in test_queries:
691
+ results = knowledge_base.search(query, n_results=3)
692
+ print(f"\nQuery: '{query}'")
693
+ print(f"Results: {len(results)} found")
694
+ if results:
695
+ top_result = results[0]
696
+ print(f"Top match: {top_result['metadata']['company']} - {top_result['metadata']['source_title'][:60]}...")
697
+ print(f"Score: {top_result['score']:.3f}")
698
+
699
+
700
+ if __name__ == "__main__":
701
+ main()
src/search/bm25_retriever.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 Retriever for SentrySearch
3
+
4
+ Implements BM25-based retrieval with enriched metadata support for the agentic RAG system.
5
+ Provides complementary keyword-based retrieval alongside vector search for enhanced precision.
6
+
7
+ Features:
8
+ - BM25 algorithm for exact keyword matching
9
+ - Enriched metadata indexing (summaries, FAQs, keywords)
10
+ - Integration with existing knowledge base
11
+ - Result scoring and ranking
12
+ - Efficient document preprocessing
13
+ """
14
+
15
+ import os
16
+ import json
17
+ import logging
18
+ import time
19
+ from typing import List, Dict, Optional, Tuple, Set
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+ import re
23
+ import pickle
24
+
25
+ from rank_bm25 import BM25Okapi
26
+ import numpy as np
27
+ from src.data.ml_knowledge_base_builder import KnowledgeBaseStorage
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class BM25Document:
34
+ """Represents a document optimized for BM25 retrieval"""
35
+ doc_id: str
36
+ content: str
37
+ enriched_content: str # Enhanced with metadata
38
+ metadata: Dict
39
+ keywords: List[str]
40
+ summary: str
41
+ preprocessed_tokens: List[str]
42
+
43
+
44
+ @dataclass
45
+ class BM25SearchResult:
46
+ """Represents a BM25 search result"""
47
+ doc_id: str
48
+ content: str
49
+ metadata: Dict
50
+ bm25_score: float
51
+ matched_terms: List[str]
52
+ relevance_score: float # Normalized score
53
+
54
+
55
+ class BM25Preprocessor:
56
+ """Preprocesses documents for BM25 indexing"""
57
+
58
+ def __init__(self):
59
+ # Common stopwords for technical content
60
+ self.stopwords = {
61
+ 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
62
+ 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
63
+ 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
64
+ 'should', 'may', 'might', 'can', 'this', 'that', 'these', 'those',
65
+ 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them'
66
+ }
67
+
68
+ def preprocess_text(self, text: str) -> List[str]:
69
+ """Preprocess text for BM25 indexing"""
70
+ if not text:
71
+ return []
72
+
73
+ # Convert to lowercase
74
+ text = text.lower()
75
+
76
+ # Keep technical terms and alphanumeric sequences
77
+ # Split on whitespace and punctuation but preserve underscores and hyphens in technical terms
78
+ tokens = re.findall(r'\b[a-zA-Z0-9_-]+\b', text)
79
+
80
+ # Filter tokens
81
+ filtered_tokens = []
82
+ for token in tokens:
83
+ # Skip stopwords
84
+ if token in self.stopwords:
85
+ continue
86
+
87
+ # Skip very short tokens unless they're technical (like ML, AI, etc.)
88
+ if len(token) < 2 and not token.isupper():
89
+ continue
90
+
91
+ # Skip very long tokens (likely noise)
92
+ if len(token) > 50:
93
+ continue
94
+
95
+ filtered_tokens.append(token)
96
+
97
+ return filtered_tokens
98
+
99
+ def create_enriched_content(self, chunk_data: Dict) -> str:
100
+ """Create enriched content optimized for BM25 search"""
101
+ content_parts = []
102
+
103
+ # Original content
104
+ if chunk_data.get('enriched_content'):
105
+ content_parts.append(chunk_data['enriched_content'])
106
+
107
+ # Add weighted metadata
108
+ metadata = chunk_data.get('metadata', {})
109
+
110
+ # Company name (high weight)
111
+ if company := metadata.get('company'):
112
+ content_parts.append(f"{company} {company} {company}") # Triple weight
113
+
114
+ # ML techniques (high weight)
115
+ if ml_techniques := metadata.get('ml_techniques'):
116
+ techniques = ml_techniques.split('|') if isinstance(ml_techniques, str) else []
117
+ for technique in techniques:
118
+ if technique.strip():
119
+ # Double weight for techniques
120
+ content_parts.append(f"{technique} {technique}")
121
+
122
+ # Keywords (medium weight)
123
+ if keywords := metadata.get('keywords'):
124
+ keyword_list = keywords.split('|') if isinstance(keywords, str) else []
125
+ content_parts.extend(keyword_list)
126
+
127
+ # Summary (medium weight)
128
+ if summary := metadata.get('chunk_summary'):
129
+ content_parts.append(summary)
130
+
131
+ # Source title (medium weight)
132
+ if title := metadata.get('source_title'):
133
+ content_parts.append(title)
134
+
135
+ return ' '.join(content_parts)
136
+
137
+
138
+ class BM25Retriever:
139
+ """BM25-based retriever with enriched metadata support"""
140
+
141
+ def __init__(self, knowledge_base: KnowledgeBaseStorage,
142
+ storage_path: str = "./ml_knowledge_base"):
143
+ self.knowledge_base = knowledge_base
144
+ self.storage_path = Path(storage_path)
145
+ self.preprocessor = BM25Preprocessor()
146
+
147
+ # BM25 components
148
+ self.bm25_index = None
149
+ self.documents = []
150
+ self.doc_lookup = {} # doc_id -> document mapping
151
+
152
+ # Storage files
153
+ self.bm25_cache_file = self.storage_path / "bm25_index.pkl"
154
+ self.docs_cache_file = self.storage_path / "bm25_documents.json"
155
+
156
+ # Initialize
157
+ self._initialize_bm25_index()
158
+
159
+ def _initialize_bm25_index(self):
160
+ """Initialize or load existing BM25 index"""
161
+ try:
162
+ # Try to load cached index
163
+ if self._load_cached_index():
164
+ logger.info(f"Loaded cached BM25 index with {len(self.documents)} documents")
165
+ return
166
+
167
+ # Build new index
168
+ logger.info("Building new BM25 index...")
169
+ self._build_bm25_index()
170
+
171
+ except Exception as e:
172
+ logger.error(f"Failed to initialize BM25 index: {e}")
173
+ self.bm25_index = None
174
+ self.documents = []
175
+
176
+ def _load_cached_index(self) -> bool:
177
+ """Load cached BM25 index if available"""
178
+ if not (self.bm25_cache_file.exists() and self.docs_cache_file.exists()):
179
+ return False
180
+
181
+ try:
182
+ # Load BM25 index
183
+ with open(self.bm25_cache_file, 'rb') as f:
184
+ self.bm25_index = pickle.load(f)
185
+
186
+ # Load documents
187
+ with open(self.docs_cache_file, 'r', encoding='utf-8') as f:
188
+ docs_data = json.load(f)
189
+
190
+ # Reconstruct documents
191
+ self.documents = []
192
+ self.doc_lookup = {}
193
+
194
+ for doc_data in docs_data:
195
+ doc = BM25Document(**doc_data)
196
+ self.documents.append(doc)
197
+ self.doc_lookup[doc.doc_id] = doc
198
+
199
+ return True
200
+
201
+ except Exception as e:
202
+ logger.warning(f"Failed to load cached BM25 index: {e}")
203
+ return False
204
+
205
+ def _build_bm25_index(self):
206
+ """Build BM25 index from knowledge base"""
207
+ try:
208
+ # Get all documents from ChromaDB
209
+ results = self.knowledge_base.collection.get(
210
+ include=['documents', 'metadatas']
211
+ )
212
+
213
+ if not results['ids']:
214
+ logger.warning("No documents found in knowledge base")
215
+ return
216
+
217
+ logger.info(f"Processing {len(results['ids'])} documents for BM25 indexing...")
218
+
219
+ # Process each document
220
+ bm25_documents = []
221
+ tokenized_docs = []
222
+
223
+ for i, doc_id in enumerate(results['ids']):
224
+ try:
225
+ # Create enriched content for BM25
226
+ chunk_data = {
227
+ 'enriched_content': results['documents'][i],
228
+ 'metadata': results['metadatas'][i]
229
+ }
230
+
231
+ enriched_content = self.preprocessor.create_enriched_content(chunk_data)
232
+
233
+ # Preprocess for BM25
234
+ tokens = self.preprocessor.preprocess_text(enriched_content)
235
+
236
+ if not tokens: # Skip empty documents
237
+ continue
238
+
239
+ # Create BM25 document
240
+ bm25_doc = BM25Document(
241
+ doc_id=doc_id,
242
+ content=results['documents'][i],
243
+ enriched_content=enriched_content,
244
+ metadata=results['metadatas'][i],
245
+ keywords=results['metadatas'][i].get('keywords', '').split('|'),
246
+ summary=results['metadatas'][i].get('chunk_summary', ''),
247
+ preprocessed_tokens=tokens
248
+ )
249
+
250
+ bm25_documents.append(bm25_doc)
251
+ tokenized_docs.append(tokens)
252
+
253
+ except Exception as e:
254
+ logger.warning(f"Failed to process document {doc_id}: {e}")
255
+ continue
256
+
257
+ if not bm25_documents:
258
+ logger.error("No valid documents processed for BM25")
259
+ return
260
+
261
+ # Build BM25 index
262
+ logger.info(f"Building BM25 index with {len(bm25_documents)} documents...")
263
+ self.bm25_index = BM25Okapi(tokenized_docs)
264
+ self.documents = bm25_documents
265
+
266
+ # Create lookup dictionary
267
+ self.doc_lookup = {doc.doc_id: doc for doc in self.documents}
268
+
269
+ # Cache the index
270
+ self._cache_bm25_index()
271
+
272
+ logger.info(f"BM25 index built successfully with {len(self.documents)} documents")
273
+
274
+ except Exception as e:
275
+ logger.error(f"Failed to build BM25 index: {e}")
276
+ self.bm25_index = None
277
+ self.documents = []
278
+
279
+ def _cache_bm25_index(self):
280
+ """Cache BM25 index to disk"""
281
+ try:
282
+ # Cache BM25 index
283
+ with open(self.bm25_cache_file, 'wb') as f:
284
+ pickle.dump(self.bm25_index, f)
285
+
286
+ # Cache documents (convert to JSON-serializable format)
287
+ docs_data = []
288
+ for doc in self.documents:
289
+ doc_dict = {
290
+ 'doc_id': doc.doc_id,
291
+ 'content': doc.content,
292
+ 'enriched_content': doc.enriched_content,
293
+ 'metadata': doc.metadata,
294
+ 'keywords': doc.keywords,
295
+ 'summary': doc.summary,
296
+ 'preprocessed_tokens': doc.preprocessed_tokens
297
+ }
298
+ docs_data.append(doc_dict)
299
+
300
+ with open(self.docs_cache_file, 'w', encoding='utf-8') as f:
301
+ json.dump(docs_data, f, ensure_ascii=False, indent=2)
302
+
303
+ logger.info("BM25 index cached successfully")
304
+
305
+ except Exception as e:
306
+ logger.warning(f"Failed to cache BM25 index: {e}")
307
+
308
+ def search(self, query: str, n_results: int = 10,
309
+ min_score: float = 0.0) -> List[BM25SearchResult]:
310
+ """Search using BM25 algorithm"""
311
+ if not self.bm25_index or not self.documents:
312
+ logger.warning("BM25 index not available")
313
+ return []
314
+
315
+ try:
316
+ # Preprocess query
317
+ query_tokens = self.preprocessor.preprocess_text(query)
318
+
319
+ if not query_tokens:
320
+ logger.warning("No valid tokens in query")
321
+ return []
322
+
323
+ # Get BM25 scores
324
+ scores = self.bm25_index.get_scores(query_tokens)
325
+
326
+ # Create results with scores
327
+ results = []
328
+ for i, score in enumerate(scores):
329
+ if score <= min_score:
330
+ continue
331
+
332
+ doc = self.documents[i]
333
+
334
+ # Find matched terms
335
+ matched_terms = self._find_matched_terms(query_tokens, doc.preprocessed_tokens)
336
+
337
+ # Calculate relevance score (normalized)
338
+ relevance_score = min(score / 10.0, 1.0) # Normalize to 0-1 range
339
+
340
+ result = BM25SearchResult(
341
+ doc_id=doc.doc_id,
342
+ content=doc.content,
343
+ metadata=doc.metadata,
344
+ bm25_score=score,
345
+ matched_terms=matched_terms,
346
+ relevance_score=relevance_score
347
+ )
348
+
349
+ results.append(result)
350
+
351
+ # Sort by BM25 score (descending)
352
+ results.sort(key=lambda x: x.bm25_score, reverse=True)
353
+
354
+ # Return top N results
355
+ return results[:n_results]
356
+
357
+ except Exception as e:
358
+ logger.error(f"BM25 search failed: {e}")
359
+ return []
360
+
361
+ def _find_matched_terms(self, query_tokens: List[str], doc_tokens: List[str]) -> List[str]:
362
+ """Find which query terms matched in the document"""
363
+ doc_token_set = set(doc_tokens)
364
+ matched = [token for token in query_tokens if token in doc_token_set]
365
+ return matched
366
+
367
+ def get_stats(self) -> Dict:
368
+ """Get BM25 retriever statistics"""
369
+ return {
370
+ 'total_documents': len(self.documents),
371
+ 'index_available': self.bm25_index is not None,
372
+ 'cache_files_exist': {
373
+ 'index': self.bm25_cache_file.exists(),
374
+ 'documents': self.docs_cache_file.exists()
375
+ },
376
+ 'storage_path': str(self.storage_path)
377
+ }
378
+
379
+ def rebuild_index(self):
380
+ """Force rebuild of BM25 index"""
381
+ logger.info("Rebuilding BM25 index...")
382
+
383
+ # Clear existing index
384
+ self.bm25_index = None
385
+ self.documents = []
386
+ self.doc_lookup = {}
387
+
388
+ # Remove cache files
389
+ try:
390
+ if self.bm25_cache_file.exists():
391
+ self.bm25_cache_file.unlink()
392
+ if self.docs_cache_file.exists():
393
+ self.docs_cache_file.unlink()
394
+ except Exception as e:
395
+ logger.warning(f"Failed to remove cache files: {e}")
396
+
397
+ # Rebuild
398
+ self._build_bm25_index()
399
+
400
+
401
+ def main():
402
+ """Test the BM25 retriever"""
403
+
404
+ print("πŸ” Testing BM25 Retriever")
405
+ print("=" * 40)
406
+
407
+ # Initialize knowledge base and BM25 retriever
408
+ knowledge_base = KnowledgeBaseStorage()
409
+ bm25_retriever = BM25Retriever(knowledge_base)
410
+
411
+ # Get stats
412
+ stats = bm25_retriever.get_stats()
413
+ print(f"πŸ“Š BM25 Index Stats:")
414
+ print(f" Documents: {stats['total_documents']}")
415
+ print(f" Index Available: {stats['index_available']}")
416
+
417
+ if not stats['index_available']:
418
+ print("❌ BM25 index not available")
419
+ return
420
+
421
+ # Test queries
422
+ test_queries = [
423
+ "Netflix anomaly detection",
424
+ "isolation forest LinkedIn",
425
+ "fraud detection machine learning",
426
+ "bot detection Cloudflare",
427
+ "spam classification",
428
+ "graph neural networks"
429
+ ]
430
+
431
+ print(f"\nπŸ” Testing BM25 Search:")
432
+ for query in test_queries:
433
+ print(f"\nQuery: '{query}'")
434
+
435
+ results = bm25_retriever.search(query, n_results=3)
436
+ print(f"Results: {len(results)} found")
437
+
438
+ for i, result in enumerate(results, 1):
439
+ company = result.metadata.get('company', 'Unknown')
440
+ title = result.metadata.get('source_title', 'No title')[:50]
441
+ print(f" {i}. {company} - {title}...")
442
+ print(f" BM25 Score: {result.bm25_score:.3f}")
443
+ print(f" Matched Terms: {', '.join(result.matched_terms[:5])}")
444
+
445
+ print(f"\nβœ… BM25 retriever test complete!")
446
+
447
+
448
+ if __name__ == "__main__":
449
+ main()
src/search/ml_agentic_retriever.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Agentic Retriever for SentrySearch
3
+
4
+ Implements Agentic RAG approach with intelligent query optimization.
5
+ Provides intelligent ML-focused retrieval for threat intelligence with:
6
+ - Query optimization for threat-to-ML translation
7
+ - Source identification for relevant paper filtering
8
+ - Enhanced hybrid retrieval with post-processing
9
+ - Context-aware result ranking and structuring
10
+
11
+ Usage:
12
+ retriever = MLAgenticRetriever(anthropic_client, knowledge_base)
13
+ ml_guidance = retriever.get_ml_guidance(threat_characteristics)
14
+ """
15
+
16
+ import os
17
+ import json
18
+ import logging
19
+ import time
20
+ import random
21
+ import hashlib
22
+ from typing import List, Dict, Optional, Tuple, Set
23
+ from dataclasses import dataclass
24
+ import re
25
+
26
+ from anthropic import Anthropic
27
+ import anthropic
28
+ from src.data.ml_knowledge_base_builder import KnowledgeBaseStorage
29
+ from src.search.bm25_retriever import BM25Retriever, BM25SearchResult
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class ThreatCharacteristics:
36
+ """Represents threat characteristics for ML guidance generation"""
37
+ threat_name: str
38
+ threat_type: str # e.g., "malware", "apt", "insider_threat"
39
+ attack_vectors: List[str] # e.g., ["network", "email", "web"]
40
+ target_assets: List[str] # e.g., ["user_accounts", "financial_data"]
41
+ behavior_patterns: List[str] # e.g., ["lateral_movement", "data_exfiltration"]
42
+ time_characteristics: str # e.g., "persistent", "burst", "periodic"
43
+
44
+
45
+ @dataclass
46
+ class OptimizedQuery:
47
+ """Represents an optimized query for ML retrieval"""
48
+ original_query: str
49
+ optimized_queries: List[str]
50
+ ml_focus_areas: List[str]
51
+ reasoning: str
52
+
53
+
54
+ @dataclass
55
+ class SourceSelection:
56
+ """Represents filtered source selection"""
57
+ relevant_companies: List[str]
58
+ relevant_years: List[str]
59
+ relevant_techniques: List[str]
60
+ reasoning: str
61
+
62
+
63
+ @dataclass
64
+ class MLRetrievalResult:
65
+ """Represents a structured ML retrieval result"""
66
+ content: str
67
+ metadata: Dict
68
+ relevance_score: float
69
+ source_paper: str
70
+ ml_techniques: List[str]
71
+ implementation_details: str
72
+ applicability_score: float
73
+ retrieval_method: str = 'unknown' # 'vector', 'bm25', or 'hybrid'
74
+ bm25_score: float = 0.0
75
+ hybrid_score: float = 0.0
76
+ matched_terms: List[str] = None
77
+
78
+
79
+ class QueryOptimizer:
80
+ """Optimizes queries to focus on ML anomaly detection approaches"""
81
+
82
+ def __init__(self, anthropic_client):
83
+ self.client = anthropic_client
84
+
85
+ def _api_call_with_retry(self, **kwargs):
86
+ """Make API call with intelligent retry logic using retry-after header"""
87
+ max_retries = 3
88
+ base_delay = 5
89
+
90
+ for attempt in range(max_retries):
91
+ try:
92
+ print(f"DEBUG: Query Optimizer API call attempt {attempt + 1}/{max_retries}")
93
+ return self.client.messages.create(**kwargs)
94
+
95
+ except anthropic.RateLimitError as e:
96
+ if attempt == max_retries - 1:
97
+ print(f"DEBUG: Query Optimizer rate limit exceeded after {max_retries} attempts")
98
+ raise e
99
+
100
+ # Check if the error response has retry-after information
101
+ retry_after = None
102
+ if hasattr(e, 'response') and e.response:
103
+ retry_after_header = e.response.headers.get('retry-after')
104
+ if retry_after_header:
105
+ try:
106
+ retry_after = float(retry_after_header)
107
+ print(f"DEBUG: Query Optimizer API provided retry-after: {retry_after} seconds")
108
+ except (ValueError, TypeError):
109
+ pass
110
+
111
+ # Use retry-after if available, otherwise exponential backoff
112
+ if retry_after:
113
+ delay = retry_after + random.uniform(1, 3)
114
+ else:
115
+ delay = base_delay * (2 ** attempt) + random.uniform(1, 5)
116
+ delay = min(delay, 120)
117
+
118
+ print(f"DEBUG: Query Optimizer rate limit hit. Waiting {delay:.1f} seconds before retry {attempt + 2}")
119
+ time.sleep(delay)
120
+
121
+ except Exception as e:
122
+ print(f"DEBUG: Query Optimizer non-rate-limit error: {e}")
123
+ raise e
124
+
125
+ def optimize_query(self, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery:
126
+ """Convert threat characteristics into ML-focused queries"""
127
+
128
+ prompt = f"""
129
+ You are an expert in both cybersecurity threats and machine learning anomaly detection.
130
+ Convert this threat information into 3-5 specific queries about ML approaches for detection.
131
+
132
+ Threat Information:
133
+ - Name: {threat_characteristics.threat_name}
134
+ - Type: {threat_characteristics.threat_type}
135
+ - Attack Vectors: {', '.join(threat_characteristics.attack_vectors)}
136
+ - Target Assets: {', '.join(threat_characteristics.target_assets)}
137
+ - Behavior Patterns: {', '.join(threat_characteristics.behavior_patterns)}
138
+ - Time Characteristics: {threat_characteristics.time_characteristics}
139
+
140
+ Generate queries that focus on:
141
+ 1. Specific ML techniques for detecting this threat type
142
+ 2. Feature engineering approaches for the attack vectors
143
+ 3. Behavioral analysis methods for the patterns observed
144
+ 4. Implementation considerations for the target environment
145
+
146
+ Format your response as:
147
+ QUERIES:
148
+ 1. [Query 1]
149
+ 2. [Query 2]
150
+ 3. [Query 3]
151
+ etc.
152
+
153
+ ML_FOCUS_AREAS: [comma-separated focus areas]
154
+
155
+ REASONING: [1-2 sentences explaining the ML approach rationale]
156
+ """
157
+
158
+ try:
159
+ response = self._api_call_with_retry(
160
+ model="claude-sonnet-4-20250514",
161
+ max_tokens=600,
162
+ messages=[{"role": "user", "content": prompt}]
163
+ )
164
+
165
+ # Safe access to response content
166
+ if not response.content or len(response.content) == 0:
167
+ raise ValueError("Empty response from API")
168
+
169
+ content = response.content[0].text.strip()
170
+ return self._parse_optimization_response(content, threat_characteristics)
171
+
172
+ except Exception as e:
173
+ logger.error(f"Query optimization failed: {e}")
174
+ # Fallback to simple query
175
+ fallback_query = f"Machine learning approaches for detecting {threat_characteristics.threat_name}"
176
+ return OptimizedQuery(
177
+ original_query=threat_characteristics.threat_name,
178
+ optimized_queries=[fallback_query],
179
+ ml_focus_areas=["anomaly_detection"],
180
+ reasoning="Fallback query due to optimization failure"
181
+ )
182
+
183
+ def _parse_optimization_response(self, response: str, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery:
184
+ """Parse the LLM response into structured query optimization"""
185
+
186
+ queries = []
187
+ ml_focus_areas = []
188
+ reasoning = ""
189
+
190
+ lines = response.split('\n')
191
+ current_section = None
192
+
193
+ for line in lines:
194
+ line = line.strip()
195
+
196
+ if line.startswith('QUERIES:'):
197
+ current_section = 'queries'
198
+ continue
199
+ elif line.startswith('ML_FOCUS_AREAS:'):
200
+ current_section = 'focus'
201
+ ml_focus_areas = [area.strip() for area in line.replace('ML_FOCUS_AREAS:', '').split(',')]
202
+ continue
203
+ elif line.startswith('REASONING:'):
204
+ current_section = 'reasoning'
205
+ reasoning = line.replace('REASONING:', '').strip()
206
+ continue
207
+
208
+ if current_section == 'queries' and line:
209
+ # Extract query from numbered list
210
+ query_match = re.match(r'\d+\.\s*(.+)', line)
211
+ if query_match:
212
+ queries.append(query_match.group(1).strip())
213
+ elif current_section == 'reasoning' and line:
214
+ reasoning += ' ' + line
215
+
216
+ # Ensure we have at least one query
217
+ if not queries:
218
+ queries = [f"Machine learning detection approaches for {threat_characteristics.threat_name}"]
219
+
220
+ return OptimizedQuery(
221
+ original_query=threat_characteristics.threat_name,
222
+ optimized_queries=queries,
223
+ ml_focus_areas=ml_focus_areas if ml_focus_areas else ["anomaly_detection"],
224
+ reasoning=reasoning.strip()
225
+ )
226
+
227
+
228
+ class SourceIdentifier:
229
+ """Identifies most relevant papers/sources for a given threat"""
230
+
231
+ def __init__(self, knowledge_base: KnowledgeBaseStorage):
232
+ self.knowledge_base = knowledge_base
233
+ self._load_source_mappings()
234
+
235
+ def _load_source_mappings(self):
236
+ """Load mappings between threat types and relevant sources"""
237
+
238
+ # Company expertise mappings
239
+ self.company_expertise = {
240
+ 'Netflix': ['performance_monitoring', 'infrastructure_anomalies', 'streaming_security'],
241
+ 'LinkedIn': ['user_behavior', 'abuse_detection', 'social_platform_security'],
242
+ 'Slack': ['communication_security', 'invite_spam', 'workspace_security'],
243
+ 'Cloudflare': ['network_security', 'bot_detection', 'traffic_analysis'],
244
+ 'Uber': ['fraud_detection', 'real_time_systems', 'human_in_the_loop'],
245
+ 'Grab': ['financial_fraud', 'graph_analysis', 'transaction_security'],
246
+ 'OLX Group': ['marketplace_fraud', 'deep_learning', 'user_verification'],
247
+ 'Stack Exchange': ['content_moderation', 'spam_detection', 'community_security'],
248
+ 'Mercari': ['e_commerce_security', 'content_moderation', 'automated_review']
249
+ }
250
+
251
+ # Attack vector to technique mappings
252
+ self.attack_vector_techniques = {
253
+ 'network': ['traffic_analysis', 'graph_ml', 'behavioral_analysis'],
254
+ 'email': ['text_classification', 'nlp', 'spam_detection'],
255
+ 'web': ['bot_detection', 'traffic_analysis', 'behavioral_analysis'],
256
+ 'insider': ['user_behavior', 'behavioral_analysis', 'anomaly_detection'],
257
+ 'financial': ['fraud_detection', 'transaction_analysis', 'graph_ml']
258
+ }
259
+
260
+ # Threat type to ML approach mappings
261
+ self.threat_ml_mappings = {
262
+ 'malware': ['behavioral_analysis', 'static_analysis', 'dynamic_analysis'],
263
+ 'apt': ['behavioral_analysis', 'network_analysis', 'long_term_patterns'],
264
+ 'fraud': ['fraud_detection', 'transaction_analysis', 'user_behavior'],
265
+ 'spam': ['text_classification', 'content_moderation', 'nlp'],
266
+ 'abuse': ['user_behavior', 'abuse_detection', 'behavioral_analysis']
267
+ }
268
+
269
+ def identify_relevant_sources(self, optimized_query: OptimizedQuery,
270
+ threat_characteristics: ThreatCharacteristics) -> SourceSelection:
271
+ """Identify most relevant papers/sources for the optimized queries"""
272
+
273
+ relevant_companies = set()
274
+ relevant_techniques = set()
275
+
276
+ # Map attack vectors to companies
277
+ for vector in threat_characteristics.attack_vectors:
278
+ for company, expertise in self.company_expertise.items():
279
+ for expert_area in expertise:
280
+ if any(keyword in expert_area for keyword in [vector, threat_characteristics.threat_type]):
281
+ relevant_companies.add(company)
282
+
283
+ # Add techniques for this attack vector
284
+ if vector in self.attack_vector_techniques:
285
+ relevant_techniques.update(self.attack_vector_techniques[vector])
286
+
287
+ # Map threat type to companies and techniques
288
+ threat_type = threat_characteristics.threat_type.lower()
289
+ for company, expertise in self.company_expertise.items():
290
+ if any(threat_type in expert_area or expert_area in threat_type for expert_area in expertise):
291
+ relevant_companies.add(company)
292
+
293
+ # Add techniques for threat type
294
+ if threat_type in self.threat_ml_mappings:
295
+ relevant_techniques.update(self.threat_ml_mappings[threat_type])
296
+
297
+ # Add techniques from ML focus areas
298
+ relevant_techniques.update(optimized_query.ml_focus_areas)
299
+
300
+ # Always include available companies from knowledge base
301
+ stats = self.knowledge_base.get_stats()
302
+ available_companies = set(stats['companies'])
303
+
304
+ # For now, include all available companies to ensure we get results
305
+ # In production, you can make this more selective
306
+ relevant_companies.update(available_companies)
307
+
308
+ # If still empty somehow, use all available
309
+ if not relevant_companies:
310
+ relevant_companies = available_companies
311
+
312
+ # Generate reasoning
313
+ reasoning = f"Selected companies based on expertise in {threat_characteristics.threat_type} and {', '.join(threat_characteristics.attack_vectors)}. Focus on {', '.join(list(relevant_techniques)[:3])} techniques."
314
+
315
+ return SourceSelection(
316
+ relevant_companies=list(relevant_companies),
317
+ relevant_years=['2019', '2020', '2021', '2022'], # All available years
318
+ relevant_techniques=list(relevant_techniques),
319
+ reasoning=reasoning
320
+ )
321
+
322
+
323
+ class EnhancedRetriever:
324
+ """Enhanced hybrid retriever with vector search + BM25 and post-processing"""
325
+
326
+ def __init__(self, knowledge_base: KnowledgeBaseStorage):
327
+ self.knowledge_base = knowledge_base
328
+ self.bm25_retriever = BM25Retriever(knowledge_base)
329
+
330
+ def retrieve_with_context(self, optimized_query: OptimizedQuery,
331
+ source_selection: SourceSelection,
332
+ max_results: int = 10,
333
+ trace_exporter=None) -> List[MLRetrievalResult]:
334
+ """Hybrid retrieval using both vector search and BM25, with context-aware processing"""
335
+
336
+ vector_results = []
337
+ bm25_results = []
338
+
339
+ # Log query optimization to trace
340
+ if trace_exporter:
341
+ trace_exporter.log_query_optimization(
342
+ optimized_query.original_query,
343
+ optimized_query.optimized_queries,
344
+ optimized_query.reasoning,
345
+ optimized_query.ml_focus_areas
346
+ )
347
+
348
+ # 1. Vector Search - Search with each optimized query
349
+ if trace_exporter:
350
+ trace_exporter.log_stage_start("vector_retrieval")
351
+
352
+ for query in optimized_query.optimized_queries:
353
+ results = self.knowledge_base.search(query, n_results=max_results)
354
+
355
+ for result in results:
356
+ # Filter by relevant sources
357
+ metadata = result['metadata']
358
+ company = metadata.get('company', '')
359
+
360
+ if company in source_selection.relevant_companies:
361
+ ml_result = self._create_ml_result(result, optimized_query, source_selection)
362
+ ml_result.retrieval_method = 'vector'
363
+ vector_results.append(ml_result)
364
+
365
+ if trace_exporter:
366
+ trace_exporter.log_stage_end("vector_retrieval", result_count=len(vector_results))
367
+ # Convert to format expected by trace exporter
368
+ vector_trace_results = [self._convert_to_trace_format(r) for r in vector_results]
369
+ trace_exporter.log_retrieval_results(vector_trace_results, "vector")
370
+
371
+ # 2. BM25 Search - Search with each optimized query
372
+ if trace_exporter:
373
+ trace_exporter.log_stage_start("bm25_retrieval")
374
+
375
+ for query in optimized_query.optimized_queries:
376
+ bm25_search_results = self.bm25_retriever.search(query, n_results=max_results)
377
+
378
+ for bm25_result in bm25_search_results:
379
+ # Filter by relevant sources
380
+ company = bm25_result.metadata.get('company', '')
381
+
382
+ if company in source_selection.relevant_companies:
383
+ # Convert BM25 result to ML result format
384
+ ml_result = self._create_ml_result_from_bm25(bm25_result, optimized_query, source_selection)
385
+ ml_result.retrieval_method = 'bm25'
386
+ bm25_results.append(ml_result)
387
+
388
+ if trace_exporter:
389
+ trace_exporter.log_stage_end("bm25_retrieval", result_count=len(bm25_results))
390
+ # Convert to format expected by trace exporter
391
+ bm25_trace_results = [self._convert_to_trace_format(r) for r in bm25_results]
392
+ trace_exporter.log_retrieval_results(bm25_trace_results, "bm25")
393
+
394
+ # 3. Combine and deduplicate results
395
+ if trace_exporter:
396
+ trace_exporter.log_stage_start("hybrid_fusion")
397
+
398
+ all_results = self._fuse_hybrid_results(vector_results, bm25_results)
399
+
400
+ # 4. Post-process results
401
+ processed_results = self._post_process_results(all_results)
402
+
403
+ # 5. Rank by hybrid score combining relevance and applicability
404
+ ranked_results = sorted(processed_results,
405
+ key=lambda x: x.hybrid_score,
406
+ reverse=True)
407
+
408
+ final_results = ranked_results[:max_results]
409
+
410
+ # Log hybrid fusion results
411
+ if trace_exporter:
412
+ trace_exporter.log_stage_end("hybrid_fusion", final_result_count=len(final_results))
413
+ # Convert to format expected by trace exporter
414
+ hybrid_trace_results = [self._convert_to_trace_format(r) for r in final_results]
415
+ trace_exporter.log_retrieval_results(hybrid_trace_results, "hybrid")
416
+
417
+ # Log hybrid scoring if available
418
+ if final_results:
419
+ avg_vector_score = sum(r.relevance_score for r in final_results) / len(final_results)
420
+ avg_bm25_score = sum(getattr(r, 'bm25_score', 0) for r in final_results) / len(final_results)
421
+ avg_hybrid_score = sum(r.hybrid_score for r in final_results) / len(final_results)
422
+ avg_applicability_score = sum(r.applicability_score for r in final_results) / len(final_results)
423
+
424
+ trace_exporter.log_hybrid_scoring(
425
+ avg_vector_score, avg_bm25_score, avg_hybrid_score, avg_applicability_score
426
+ )
427
+
428
+ return final_results
429
+
430
+ def _create_ml_result(self, search_result: Dict,
431
+ optimized_query: OptimizedQuery,
432
+ source_selection: SourceSelection) -> MLRetrievalResult:
433
+ """Convert search result to structured ML result"""
434
+
435
+ metadata = search_result['metadata']
436
+
437
+ # Calculate applicability score based on technique overlap
438
+ ml_techniques_raw = metadata.get('ml_techniques', '')
439
+ # Ensure we have a string before splitting
440
+ ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else ''
441
+ paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip())
442
+ relevant_techniques = set(source_selection.relevant_techniques)
443
+
444
+ technique_overlap = len(paper_techniques.intersection(relevant_techniques))
445
+ applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0)
446
+
447
+ # Extract implementation details from content
448
+ content = search_result['document']
449
+ implementation_details = self._extract_implementation_details(content)
450
+
451
+ result = MLRetrievalResult(
452
+ content=content,
453
+ metadata=metadata,
454
+ relevance_score=search_result['score'],
455
+ source_paper=metadata.get('source_title', ''),
456
+ ml_techniques=list(paper_techniques),
457
+ implementation_details=implementation_details,
458
+ applicability_score=applicability_score,
459
+ retrieval_method='vector'
460
+ )
461
+
462
+ # Calculate hybrid score (for vector results, this is just the combination)
463
+ result.hybrid_score = (result.relevance_score * 0.6 + result.applicability_score * 0.4)
464
+
465
+ return result
466
+
467
+ def _extract_implementation_details(self, content: str) -> str:
468
+ """Extract key implementation details from content"""
469
+
470
+ # Look for implementation-specific keywords
471
+ impl_keywords = [
472
+ 'architecture', 'framework', 'algorithm', 'model',
473
+ 'feature', 'training', 'deployment', 'performance',
474
+ 'accuracy', 'precision', 'recall', 'latency'
475
+ ]
476
+
477
+ sentences = content.split('.')
478
+ impl_sentences = []
479
+
480
+ for sentence in sentences:
481
+ if any(keyword in sentence.lower() for keyword in impl_keywords):
482
+ impl_sentences.append(sentence.strip())
483
+
484
+ return '. '.join(impl_sentences[:3]) # Top 3 relevant sentences
485
+
486
+ def _post_process_results(self, results: List[MLRetrievalResult]) -> List[MLRetrievalResult]:
487
+ """Post-process results for deduplication and enhancement"""
488
+
489
+ # Simple deduplication by content hash
490
+ seen_hashes = set()
491
+ deduplicated = []
492
+
493
+ for result in results:
494
+ content_hash = hash(result.content[:200]) # Hash first 200 chars
495
+
496
+ if content_hash not in seen_hashes:
497
+ seen_hashes.add(content_hash)
498
+ deduplicated.append(result)
499
+
500
+ return deduplicated
501
+
502
+ def _convert_to_trace_format(self, ml_result: MLRetrievalResult) -> Dict:
503
+ """Convert MLRetrievalResult to format expected by trace exporter"""
504
+ return {
505
+ "content": ml_result.content,
506
+ "metadata": ml_result.metadata,
507
+ "score": ml_result.relevance_score,
508
+ "method": ml_result.retrieval_method,
509
+ "matched_terms": getattr(ml_result, 'matched_terms', None),
510
+ "source_company": ml_result.metadata.get('company'),
511
+ "ml_techniques": ml_result.ml_techniques
512
+ }
513
+
514
+ def _create_ml_result_from_bm25(self, bm25_result: BM25SearchResult,
515
+ optimized_query: OptimizedQuery,
516
+ source_selection: SourceSelection) -> MLRetrievalResult:
517
+ """Convert BM25 search result to ML result format"""
518
+
519
+ metadata = bm25_result.metadata
520
+
521
+ # Calculate applicability score based on technique overlap
522
+ ml_techniques_raw = metadata.get('ml_techniques', '')
523
+ ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else ''
524
+ paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip())
525
+ relevant_techniques = set(source_selection.relevant_techniques)
526
+
527
+ technique_overlap = len(paper_techniques.intersection(relevant_techniques))
528
+ applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0)
529
+
530
+ # Extract implementation details from content
531
+ implementation_details = self._extract_implementation_details(bm25_result.content)
532
+
533
+ result = MLRetrievalResult(
534
+ content=bm25_result.content,
535
+ metadata=metadata,
536
+ relevance_score=bm25_result.relevance_score,
537
+ source_paper=metadata.get('source_title', ''),
538
+ ml_techniques=list(paper_techniques),
539
+ implementation_details=implementation_details,
540
+ applicability_score=applicability_score,
541
+ retrieval_method='bm25',
542
+ bm25_score=bm25_result.bm25_score,
543
+ matched_terms=bm25_result.matched_terms
544
+ )
545
+
546
+ # Calculate hybrid score (for BM25 results, give more weight to exact matches)
547
+ bm25_weight = min(bm25_result.bm25_score / 5.0, 1.0) # Normalize BM25 score
548
+ term_match_bonus = len(bm25_result.matched_terms) * 0.1 # Bonus for matched terms
549
+ result.hybrid_score = (bm25_weight * 0.5 + result.applicability_score * 0.4 + term_match_bonus)
550
+
551
+ return result
552
+
553
+ def _fuse_hybrid_results(self, vector_results: List[MLRetrievalResult],
554
+ bm25_results: List[MLRetrievalResult]) -> List[MLRetrievalResult]:
555
+ """Fuse vector and BM25 results using reciprocal rank fusion"""
556
+
557
+ # Create dictionaries for fast lookup
558
+ vector_lookup = {self._get_result_key(r): r for r in vector_results}
559
+ bm25_lookup = {self._get_result_key(r): r for r in bm25_results}
560
+
561
+ # Get all unique result keys
562
+ all_keys = set(vector_lookup.keys()) | set(bm25_lookup.keys())
563
+
564
+ fused_results = []
565
+
566
+ for key in all_keys:
567
+ vector_result = vector_lookup.get(key)
568
+ bm25_result = bm25_lookup.get(key)
569
+
570
+ if vector_result and bm25_result:
571
+ # Both methods found this result - create hybrid
572
+ hybrid_result = self._create_hybrid_result(vector_result, bm25_result)
573
+ fused_results.append(hybrid_result)
574
+ elif vector_result:
575
+ # Only vector search found this
576
+ fused_results.append(vector_result)
577
+ elif bm25_result:
578
+ # Only BM25 found this
579
+ fused_results.append(bm25_result)
580
+
581
+ return fused_results
582
+
583
+ def _get_result_key(self, result: MLRetrievalResult) -> str:
584
+ """Generate a unique key for a result based on content hash"""
585
+ # Use first 100 characters of content as key to detect near-duplicates
586
+ content_key = result.content[:100] if result.content else ""
587
+ return hashlib.md5(content_key.encode()).hexdigest()[:16]
588
+
589
+ def _create_hybrid_result(self, vector_result: MLRetrievalResult,
590
+ bm25_result: MLRetrievalResult) -> MLRetrievalResult:
591
+ """Create a hybrid result by combining vector and BM25 results"""
592
+
593
+ # Use vector result as base and enhance with BM25 data
594
+ hybrid_result = MLRetrievalResult(
595
+ content=vector_result.content,
596
+ metadata=vector_result.metadata,
597
+ relevance_score=max(vector_result.relevance_score, bm25_result.relevance_score),
598
+ source_paper=vector_result.source_paper,
599
+ ml_techniques=vector_result.ml_techniques,
600
+ implementation_details=vector_result.implementation_details,
601
+ applicability_score=max(vector_result.applicability_score, bm25_result.applicability_score),
602
+ retrieval_method='hybrid',
603
+ bm25_score=bm25_result.bm25_score,
604
+ matched_terms=bm25_result.matched_terms
605
+ )
606
+
607
+ # Calculate enhanced hybrid score using reciprocal rank fusion concept
608
+ vector_score = vector_result.hybrid_score
609
+ bm25_score = bm25_result.hybrid_score
610
+
611
+ # Reciprocal rank fusion with k=60 (standard value)
612
+ k = 60
613
+ vector_rank = 1.0 / (k + vector_score * 100) # Convert score to rank
614
+ bm25_rank = 1.0 / (k + bm25_score * 100)
615
+
616
+ # Combine ranks and add bonus for being found by both methods
617
+ hybrid_result.hybrid_score = (vector_rank + bm25_rank) + 0.2 # 0.2 bonus for hybrid
618
+
619
+ return hybrid_result
620
+
621
+
622
+ class MLAgenticRetriever:
623
+ """Main agentic retriever orchestrating all components"""
624
+
625
+ def __init__(self, anthropic_client, knowledge_base_path: str = "./ml_knowledge_base"):
626
+ self.client = anthropic_client
627
+ self.knowledge_base = KnowledgeBaseStorage(knowledge_base_path)
628
+
629
+ # Initialize agents
630
+ self.query_optimizer = QueryOptimizer(anthropic_client)
631
+ self.source_identifier = SourceIdentifier(self.knowledge_base)
632
+ self.enhanced_retriever = EnhancedRetriever(self.knowledge_base)
633
+
634
+ logger.info("ML Agentic Retriever initialized")
635
+
636
+ def get_ml_guidance(self, threat_characteristics: ThreatCharacteristics, trace_exporter=None) -> Dict:
637
+ """Get comprehensive ML guidance for threat detection"""
638
+
639
+ try:
640
+ logger.info(f"Getting ML guidance for: {threat_characteristics.threat_name}")
641
+
642
+ # Step 1: Query Optimization
643
+ optimized_query = self.query_optimizer.optimize_query(threat_characteristics)
644
+ logger.info(f"Generated {len(optimized_query.optimized_queries)} optimized queries")
645
+
646
+ # Step 2: Source Identification
647
+ source_selection = self.source_identifier.identify_relevant_sources(
648
+ optimized_query, threat_characteristics
649
+ )
650
+ logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies")
651
+
652
+ # Step 3: Enhanced Retrieval
653
+ ml_results = self.enhanced_retriever.retrieve_with_context(
654
+ optimized_query, source_selection, max_results=8,
655
+ trace_exporter=trace_exporter
656
+ )
657
+ logger.info(f"Retrieved {len(ml_results)} relevant ML approaches")
658
+
659
+ # Step 4: Structure results
660
+ guidance = self._structure_ml_guidance(
661
+ threat_characteristics, optimized_query, source_selection, ml_results
662
+ )
663
+
664
+ return guidance
665
+
666
+ except Exception as e:
667
+ logger.error(f"ML guidance generation failed: {e}")
668
+ return self._create_fallback_guidance(threat_characteristics)
669
+
670
+ def get_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
671
+ complete_threat_data: Dict, trace_exporter=None) -> Dict:
672
+ """Get enhanced ML guidance leveraging complete threat intelligence context"""
673
+
674
+ try:
675
+ logger.info(f"Getting enhanced ML guidance for: {threat_characteristics.threat_name}")
676
+
677
+ # Step 1: Enhanced Query Optimization with threat context
678
+ optimized_query = self._optimize_query_with_context(threat_characteristics, complete_threat_data)
679
+ logger.info(f"Generated {len(optimized_query.optimized_queries)} context-enhanced queries")
680
+
681
+ # Step 2: Enhanced Source Identification
682
+ source_selection = self.source_identifier.identify_relevant_sources(
683
+ optimized_query, threat_characteristics
684
+ )
685
+ logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies")
686
+
687
+ # Step 3: Enhanced Retrieval with threat context
688
+ ml_results = self.enhanced_retriever.retrieve_with_context(
689
+ optimized_query, source_selection, max_results=10, # More results for enhanced mode
690
+ trace_exporter=trace_exporter
691
+ )
692
+ logger.info(f"Retrieved {len(ml_results)} relevant ML approaches")
693
+
694
+ # Step 4: Structure results with enhanced context
695
+ guidance = self._structure_enhanced_ml_guidance(
696
+ threat_characteristics, optimized_query, source_selection, ml_results, complete_threat_data
697
+ )
698
+
699
+ return guidance
700
+
701
+ except Exception as e:
702
+ logger.error(f"Enhanced ML guidance generation failed: {e}")
703
+ return self._create_enhanced_fallback_guidance(threat_characteristics, complete_threat_data)
704
+
705
+ def _structure_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
706
+ optimized_query: OptimizedQuery,
707
+ source_selection: SourceSelection,
708
+ ml_results: List[MLRetrievalResult]) -> Dict:
709
+ """Structure the ML guidance into organized sections"""
710
+
711
+ # Group results by ML technique
712
+ techniques_map = {}
713
+ for result in ml_results:
714
+ for technique in result.ml_techniques:
715
+ if technique not in techniques_map:
716
+ techniques_map[technique] = []
717
+ techniques_map[technique].append(result)
718
+
719
+ # Create structured guidance
720
+ guidance = {
721
+ 'threat_name': threat_characteristics.threat_name,
722
+ 'ml_approaches': [],
723
+ 'implementation_considerations': [],
724
+ 'source_papers': [],
725
+ 'query_optimization': {
726
+ 'original_query': optimized_query.original_query,
727
+ 'optimized_queries': optimized_query.optimized_queries,
728
+ 'reasoning': optimized_query.reasoning
729
+ },
730
+ 'source_selection': {
731
+ 'relevant_companies': source_selection.relevant_companies,
732
+ 'reasoning': source_selection.reasoning
733
+ }
734
+ }
735
+
736
+ # Add ML approaches
737
+ for technique, results in techniques_map.items():
738
+ if results: # Only include techniques with results
739
+ best_result = max(results, key=lambda x: x.applicability_score)
740
+
741
+ approach = {
742
+ 'technique': str(technique),
743
+ 'description': best_result.implementation_details,
744
+ 'source_company': str(best_result.metadata.get('company', '')),
745
+ 'source_paper': str(best_result.source_paper),
746
+ 'applicability_score': best_result.applicability_score,
747
+ 'relevance_score': best_result.relevance_score,
748
+ 'retrieval_method': best_result.retrieval_method,
749
+ 'hybrid_score': best_result.hybrid_score,
750
+ 'bm25_score': best_result.bm25_score,
751
+ 'matched_terms': best_result.matched_terms if best_result.matched_terms else []
752
+ }
753
+
754
+ guidance['ml_approaches'].append(approach)
755
+
756
+ # Add implementation considerations
757
+ for result in ml_results[:3]: # Top 3 results
758
+ consideration = {
759
+ 'aspect': f"{result.metadata.get('company', '')} Implementation",
760
+ 'details': result.implementation_details,
761
+ 'source': result.source_paper
762
+ }
763
+ guidance['implementation_considerations'].append(consideration)
764
+
765
+ # Add source papers
766
+ seen_papers = set()
767
+ for result in ml_results:
768
+ paper_title = result.source_paper
769
+ if paper_title not in seen_papers:
770
+ seen_papers.add(paper_title)
771
+ paper_info = {
772
+ 'title': paper_title,
773
+ 'company': result.metadata.get('company', ''),
774
+ 'year': result.metadata.get('year', ''),
775
+ 'url': result.metadata.get('source_url', ''),
776
+ 'techniques': result.ml_techniques
777
+ }
778
+ guidance['source_papers'].append(paper_info)
779
+
780
+ return guidance
781
+
782
+ def _create_fallback_guidance(self, threat_characteristics: ThreatCharacteristics) -> Dict:
783
+ """Create fallback guidance when main pipeline fails"""
784
+
785
+ return {
786
+ 'threat_name': threat_characteristics.threat_name,
787
+ 'ml_approaches': [{
788
+ 'technique': 'anomaly_detection',
789
+ 'description': 'General anomaly detection approaches using statistical methods and machine learning',
790
+ 'source_company': 'General',
791
+ 'source_paper': 'Fallback recommendation',
792
+ 'applicability_score': 0.5,
793
+ 'relevance_score': 0.5
794
+ }],
795
+ 'implementation_considerations': [{
796
+ 'aspect': 'General Implementation',
797
+ 'details': 'Consider implementing statistical anomaly detection as a baseline approach',
798
+ 'source': 'Fallback recommendation'
799
+ }],
800
+ 'source_papers': [],
801
+ 'error': 'ML guidance generation failed - fallback recommendations provided'
802
+ }
803
+
804
+
805
+ def create_test_threat_characteristics() -> ThreatCharacteristics:
806
+ """Create test threat characteristics for validation"""
807
+
808
+ return ThreatCharacteristics(
809
+ threat_name="ShadowPad",
810
+ threat_type="malware",
811
+ attack_vectors=["network", "lateral_movement"],
812
+ target_assets=["corporate_networks", "sensitive_data"],
813
+ behavior_patterns=["persistence", "data_exfiltration", "command_control"],
814
+ time_characteristics="persistent"
815
+ )
816
+
817
+
818
+ # Add enhanced methods to MLAgenticRetriever class
819
+ def _optimize_query_with_context(self, threat_characteristics: ThreatCharacteristics,
820
+ complete_threat_data: Dict) -> OptimizedQuery:
821
+ """Create enhanced queries using complete threat intelligence context"""
822
+
823
+ # Extract additional context for query enhancement
824
+ context_elements = []
825
+
826
+ # Technical capabilities
827
+ if tech_details := complete_threat_data.get('technicalDetails'):
828
+ if capabilities := tech_details.get('capabilities'):
829
+ # Ensure capabilities is a list before slicing
830
+ if isinstance(capabilities, list):
831
+ cap_names = [cap.get('name', str(cap)) if isinstance(cap, dict) else str(cap)
832
+ for cap in capabilities[:3]]
833
+ context_elements.extend(cap_names)
834
+
835
+ # C2 protocols
836
+ if c2_data := complete_threat_data.get('commandAndControl'):
837
+ if methods := c2_data.get('communicationMethods'):
838
+ # Ensure methods is a list before slicing
839
+ if isinstance(methods, list):
840
+ protocols = [method.get('protocol', str(method)) if isinstance(method, dict) else str(method)
841
+ for method in methods[:2]]
842
+ context_elements.extend(protocols)
843
+
844
+ # Use the regular optimizer with enhanced threat characteristics
845
+ enhanced_characteristics = ThreatCharacteristics(
846
+ threat_name=threat_characteristics.threat_name,
847
+ threat_type=threat_characteristics.threat_type,
848
+ attack_vectors=threat_characteristics.attack_vectors + context_elements[:2],
849
+ target_assets=threat_characteristics.target_assets,
850
+ behavior_patterns=threat_characteristics.behavior_patterns + context_elements[2:4],
851
+ time_characteristics=threat_characteristics.time_characteristics
852
+ )
853
+
854
+ return self.query_optimizer.optimize_query(enhanced_characteristics)
855
+
856
+ def _structure_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
857
+ optimized_query: OptimizedQuery,
858
+ source_selection: SourceSelection,
859
+ ml_results: List, # MLRetrievalResult type
860
+ complete_threat_data: Dict) -> Dict:
861
+ """Structure enhanced ML guidance with threat context"""
862
+
863
+ # Start with regular structuring
864
+ guidance = self._structure_ml_guidance(
865
+ threat_characteristics, optimized_query, source_selection, ml_results
866
+ )
867
+
868
+ # Enhance with threat context
869
+ guidance['threat_context_applied'] = True
870
+ guidance['context_sources'] = {
871
+ 'technical_details': bool(complete_threat_data.get('technicalDetails')),
872
+ 'command_and_control': bool(complete_threat_data.get('commandAndControl')),
873
+ 'detection_and_mitigation': bool(complete_threat_data.get('detectionAndMitigation')),
874
+ 'forensic_artifacts': bool(complete_threat_data.get('forensicArtifacts'))
875
+ }
876
+
877
+ # Add threat-specific implementation considerations
878
+ if tech_details := complete_threat_data.get('technicalDetails'):
879
+ if os_data := tech_details.get('operatingSystems'):
880
+ # Ensure os_data is a list before slicing
881
+ if isinstance(os_data, list):
882
+ os_names = [os.get('name', str(os)) if isinstance(os, dict) else str(os) for os in os_data[:2]]
883
+ guidance['implementation_considerations'].append({
884
+ 'aspect': 'OS Compatibility',
885
+ 'details': f'Ensure ML models are trained on {", ".join(os_names)} environments for optimal detection.',
886
+ 'source': 'Threat Intelligence Profile'
887
+ })
888
+
889
+ return guidance
890
+
891
+ def _create_enhanced_fallback_guidance(self, threat_characteristics: ThreatCharacteristics,
892
+ complete_threat_data: Dict) -> Dict:
893
+ """Create enhanced fallback guidance with threat context"""
894
+
895
+ fallback = self._create_fallback_guidance(threat_characteristics)
896
+
897
+ # Add context-aware recommendations
898
+ fallback['threat_context_applied'] = True
899
+ fallback['enhanced_fallback'] = True
900
+
901
+ # Add context-specific ML approaches
902
+ if complete_threat_data.get('commandAndControl'):
903
+ fallback['ml_approaches'].append({
904
+ 'technique': 'C2 Traffic Analysis',
905
+ 'source_company': 'Context-Derived',
906
+ 'description': 'ML-based detection of command and control communication patterns identified in the threat profile.',
907
+ 'applicability_score': 0.8
908
+ })
909
+
910
+ if complete_threat_data.get('forensicArtifacts'):
911
+ fallback['ml_approaches'].append({
912
+ 'technique': 'Artifact-Based Detection',
913
+ 'source_company': 'Context-Derived',
914
+ 'description': 'Machine learning models trained on forensic artifacts specific to this threat.',
915
+ 'applicability_score': 0.7
916
+ })
917
+
918
+ return fallback
919
+
920
+ # Monkey patch these methods onto the MLAgenticRetriever class
921
+ MLAgenticRetriever._optimize_query_with_context = _optimize_query_with_context
922
+ MLAgenticRetriever._structure_enhanced_ml_guidance = _structure_enhanced_ml_guidance
923
+ MLAgenticRetriever._create_enhanced_fallback_guidance = _create_enhanced_fallback_guidance
924
+
925
+
926
+ def main():
927
+ """Test the ML Agentic Retriever"""
928
+
929
+ # Initialize
930
+ api_key = os.getenv('ANTHROPIC_API_KEY')
931
+ if not api_key:
932
+ print("Error: ANTHROPIC_API_KEY environment variable not set")
933
+ return
934
+
935
+ print("πŸ€– Testing ML Agentic Retriever")
936
+ print("=" * 40)
937
+
938
+ # Create components
939
+ anthropic_client = Anthropic(api_key=api_key)
940
+ retriever = MLAgenticRetriever(anthropic_client)
941
+
942
+ # Test with sample threat
943
+ threat = create_test_threat_characteristics()
944
+
945
+ print(f"🎯 Testing with threat: {threat.threat_name}")
946
+ print(f" Type: {threat.threat_type}")
947
+ print(f" Attack Vectors: {', '.join(threat.attack_vectors)}")
948
+ print(f" Behavior Patterns: {', '.join(threat.behavior_patterns)}")
949
+
950
+ # Get ML guidance
951
+ guidance = retriever.get_ml_guidance(threat)
952
+
953
+ print(f"\n🧠 ML Guidance Generated:")
954
+ print(f" ML Approaches: {len(guidance['ml_approaches'])}")
955
+ print(f" Implementation Considerations: {len(guidance['implementation_considerations'])}")
956
+ print(f" Source Papers: {len(guidance['source_papers'])}")
957
+
958
+ # Show details
959
+ if guidance['ml_approaches']:
960
+ print(f"\nπŸ“Š Top ML Approaches:")
961
+ for i, approach in enumerate(guidance['ml_approaches'][:3], 1):
962
+ print(f" {i}. {approach['technique']} ({approach['source_company']})")
963
+ print(f" Applicability: {approach['applicability_score']:.2f}")
964
+ print(f" Description: {approach['description'][:100]}...")
965
+
966
+ if guidance['source_papers']:
967
+ print(f"\nπŸ“š Source Papers:")
968
+ for paper in guidance['source_papers'][:3]:
969
+ print(f" β€’ {paper['company']} ({paper['year']}): {paper['title'][:60]}...")
970
+
971
+ print(f"\nβœ… Agentic retrieval test complete!")
972
+
973
+
974
+ if __name__ == "__main__":
975
+ main()
src/{ml_workers_retriever.py β†’ search/ml_workers_retriever.py} RENAMED
File without changes
src/{app.py β†’ ui/app.py} RENAMED
@@ -2,8 +2,8 @@
2
  πŸ” SentrySearch - Threat Intelligence Profile Generator
3
  """
4
  import gradio as gr
5
- from threat_intel_tool import ThreatIntelTool
6
- from markdown_generator import generate_markdown
7
 
8
 
9
  def generate_threat_profile(api_key, tool_name, enable_quality_control, progress=gr.Progress()):
 
2
  πŸ” SentrySearch - Threat Intelligence Profile Generator
3
  """
4
  import gradio as gr
5
+ from src.core.threat_intel_tool import ThreatIntelTool
6
+ from src.core.markdown_generator import generate_markdown
7
 
8
 
9
  def generate_threat_profile(api_key, tool_name, enable_quality_control, progress=gr.Progress()):
wrangler.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cloudflare Workers configuration for SentrySearch Hybrid Search
2
+
3
+ name = "sentry-search-hybrid"
4
+ main = "worker.js"
5
+ compatibility_date = "2023-12-01"
6
+
7
+ # KV Namespace binding
8
+ [[kv_namespaces]]
9
+ binding = "SENTRY_KV"
10
+ id = "f97c6c83f96f4b548307ccb1ffaa2668"
11
+ preview_id = "3fdb32abdd9d472d9d0506297dde4587"
12
+
13
+ # Environment variables (non-secret)
14
+ [vars]
15
+ ENVIRONMENT = "production"
16
+ MAX_SEARCH_RESULTS = "50"
17
+ DEFAULT_SEARCH_TIMEOUT = "30000"
18
+
19
+ # Routes (configure after deployment)
20
+ # routes = [
21
+ # { pattern = "sentry-search.your-domain.com/*", zone_name = "your-domain.com" }
22
+ # ]
23
+
24
+ # Resource limits (removed for free plan compatibility)
25
+ # [limits]
26
+ # cpu_ms = 30000