Spaces:
Sleeping
Sleeping
Commit
Β·
38c5fec
1
Parent(s):
1165e73
Restructure project to mirror organized directory layout
Browse files- Organize src/ into modular directories: core/, data/, search/, ui/
- Move core logic files to src/core/
- Move data management to src/data/
- Move search functionality to src/search/
- Move UI components to src/ui/
- Update all import paths to reflect new structure
- Enhance .gitignore to exclude non-essential files
- Maintain Hugging Face Spaces compatibility
- .gitignore +49 -25
- app.py +1 -1
- src/{markdown_generator.py β core/markdown_generator.py} +0 -0
- src/{ml_guidance_generator.py β core/ml_guidance_generator.py} +1 -1
- src/{section_validator.py β core/section_validator.py} +1 -1
- src/{threat_intel_tool.py β core/threat_intel_tool.py} +3 -3
- src/{trace_exporter.py β core/trace_exporter.py} +0 -0
- src/{validation_criteria.py β core/validation_criteria.py} +0 -0
- src/data/ml_knowledge_base_builder.py +701 -0
- src/search/bm25_retriever.py +449 -0
- src/search/ml_agentic_retriever.py +975 -0
- src/{ml_workers_retriever.py β search/ml_workers_retriever.py} +0 -0
- src/{app.py β ui/app.py} +2 -2
- wrangler.toml +26 -0
.gitignore
CHANGED
@@ -1,40 +1,64 @@
|
|
1 |
-
#
|
2 |
-
|
3 |
-
*.
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
*.log
|
11 |
|
12 |
-
# Environment
|
13 |
.env
|
14 |
-
|
15 |
-
|
16 |
-
*.pyc
|
17 |
|
18 |
-
#
|
19 |
.vscode/
|
20 |
.idea/
|
21 |
*.swp
|
22 |
*.swo
|
23 |
|
24 |
-
# OS
|
25 |
.DS_Store
|
26 |
Thumbs.db
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
|
39 |
-
#
|
40 |
-
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
MANIFEST
|
23 |
|
24 |
+
# Virtual environments
|
25 |
+
venv/
|
26 |
+
env/
|
27 |
+
ENV/
|
|
|
28 |
|
29 |
+
# Environment variables
|
30 |
.env
|
31 |
+
.env.local
|
32 |
+
.env.production
|
|
|
33 |
|
34 |
+
# IDEs
|
35 |
.vscode/
|
36 |
.idea/
|
37 |
*.swp
|
38 |
*.swo
|
39 |
|
40 |
+
# OS
|
41 |
.DS_Store
|
42 |
Thumbs.db
|
43 |
|
44 |
+
# Application specific
|
45 |
+
/data/
|
46 |
+
docs/
|
47 |
+
tests/
|
48 |
+
scripts/
|
49 |
+
backup_*/
|
50 |
+
PROJECT_STRUCTURE.md
|
51 |
+
main.py
|
52 |
+
|
53 |
+
# Logs
|
54 |
+
*.log
|
55 |
|
56 |
+
# Temporary files
|
57 |
+
*.tmp
|
58 |
+
*.temp
|
59 |
|
60 |
+
# Node modules (if any)
|
61 |
+
node_modules/
|
62 |
|
63 |
+
# Cloudflare
|
64 |
+
.wrangler/
|
app.py
CHANGED
@@ -9,7 +9,7 @@ import os
|
|
9 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
10 |
|
11 |
# Import and run the main application
|
12 |
-
from app import create_ui
|
13 |
|
14 |
if __name__ == "__main__":
|
15 |
create_ui()
|
|
|
9 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
10 |
|
11 |
# Import and run the main application
|
12 |
+
from src.ui.app import create_ui
|
13 |
|
14 |
if __name__ == "__main__":
|
15 |
create_ui()
|
src/{markdown_generator.py β core/markdown_generator.py}
RENAMED
File without changes
|
src/{ml_guidance_generator.py β core/ml_guidance_generator.py}
RENAMED
@@ -22,7 +22,7 @@ from datetime import datetime
|
|
22 |
|
23 |
from pydantic import BaseModel, Field, validator
|
24 |
# Import Workers-based retriever (production system)
|
25 |
-
from ml_workers_retriever import MLWorkersRetriever, ThreatCharacteristics
|
26 |
from anthropic import Anthropic
|
27 |
import anthropic
|
28 |
|
|
|
22 |
|
23 |
from pydantic import BaseModel, Field, validator
|
24 |
# Import Workers-based retriever (production system)
|
25 |
+
from src.search.ml_workers_retriever import MLWorkersRetriever, ThreatCharacteristics
|
26 |
from anthropic import Anthropic
|
27 |
import anthropic
|
28 |
|
src/{section_validator.py β core/section_validator.py}
RENAMED
@@ -8,7 +8,7 @@ import anthropic
|
|
8 |
from typing import Optional, List
|
9 |
from datetime import datetime
|
10 |
import re
|
11 |
-
from validation_criteria import SECTION_CRITERIA, VALIDATION_PROMPTS
|
12 |
|
13 |
|
14 |
class SectionValidator:
|
|
|
8 |
from typing import Optional, List
|
9 |
from datetime import datetime
|
10 |
import re
|
11 |
+
from src.core.validation_criteria import SECTION_CRITERIA, VALIDATION_PROMPTS
|
12 |
|
13 |
|
14 |
class SectionValidator:
|
src/{threat_intel_tool.py β core/threat_intel_tool.py}
RENAMED
@@ -11,9 +11,9 @@ import re
|
|
11 |
from pydantic import ValidationError
|
12 |
import time
|
13 |
import random
|
14 |
-
from section_validator import SectionValidator, SectionImprover
|
15 |
-
from ml_guidance_generator import MLGuidanceGenerator, ThreatCharacteristics
|
16 |
-
from trace_exporter import get_trace_exporter
|
17 |
|
18 |
|
19 |
class ThreatIntelTool:
|
|
|
11 |
from pydantic import ValidationError
|
12 |
import time
|
13 |
import random
|
14 |
+
from src.core.section_validator import SectionValidator, SectionImprover
|
15 |
+
from src.core.ml_guidance_generator import MLGuidanceGenerator, ThreatCharacteristics
|
16 |
+
from src.core.trace_exporter import get_trace_exporter
|
17 |
|
18 |
|
19 |
class ThreatIntelTool:
|
src/{trace_exporter.py β core/trace_exporter.py}
RENAMED
File without changes
|
src/{validation_criteria.py β core/validation_criteria.py}
RENAMED
File without changes
|
src/data/ml_knowledge_base_builder.py
ADDED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ML Knowledge Base Builder for SentrySearch
|
3 |
+
|
4 |
+
Builds a production-ready knowledge base from curated ML anomaly detection papers
|
5 |
+
and blog posts. Implements Agentic RAG approach with intelligent content processing.
|
6 |
+
|
7 |
+
Features:
|
8 |
+
- Real content ingestion from URLs
|
9 |
+
- LLM-powered content enrichment
|
10 |
+
- Persistent ChromaDB storage
|
11 |
+
- Question-like chunk processing
|
12 |
+
- Production-ready error handling
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
import time
|
18 |
+
import random
|
19 |
+
import hashlib
|
20 |
+
from typing import List, Dict, Optional, Tuple
|
21 |
+
from dataclasses import dataclass, asdict
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
import logging
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
import requests
|
27 |
+
from bs4 import BeautifulSoup
|
28 |
+
import chromadb
|
29 |
+
from chromadb.config import Settings
|
30 |
+
from anthropic import Anthropic
|
31 |
+
import anthropic
|
32 |
+
import numpy as np
|
33 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
34 |
+
|
35 |
+
# Set up logging
|
36 |
+
logging.basicConfig(level=logging.INFO)
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class MLPaperSource:
|
42 |
+
"""Represents a source ML paper or blog post"""
|
43 |
+
title: str
|
44 |
+
url: str
|
45 |
+
company: str
|
46 |
+
year: str
|
47 |
+
description: str
|
48 |
+
ml_techniques: List[str]
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class EnrichedChunk:
|
53 |
+
"""Represents a processed and enriched document chunk"""
|
54 |
+
chunk_id: str
|
55 |
+
source_title: str
|
56 |
+
source_url: str
|
57 |
+
company: str
|
58 |
+
year: str
|
59 |
+
original_content: str
|
60 |
+
enriched_content: str # Question-like format
|
61 |
+
ml_techniques: List[str]
|
62 |
+
chunk_summary: str
|
63 |
+
keywords: List[str]
|
64 |
+
chunk_index: int
|
65 |
+
content_hash: str
|
66 |
+
bm25_terms: List[str] = None # Additional search terms for BM25
|
67 |
+
faq_questions: List[str] = None # FAQ-style questions
|
68 |
+
|
69 |
+
|
70 |
+
class ContentExtractor:
|
71 |
+
"""Extracts and cleans content from web pages"""
|
72 |
+
|
73 |
+
def __init__(self, timeout: int = 30):
|
74 |
+
self.timeout = timeout
|
75 |
+
self.session = requests.Session()
|
76 |
+
self.session.headers.update({
|
77 |
+
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36'
|
78 |
+
})
|
79 |
+
|
80 |
+
def extract_from_url(self, url: str) -> Optional[str]:
|
81 |
+
"""Extract clean text content from a URL"""
|
82 |
+
try:
|
83 |
+
logger.info(f"Extracting content from: {url}")
|
84 |
+
|
85 |
+
response = self.session.get(url, timeout=self.timeout)
|
86 |
+
response.raise_for_status()
|
87 |
+
|
88 |
+
# Parse HTML
|
89 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
90 |
+
|
91 |
+
# Remove unwanted elements
|
92 |
+
for element in soup(['script', 'style', 'nav', 'header', 'footer', 'aside']):
|
93 |
+
element.decompose()
|
94 |
+
|
95 |
+
# Extract main content
|
96 |
+
content = self._extract_main_content(soup)
|
97 |
+
|
98 |
+
# Clean and normalize text
|
99 |
+
cleaned_content = self._clean_text(content)
|
100 |
+
|
101 |
+
logger.info(f"Extracted {len(cleaned_content)} characters from {url}")
|
102 |
+
return cleaned_content
|
103 |
+
|
104 |
+
except requests.RequestException as e:
|
105 |
+
logger.error(f"Request failed for {url}: {e}")
|
106 |
+
return None
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Content extraction failed for {url}: {e}")
|
109 |
+
return None
|
110 |
+
|
111 |
+
def _extract_main_content(self, soup: BeautifulSoup) -> str:
|
112 |
+
"""Extract main content from parsed HTML"""
|
113 |
+
|
114 |
+
# Try common article selectors
|
115 |
+
content_selectors = [
|
116 |
+
'article',
|
117 |
+
'[role="main"]',
|
118 |
+
'.post-content',
|
119 |
+
'.article-content',
|
120 |
+
'.entry-content',
|
121 |
+
'.content',
|
122 |
+
'main'
|
123 |
+
]
|
124 |
+
|
125 |
+
for selector in content_selectors:
|
126 |
+
content_elem = soup.select_one(selector)
|
127 |
+
if content_elem:
|
128 |
+
return content_elem.get_text()
|
129 |
+
|
130 |
+
# Fallback to body content
|
131 |
+
body = soup.find('body')
|
132 |
+
return body.get_text() if body else soup.get_text()
|
133 |
+
|
134 |
+
def _clean_text(self, text: str) -> str:
|
135 |
+
"""Clean and normalize extracted text"""
|
136 |
+
# Remove extra whitespace
|
137 |
+
lines = [line.strip() for line in text.split('\n')]
|
138 |
+
lines = [line for line in lines if line]
|
139 |
+
|
140 |
+
# Join lines and normalize spaces
|
141 |
+
cleaned = ' '.join(lines)
|
142 |
+
cleaned = ' '.join(cleaned.split())
|
143 |
+
|
144 |
+
return cleaned
|
145 |
+
|
146 |
+
|
147 |
+
class ContentEnricher:
|
148 |
+
"""Enriches content using LLM-powered processing"""
|
149 |
+
|
150 |
+
def __init__(self, anthropic_client):
|
151 |
+
self.client = anthropic_client
|
152 |
+
|
153 |
+
def _api_call_with_retry(self, **kwargs):
|
154 |
+
"""Make API call with intelligent retry logic using retry-after header"""
|
155 |
+
max_retries = 3
|
156 |
+
base_delay = 5
|
157 |
+
|
158 |
+
for attempt in range(max_retries):
|
159 |
+
try:
|
160 |
+
print(f"DEBUG: Content Enricher API call attempt {attempt + 1}/{max_retries}")
|
161 |
+
return self.client.messages.create(**kwargs)
|
162 |
+
|
163 |
+
except anthropic.RateLimitError as e:
|
164 |
+
if attempt == max_retries - 1:
|
165 |
+
print(f"DEBUG: Content Enricher rate limit exceeded after {max_retries} attempts")
|
166 |
+
raise e
|
167 |
+
|
168 |
+
# Check if the error response has retry-after information
|
169 |
+
retry_after = None
|
170 |
+
if hasattr(e, 'response') and e.response:
|
171 |
+
retry_after_header = e.response.headers.get('retry-after')
|
172 |
+
if retry_after_header:
|
173 |
+
try:
|
174 |
+
retry_after = float(retry_after_header)
|
175 |
+
print(f"DEBUG: Content Enricher API provided retry-after: {retry_after} seconds")
|
176 |
+
except (ValueError, TypeError):
|
177 |
+
pass
|
178 |
+
|
179 |
+
# Use retry-after if available, otherwise exponential backoff
|
180 |
+
if retry_after:
|
181 |
+
delay = retry_after + random.uniform(1, 3)
|
182 |
+
else:
|
183 |
+
delay = base_delay * (2 ** attempt) + random.uniform(1, 5)
|
184 |
+
delay = min(delay, 120)
|
185 |
+
|
186 |
+
print(f"DEBUG: Content Enricher rate limit hit. Waiting {delay:.1f} seconds before retry {attempt + 2}")
|
187 |
+
time.sleep(delay)
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
print(f"DEBUG: Content Enricher non-rate-limit error: {e}")
|
191 |
+
raise e
|
192 |
+
|
193 |
+
def enrich_chunk(self, chunk: str, source: MLPaperSource) -> Dict[str, str]:
|
194 |
+
"""Enrich a chunk with summary, keywords, question-like format, and BM25-optimized metadata"""
|
195 |
+
|
196 |
+
prompt = f"""
|
197 |
+
Analyze this text chunk from a machine learning anomaly detection paper/blog and provide:
|
198 |
+
|
199 |
+
1. QUESTION_FORMAT: Rewrite the chunk content as if it's answering questions about the ML approach
|
200 |
+
2. SUMMARY: A 2-line summary of what this chunk covers
|
201 |
+
3. KEYWORDS: 5-8 relevant technical keywords (comma-separated)
|
202 |
+
4. BM25_TERMS: Additional search terms for BM25 retrieval (comma-separated, include variations, synonyms, acronyms)
|
203 |
+
5. FAQ_QUESTIONS: 2-3 potential questions this chunk could answer (pipe-separated)
|
204 |
+
|
205 |
+
Source Context:
|
206 |
+
- Company: {source.company}
|
207 |
+
- ML Techniques: {', '.join(source.ml_techniques)}
|
208 |
+
- Year: {source.year}
|
209 |
+
|
210 |
+
Text Chunk:
|
211 |
+
{chunk[:1500]}
|
212 |
+
|
213 |
+
Format your response as:
|
214 |
+
QUESTION_FORMAT: [rewritten content]
|
215 |
+
SUMMARY: [summary]
|
216 |
+
KEYWORDS: [keywords]
|
217 |
+
BM25_TERMS: [search terms with variations]
|
218 |
+
FAQ_QUESTIONS: [question1|question2|question3]
|
219 |
+
"""
|
220 |
+
|
221 |
+
try:
|
222 |
+
response = self._api_call_with_retry(
|
223 |
+
model="claude-sonnet-4-20250514",
|
224 |
+
max_tokens=800,
|
225 |
+
messages=[{"role": "user", "content": prompt}]
|
226 |
+
)
|
227 |
+
|
228 |
+
# Safe access to response content
|
229 |
+
if not response.content or len(response.content) == 0:
|
230 |
+
raise ValueError("Empty response from content enrichment API")
|
231 |
+
|
232 |
+
if not hasattr(response.content[0], 'text'):
|
233 |
+
raise ValueError("Response content missing text attribute")
|
234 |
+
|
235 |
+
content = response.content[0].text.strip()
|
236 |
+
return self._parse_enrichment_response(content)
|
237 |
+
|
238 |
+
except Exception as e:
|
239 |
+
logger.error(f"Content enrichment failed: {e}")
|
240 |
+
# Return fallback enrichment
|
241 |
+
return {
|
242 |
+
'question_format': chunk,
|
243 |
+
'summary': f"Content about {source.ml_techniques[0]} implementation at {source.company}",
|
244 |
+
'keywords': ', '.join(source.ml_techniques + [source.company.lower(), 'anomaly detection']),
|
245 |
+
'bm25_terms': ', '.join(source.ml_techniques + [source.company.lower(), 'ml', 'detection', 'analysis']),
|
246 |
+
'faq_questions': f"How does {source.company} implement {source.ml_techniques[0]}?|What is {source.ml_techniques[0]} used for?"
|
247 |
+
}
|
248 |
+
|
249 |
+
def _parse_enrichment_response(self, response: str) -> Dict[str, str]:
|
250 |
+
"""Parse LLM response into structured enrichment data"""
|
251 |
+
result = {
|
252 |
+
'question_format': '',
|
253 |
+
'summary': '',
|
254 |
+
'keywords': '',
|
255 |
+
'bm25_terms': '',
|
256 |
+
'faq_questions': ''
|
257 |
+
}
|
258 |
+
|
259 |
+
lines = response.split('\n')
|
260 |
+
current_field = None
|
261 |
+
|
262 |
+
for line in lines:
|
263 |
+
line = line.strip()
|
264 |
+
if line.startswith('QUESTION_FORMAT:'):
|
265 |
+
current_field = 'question_format'
|
266 |
+
result[current_field] = line.replace('QUESTION_FORMAT:', '').strip()
|
267 |
+
elif line.startswith('SUMMARY:'):
|
268 |
+
current_field = 'summary'
|
269 |
+
result[current_field] = line.replace('SUMMARY:', '').strip()
|
270 |
+
elif line.startswith('KEYWORDS:'):
|
271 |
+
current_field = 'keywords'
|
272 |
+
result[current_field] = line.replace('KEYWORDS:', '').strip()
|
273 |
+
elif line.startswith('BM25_TERMS:'):
|
274 |
+
current_field = 'bm25_terms'
|
275 |
+
result[current_field] = line.replace('BM25_TERMS:', '').strip()
|
276 |
+
elif line.startswith('FAQ_QUESTIONS:'):
|
277 |
+
current_field = 'faq_questions'
|
278 |
+
result[current_field] = line.replace('FAQ_QUESTIONS:', '').strip()
|
279 |
+
elif current_field and line:
|
280 |
+
result[current_field] += ' ' + line
|
281 |
+
|
282 |
+
return result
|
283 |
+
|
284 |
+
|
285 |
+
class DocumentProcessor:
|
286 |
+
"""Processes documents into enriched chunks"""
|
287 |
+
|
288 |
+
def __init__(self, content_enricher: ContentEnricher, chunk_size: int = 800):
|
289 |
+
self.enricher = content_enricher
|
290 |
+
self.chunk_size = chunk_size
|
291 |
+
|
292 |
+
def process_document(self, source: MLPaperSource, content: str) -> List[EnrichedChunk]:
|
293 |
+
"""Process a document into enriched chunks"""
|
294 |
+
|
295 |
+
if not content or len(content) < 100:
|
296 |
+
logger.warning(f"Content too short for {source.title}")
|
297 |
+
return []
|
298 |
+
|
299 |
+
# Create chunks
|
300 |
+
chunks = self._create_chunks(content, source)
|
301 |
+
|
302 |
+
# Enrich each chunk
|
303 |
+
enriched_chunks = []
|
304 |
+
for i, chunk_content in enumerate(chunks):
|
305 |
+
|
306 |
+
# Generate content hash for deduplication
|
307 |
+
content_hash = hashlib.md5(chunk_content.encode()).hexdigest()
|
308 |
+
|
309 |
+
# Enrich with LLM
|
310 |
+
enrichment = self.enricher.enrich_chunk(chunk_content, source)
|
311 |
+
|
312 |
+
chunk = EnrichedChunk(
|
313 |
+
chunk_id=f"{source.company}_{source.year}_{i}_{content_hash[:8]}",
|
314 |
+
source_title=source.title,
|
315 |
+
source_url=source.url,
|
316 |
+
company=source.company,
|
317 |
+
year=source.year,
|
318 |
+
original_content=chunk_content,
|
319 |
+
enriched_content=enrichment['question_format'],
|
320 |
+
ml_techniques=source.ml_techniques,
|
321 |
+
chunk_summary=enrichment['summary'],
|
322 |
+
keywords=enrichment['keywords'].split(', ') if enrichment['keywords'] else [],
|
323 |
+
chunk_index=i,
|
324 |
+
content_hash=content_hash
|
325 |
+
)
|
326 |
+
|
327 |
+
# Add BM25-specific metadata to chunk
|
328 |
+
chunk.bm25_terms = enrichment.get('bm25_terms', '').split(', ') if enrichment.get('bm25_terms') else []
|
329 |
+
chunk.faq_questions = enrichment.get('faq_questions', '').split('|') if enrichment.get('faq_questions') else []
|
330 |
+
|
331 |
+
enriched_chunks.append(chunk)
|
332 |
+
|
333 |
+
# Rate limiting for API calls
|
334 |
+
time.sleep(0.5)
|
335 |
+
|
336 |
+
logger.info(f"Processed {len(enriched_chunks)} chunks for {source.title}")
|
337 |
+
return enriched_chunks
|
338 |
+
|
339 |
+
def _create_chunks(self, content: str, source: MLPaperSource) -> List[str]:
|
340 |
+
"""Create overlapping chunks from content"""
|
341 |
+
chunks = []
|
342 |
+
overlap = self.chunk_size // 4 # 25% overlap
|
343 |
+
|
344 |
+
for i in range(0, len(content), self.chunk_size - overlap):
|
345 |
+
chunk = content[i:i + self.chunk_size]
|
346 |
+
|
347 |
+
# Skip very short chunks
|
348 |
+
if len(chunk) < 200:
|
349 |
+
continue
|
350 |
+
|
351 |
+
# Try to break at sentence boundaries
|
352 |
+
if i + self.chunk_size < len(content):
|
353 |
+
last_period = chunk.rfind('.')
|
354 |
+
if last_period > len(chunk) * 0.7: # If period is in last 30%
|
355 |
+
chunk = chunk[:last_period + 1]
|
356 |
+
|
357 |
+
chunks.append(chunk.strip())
|
358 |
+
|
359 |
+
return chunks
|
360 |
+
|
361 |
+
|
362 |
+
class KnowledgeBaseStorage:
|
363 |
+
"""Manages persistent storage of the knowledge base"""
|
364 |
+
|
365 |
+
def __init__(self, storage_path: str = "./ml_knowledge_base"):
|
366 |
+
self.storage_path = Path(storage_path)
|
367 |
+
self.storage_path.mkdir(exist_ok=True)
|
368 |
+
|
369 |
+
# Initialize ChromaDB with persistent storage
|
370 |
+
self.chroma_client = chromadb.PersistentClient(
|
371 |
+
path=str(self.storage_path / "chroma_db")
|
372 |
+
)
|
373 |
+
|
374 |
+
self.collection_name = "ml_anomaly_detection"
|
375 |
+
self.collection = None
|
376 |
+
self._initialize_collection()
|
377 |
+
|
378 |
+
def _initialize_collection(self):
|
379 |
+
"""Initialize or get existing collection"""
|
380 |
+
try:
|
381 |
+
# Try to get existing collection
|
382 |
+
self.collection = self.chroma_client.get_collection(self.collection_name)
|
383 |
+
logger.info(f"Loaded existing collection with {self.collection.count()} documents")
|
384 |
+
except:
|
385 |
+
# Create new collection
|
386 |
+
self.collection = self.chroma_client.create_collection(
|
387 |
+
name=self.collection_name,
|
388 |
+
metadata={"description": "ML Anomaly Detection Knowledge Base"}
|
389 |
+
)
|
390 |
+
logger.info("Created new collection")
|
391 |
+
|
392 |
+
def add_chunks(self, chunks: List[EnrichedChunk]) -> bool:
|
393 |
+
"""Add enriched chunks to the knowledge base"""
|
394 |
+
try:
|
395 |
+
if not chunks:
|
396 |
+
return True
|
397 |
+
|
398 |
+
# Prepare data for ChromaDB
|
399 |
+
documents = []
|
400 |
+
metadatas = []
|
401 |
+
ids = []
|
402 |
+
|
403 |
+
for chunk in chunks:
|
404 |
+
# Create enriched document text
|
405 |
+
document_text = f"""
|
406 |
+
Title: {chunk.source_title}
|
407 |
+
Company: {chunk.company}
|
408 |
+
Year: {chunk.year}
|
409 |
+
ML Techniques: {', '.join(chunk.ml_techniques)}
|
410 |
+
Keywords: {', '.join(chunk.keywords)}
|
411 |
+
Summary: {chunk.chunk_summary}
|
412 |
+
|
413 |
+
Content: {chunk.enriched_content}
|
414 |
+
""".strip()
|
415 |
+
|
416 |
+
documents.append(document_text)
|
417 |
+
metadatas.append({
|
418 |
+
'source_title': chunk.source_title,
|
419 |
+
'source_url': chunk.source_url,
|
420 |
+
'company': chunk.company,
|
421 |
+
'year': chunk.year,
|
422 |
+
'ml_techniques': '|'.join(chunk.ml_techniques),
|
423 |
+
'keywords': '|'.join(chunk.keywords),
|
424 |
+
'chunk_summary': chunk.chunk_summary,
|
425 |
+
'chunk_index': chunk.chunk_index,
|
426 |
+
'content_hash': chunk.content_hash,
|
427 |
+
'bm25_terms': '|'.join(chunk.bm25_terms) if chunk.bm25_terms else '',
|
428 |
+
'faq_questions': '|'.join(chunk.faq_questions) if chunk.faq_questions else ''
|
429 |
+
})
|
430 |
+
ids.append(chunk.chunk_id)
|
431 |
+
|
432 |
+
# Add to ChromaDB
|
433 |
+
self.collection.add(
|
434 |
+
documents=documents,
|
435 |
+
metadatas=metadatas,
|
436 |
+
ids=ids
|
437 |
+
)
|
438 |
+
|
439 |
+
# Save chunk details as JSON backup
|
440 |
+
self._save_chunks_backup(chunks)
|
441 |
+
|
442 |
+
logger.info(f"Added {len(chunks)} chunks to knowledge base")
|
443 |
+
return True
|
444 |
+
|
445 |
+
except Exception as e:
|
446 |
+
logger.error(f"Failed to add chunks to knowledge base: {e}")
|
447 |
+
return False
|
448 |
+
|
449 |
+
def _save_chunks_backup(self, chunks: List[EnrichedChunk]):
|
450 |
+
"""Save chunk details as JSON backup"""
|
451 |
+
backup_file = self.storage_path / "chunks_backup.jsonl"
|
452 |
+
|
453 |
+
with open(backup_file, 'a', encoding='utf-8') as f:
|
454 |
+
for chunk in chunks:
|
455 |
+
f.write(json.dumps(asdict(chunk), ensure_ascii=False) + '\n')
|
456 |
+
|
457 |
+
def get_stats(self) -> Dict:
|
458 |
+
"""Get knowledge base statistics"""
|
459 |
+
try:
|
460 |
+
count = self.collection.count()
|
461 |
+
|
462 |
+
# Get unique companies and years
|
463 |
+
if count > 0:
|
464 |
+
results = self.collection.get(include=['metadatas'])
|
465 |
+
companies = set()
|
466 |
+
years = set()
|
467 |
+
ml_techniques = set()
|
468 |
+
|
469 |
+
for metadata in results['metadatas']:
|
470 |
+
companies.add(metadata.get('company', ''))
|
471 |
+
years.add(metadata.get('year', ''))
|
472 |
+
techniques = metadata.get('ml_techniques', '').split('|')
|
473 |
+
ml_techniques.update([t for t in techniques if t])
|
474 |
+
|
475 |
+
return {
|
476 |
+
'total_chunks': count,
|
477 |
+
'companies': sorted(list(companies)),
|
478 |
+
'years': sorted(list(years)),
|
479 |
+
'ml_techniques': sorted(list(ml_techniques)),
|
480 |
+
'storage_path': str(self.storage_path)
|
481 |
+
}
|
482 |
+
else:
|
483 |
+
return {
|
484 |
+
'total_chunks': 0,
|
485 |
+
'companies': [],
|
486 |
+
'years': [],
|
487 |
+
'ml_techniques': [],
|
488 |
+
'storage_path': str(self.storage_path)
|
489 |
+
}
|
490 |
+
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"Failed to get stats: {e}")
|
493 |
+
return {'error': str(e)}
|
494 |
+
|
495 |
+
def search(self, query: str, n_results: int = 10) -> List[Dict]:
|
496 |
+
"""Search the knowledge base"""
|
497 |
+
try:
|
498 |
+
results = self.collection.query(
|
499 |
+
query_texts=[query],
|
500 |
+
n_results=n_results,
|
501 |
+
include=['documents', 'metadatas', 'distances']
|
502 |
+
)
|
503 |
+
|
504 |
+
search_results = []
|
505 |
+
for i, doc in enumerate(results['documents'][0]):
|
506 |
+
search_results.append({
|
507 |
+
'document': doc,
|
508 |
+
'metadata': results['metadatas'][0][i],
|
509 |
+
'distance': results['distances'][0][i],
|
510 |
+
'score': 1 / (1 + results['distances'][0][i]) # Convert distance to similarity
|
511 |
+
})
|
512 |
+
|
513 |
+
return search_results
|
514 |
+
|
515 |
+
except Exception as e:
|
516 |
+
logger.error(f"Search failed: {e}")
|
517 |
+
return []
|
518 |
+
|
519 |
+
|
520 |
+
def get_curated_ml_sources() -> List[MLPaperSource]:
|
521 |
+
"""Get the curated list of ML anomaly detection sources"""
|
522 |
+
|
523 |
+
sources = [
|
524 |
+
MLPaperSource(
|
525 |
+
title="Detecting Performance Anomalies in External Firmware Deployments",
|
526 |
+
url="https://netflixtechblog.com/detecting-performance-anomalies-in-external-firmware-deployments-ed41b1bfcf46",
|
527 |
+
company="Netflix",
|
528 |
+
year="2019",
|
529 |
+
description="Netflix's approach to detecting anomalies in firmware performance using ML",
|
530 |
+
ml_techniques=["statistical_analysis", "anomaly_detection", "performance_monitoring"]
|
531 |
+
),
|
532 |
+
MLPaperSource(
|
533 |
+
title="Detecting and Preventing Abuse on LinkedIn using Isolation Forests",
|
534 |
+
url="https://engineering.linkedin.com/blog/2019/isolation-forest",
|
535 |
+
company="LinkedIn",
|
536 |
+
year="2019",
|
537 |
+
description="LinkedIn's implementation of isolation forests for abuse detection",
|
538 |
+
ml_techniques=["isolation_forest", "unsupervised_learning", "abuse_detection"]
|
539 |
+
),
|
540 |
+
MLPaperSource(
|
541 |
+
title="How Does Spam Protection Work on Stack Exchange?",
|
542 |
+
url="https://stackoverflow.blog/2020/06/25/how-does-spam-protection-work-on-stack-exchange/",
|
543 |
+
company="Stack Exchange",
|
544 |
+
year="2020",
|
545 |
+
description="Stack Exchange's ML-based spam detection system",
|
546 |
+
ml_techniques=["text_classification", "nlp", "spam_detection"]
|
547 |
+
),
|
548 |
+
MLPaperSource(
|
549 |
+
title="Blocking Slack Invite Spam With Machine Learning",
|
550 |
+
url="https://slack.engineering/blocking-slack-invite-spam-with-machine-learning/",
|
551 |
+
company="Slack",
|
552 |
+
year="2020",
|
553 |
+
description="Slack's ML approach to preventing invite spam",
|
554 |
+
ml_techniques=["classification", "feature_engineering", "spam_detection"]
|
555 |
+
),
|
556 |
+
MLPaperSource(
|
557 |
+
title="Cloudflare Bot Management: Machine Learning and More",
|
558 |
+
url="https://blog.cloudflare.com/cloudflare-bot-management-machine-learning-and-more/",
|
559 |
+
company="Cloudflare",
|
560 |
+
year="2020",
|
561 |
+
description="Cloudflare's ML-powered bot detection and management",
|
562 |
+
ml_techniques=["behavioral_analysis", "traffic_analysis", "bot_detection"]
|
563 |
+
),
|
564 |
+
MLPaperSource(
|
565 |
+
title="Graph for Fraud Detection",
|
566 |
+
url="https://engineering.grab.com/graph-for-fraud-detection",
|
567 |
+
company="Grab",
|
568 |
+
year="2022",
|
569 |
+
description="Grab's graph-based approach to fraud detection",
|
570 |
+
ml_techniques=["graph_ml", "fraud_detection", "network_analysis"]
|
571 |
+
),
|
572 |
+
MLPaperSource(
|
573 |
+
title="Machine Learning for Fraud Detection in Streaming Services",
|
574 |
+
url="https://netflixtechblog.com/machine-learning-for-fraud-detection-in-streaming-services-b0b4ef3be3f6",
|
575 |
+
company="Netflix",
|
576 |
+
year="2023",
|
577 |
+
description="Netflix's ML approach to detecting fraud in streaming services",
|
578 |
+
ml_techniques=["fraud_detection", "streaming_analytics", "behavioral_analysis"]
|
579 |
+
),
|
580 |
+
MLPaperSource(
|
581 |
+
title="Data Generation and Sampling Strategies",
|
582 |
+
url="https://blog.cloudflare.com/data-generation-and-sampling-strategies/",
|
583 |
+
company="Cloudflare",
|
584 |
+
year="2023",
|
585 |
+
description="Cloudflare's data generation and sampling strategies for ML training",
|
586 |
+
ml_techniques=["data_generation", "sampling", "training_data"]
|
587 |
+
),
|
588 |
+
MLPaperSource(
|
589 |
+
title="Machine Learning Mobile Traffic Bots",
|
590 |
+
url="https://blog.cloudflare.com/machine-learning-mobile-traffic-bots/",
|
591 |
+
company="Cloudflare",
|
592 |
+
year="2023",
|
593 |
+
description="Cloudflare's ML approach to detecting mobile traffic bots",
|
594 |
+
ml_techniques=["bot_detection", "mobile_traffic", "behavioral_analysis"]
|
595 |
+
),
|
596 |
+
MLPaperSource(
|
597 |
+
title="Project Radar: Intelligent Early Fraud Detection",
|
598 |
+
url="https://www.uber.com/blog/project-radar-intelligent-early-fraud-detection/",
|
599 |
+
company="Uber",
|
600 |
+
year="2023",
|
601 |
+
description="Uber's Project Radar for intelligent early fraud detection",
|
602 |
+
ml_techniques=["fraud_detection", "early_detection", "real_time_ml"]
|
603 |
+
)
|
604 |
+
]
|
605 |
+
|
606 |
+
return sources
|
607 |
+
|
608 |
+
|
609 |
+
def main():
|
610 |
+
"""Main function to build the ML knowledge base"""
|
611 |
+
|
612 |
+
# Initialize components
|
613 |
+
api_key = os.getenv('ANTHROPIC_API_KEY')
|
614 |
+
if not api_key:
|
615 |
+
logger.error("ANTHROPIC_API_KEY environment variable not set")
|
616 |
+
return
|
617 |
+
|
618 |
+
print("π¨ Building ML Anomaly Detection Knowledge Base")
|
619 |
+
print("=" * 50)
|
620 |
+
|
621 |
+
# Initialize components
|
622 |
+
anthropic_client = Anthropic(api_key=api_key)
|
623 |
+
content_extractor = ContentExtractor()
|
624 |
+
content_enricher = ContentEnricher(anthropic_client)
|
625 |
+
document_processor = DocumentProcessor(content_enricher)
|
626 |
+
knowledge_base = KnowledgeBaseStorage()
|
627 |
+
|
628 |
+
# Get current stats
|
629 |
+
current_stats = knowledge_base.get_stats()
|
630 |
+
print(f"π Current knowledge base: {current_stats['total_chunks']} chunks")
|
631 |
+
|
632 |
+
# Get sources to process
|
633 |
+
sources = get_curated_ml_sources()
|
634 |
+
print(f"π Processing {len(sources)} ML sources...")
|
635 |
+
|
636 |
+
# Process each source
|
637 |
+
total_chunks_added = 0
|
638 |
+
successful_sources = 0
|
639 |
+
|
640 |
+
for i, source in enumerate(sources, 1):
|
641 |
+
print(f"\nπ [{i}/{len(sources)}] Processing: {source.title}")
|
642 |
+
print(f" Company: {source.company} | Year: {source.year}")
|
643 |
+
|
644 |
+
# Extract content
|
645 |
+
content = content_extractor.extract_from_url(source.url)
|
646 |
+
|
647 |
+
if not content:
|
648 |
+
print(f" β Failed to extract content")
|
649 |
+
continue
|
650 |
+
|
651 |
+
print(f" π Extracted {len(content):,} characters")
|
652 |
+
|
653 |
+
# Process into chunks
|
654 |
+
chunks = document_processor.process_document(source, content)
|
655 |
+
|
656 |
+
if not chunks:
|
657 |
+
print(f" β No chunks generated")
|
658 |
+
continue
|
659 |
+
|
660 |
+
print(f" π§© Generated {len(chunks)} chunks")
|
661 |
+
|
662 |
+
# Add to knowledge base
|
663 |
+
if knowledge_base.add_chunks(chunks):
|
664 |
+
total_chunks_added += len(chunks)
|
665 |
+
successful_sources += 1
|
666 |
+
print(f" β
Added to knowledge base")
|
667 |
+
else:
|
668 |
+
print(f" β Failed to add to knowledge base")
|
669 |
+
|
670 |
+
# Final stats
|
671 |
+
print(f"\nπ Knowledge Base Build Complete!")
|
672 |
+
print("=" * 50)
|
673 |
+
print(f"Sources processed: {successful_sources}/{len(sources)}")
|
674 |
+
print(f"Total chunks added: {total_chunks_added}")
|
675 |
+
|
676 |
+
final_stats = knowledge_base.get_stats()
|
677 |
+
print(f"Final knowledge base size: {final_stats['total_chunks']} chunks")
|
678 |
+
print(f"Companies: {', '.join(final_stats['companies'])}")
|
679 |
+
print(f"Years: {', '.join(final_stats['years'])}")
|
680 |
+
print(f"Storage location: {final_stats['storage_path']}")
|
681 |
+
|
682 |
+
# Test search
|
683 |
+
print(f"\nπ Testing search functionality...")
|
684 |
+
test_queries = [
|
685 |
+
"How does Netflix detect performance anomalies?",
|
686 |
+
"What ML techniques work for fraud detection?",
|
687 |
+
"Isolation forest implementation details"
|
688 |
+
]
|
689 |
+
|
690 |
+
for query in test_queries:
|
691 |
+
results = knowledge_base.search(query, n_results=3)
|
692 |
+
print(f"\nQuery: '{query}'")
|
693 |
+
print(f"Results: {len(results)} found")
|
694 |
+
if results:
|
695 |
+
top_result = results[0]
|
696 |
+
print(f"Top match: {top_result['metadata']['company']} - {top_result['metadata']['source_title'][:60]}...")
|
697 |
+
print(f"Score: {top_result['score']:.3f}")
|
698 |
+
|
699 |
+
|
700 |
+
if __name__ == "__main__":
|
701 |
+
main()
|
src/search/bm25_retriever.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BM25 Retriever for SentrySearch
|
3 |
+
|
4 |
+
Implements BM25-based retrieval with enriched metadata support for the agentic RAG system.
|
5 |
+
Provides complementary keyword-based retrieval alongside vector search for enhanced precision.
|
6 |
+
|
7 |
+
Features:
|
8 |
+
- BM25 algorithm for exact keyword matching
|
9 |
+
- Enriched metadata indexing (summaries, FAQs, keywords)
|
10 |
+
- Integration with existing knowledge base
|
11 |
+
- Result scoring and ranking
|
12 |
+
- Efficient document preprocessing
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
import logging
|
18 |
+
import time
|
19 |
+
from typing import List, Dict, Optional, Tuple, Set
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from pathlib import Path
|
22 |
+
import re
|
23 |
+
import pickle
|
24 |
+
|
25 |
+
from rank_bm25 import BM25Okapi
|
26 |
+
import numpy as np
|
27 |
+
from src.data.ml_knowledge_base_builder import KnowledgeBaseStorage
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class BM25Document:
|
34 |
+
"""Represents a document optimized for BM25 retrieval"""
|
35 |
+
doc_id: str
|
36 |
+
content: str
|
37 |
+
enriched_content: str # Enhanced with metadata
|
38 |
+
metadata: Dict
|
39 |
+
keywords: List[str]
|
40 |
+
summary: str
|
41 |
+
preprocessed_tokens: List[str]
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class BM25SearchResult:
|
46 |
+
"""Represents a BM25 search result"""
|
47 |
+
doc_id: str
|
48 |
+
content: str
|
49 |
+
metadata: Dict
|
50 |
+
bm25_score: float
|
51 |
+
matched_terms: List[str]
|
52 |
+
relevance_score: float # Normalized score
|
53 |
+
|
54 |
+
|
55 |
+
class BM25Preprocessor:
|
56 |
+
"""Preprocesses documents for BM25 indexing"""
|
57 |
+
|
58 |
+
def __init__(self):
|
59 |
+
# Common stopwords for technical content
|
60 |
+
self.stopwords = {
|
61 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
62 |
+
'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
63 |
+
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
|
64 |
+
'should', 'may', 'might', 'can', 'this', 'that', 'these', 'those',
|
65 |
+
'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them'
|
66 |
+
}
|
67 |
+
|
68 |
+
def preprocess_text(self, text: str) -> List[str]:
|
69 |
+
"""Preprocess text for BM25 indexing"""
|
70 |
+
if not text:
|
71 |
+
return []
|
72 |
+
|
73 |
+
# Convert to lowercase
|
74 |
+
text = text.lower()
|
75 |
+
|
76 |
+
# Keep technical terms and alphanumeric sequences
|
77 |
+
# Split on whitespace and punctuation but preserve underscores and hyphens in technical terms
|
78 |
+
tokens = re.findall(r'\b[a-zA-Z0-9_-]+\b', text)
|
79 |
+
|
80 |
+
# Filter tokens
|
81 |
+
filtered_tokens = []
|
82 |
+
for token in tokens:
|
83 |
+
# Skip stopwords
|
84 |
+
if token in self.stopwords:
|
85 |
+
continue
|
86 |
+
|
87 |
+
# Skip very short tokens unless they're technical (like ML, AI, etc.)
|
88 |
+
if len(token) < 2 and not token.isupper():
|
89 |
+
continue
|
90 |
+
|
91 |
+
# Skip very long tokens (likely noise)
|
92 |
+
if len(token) > 50:
|
93 |
+
continue
|
94 |
+
|
95 |
+
filtered_tokens.append(token)
|
96 |
+
|
97 |
+
return filtered_tokens
|
98 |
+
|
99 |
+
def create_enriched_content(self, chunk_data: Dict) -> str:
|
100 |
+
"""Create enriched content optimized for BM25 search"""
|
101 |
+
content_parts = []
|
102 |
+
|
103 |
+
# Original content
|
104 |
+
if chunk_data.get('enriched_content'):
|
105 |
+
content_parts.append(chunk_data['enriched_content'])
|
106 |
+
|
107 |
+
# Add weighted metadata
|
108 |
+
metadata = chunk_data.get('metadata', {})
|
109 |
+
|
110 |
+
# Company name (high weight)
|
111 |
+
if company := metadata.get('company'):
|
112 |
+
content_parts.append(f"{company} {company} {company}") # Triple weight
|
113 |
+
|
114 |
+
# ML techniques (high weight)
|
115 |
+
if ml_techniques := metadata.get('ml_techniques'):
|
116 |
+
techniques = ml_techniques.split('|') if isinstance(ml_techniques, str) else []
|
117 |
+
for technique in techniques:
|
118 |
+
if technique.strip():
|
119 |
+
# Double weight for techniques
|
120 |
+
content_parts.append(f"{technique} {technique}")
|
121 |
+
|
122 |
+
# Keywords (medium weight)
|
123 |
+
if keywords := metadata.get('keywords'):
|
124 |
+
keyword_list = keywords.split('|') if isinstance(keywords, str) else []
|
125 |
+
content_parts.extend(keyword_list)
|
126 |
+
|
127 |
+
# Summary (medium weight)
|
128 |
+
if summary := metadata.get('chunk_summary'):
|
129 |
+
content_parts.append(summary)
|
130 |
+
|
131 |
+
# Source title (medium weight)
|
132 |
+
if title := metadata.get('source_title'):
|
133 |
+
content_parts.append(title)
|
134 |
+
|
135 |
+
return ' '.join(content_parts)
|
136 |
+
|
137 |
+
|
138 |
+
class BM25Retriever:
|
139 |
+
"""BM25-based retriever with enriched metadata support"""
|
140 |
+
|
141 |
+
def __init__(self, knowledge_base: KnowledgeBaseStorage,
|
142 |
+
storage_path: str = "./ml_knowledge_base"):
|
143 |
+
self.knowledge_base = knowledge_base
|
144 |
+
self.storage_path = Path(storage_path)
|
145 |
+
self.preprocessor = BM25Preprocessor()
|
146 |
+
|
147 |
+
# BM25 components
|
148 |
+
self.bm25_index = None
|
149 |
+
self.documents = []
|
150 |
+
self.doc_lookup = {} # doc_id -> document mapping
|
151 |
+
|
152 |
+
# Storage files
|
153 |
+
self.bm25_cache_file = self.storage_path / "bm25_index.pkl"
|
154 |
+
self.docs_cache_file = self.storage_path / "bm25_documents.json"
|
155 |
+
|
156 |
+
# Initialize
|
157 |
+
self._initialize_bm25_index()
|
158 |
+
|
159 |
+
def _initialize_bm25_index(self):
|
160 |
+
"""Initialize or load existing BM25 index"""
|
161 |
+
try:
|
162 |
+
# Try to load cached index
|
163 |
+
if self._load_cached_index():
|
164 |
+
logger.info(f"Loaded cached BM25 index with {len(self.documents)} documents")
|
165 |
+
return
|
166 |
+
|
167 |
+
# Build new index
|
168 |
+
logger.info("Building new BM25 index...")
|
169 |
+
self._build_bm25_index()
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
logger.error(f"Failed to initialize BM25 index: {e}")
|
173 |
+
self.bm25_index = None
|
174 |
+
self.documents = []
|
175 |
+
|
176 |
+
def _load_cached_index(self) -> bool:
|
177 |
+
"""Load cached BM25 index if available"""
|
178 |
+
if not (self.bm25_cache_file.exists() and self.docs_cache_file.exists()):
|
179 |
+
return False
|
180 |
+
|
181 |
+
try:
|
182 |
+
# Load BM25 index
|
183 |
+
with open(self.bm25_cache_file, 'rb') as f:
|
184 |
+
self.bm25_index = pickle.load(f)
|
185 |
+
|
186 |
+
# Load documents
|
187 |
+
with open(self.docs_cache_file, 'r', encoding='utf-8') as f:
|
188 |
+
docs_data = json.load(f)
|
189 |
+
|
190 |
+
# Reconstruct documents
|
191 |
+
self.documents = []
|
192 |
+
self.doc_lookup = {}
|
193 |
+
|
194 |
+
for doc_data in docs_data:
|
195 |
+
doc = BM25Document(**doc_data)
|
196 |
+
self.documents.append(doc)
|
197 |
+
self.doc_lookup[doc.doc_id] = doc
|
198 |
+
|
199 |
+
return True
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
logger.warning(f"Failed to load cached BM25 index: {e}")
|
203 |
+
return False
|
204 |
+
|
205 |
+
def _build_bm25_index(self):
|
206 |
+
"""Build BM25 index from knowledge base"""
|
207 |
+
try:
|
208 |
+
# Get all documents from ChromaDB
|
209 |
+
results = self.knowledge_base.collection.get(
|
210 |
+
include=['documents', 'metadatas']
|
211 |
+
)
|
212 |
+
|
213 |
+
if not results['ids']:
|
214 |
+
logger.warning("No documents found in knowledge base")
|
215 |
+
return
|
216 |
+
|
217 |
+
logger.info(f"Processing {len(results['ids'])} documents for BM25 indexing...")
|
218 |
+
|
219 |
+
# Process each document
|
220 |
+
bm25_documents = []
|
221 |
+
tokenized_docs = []
|
222 |
+
|
223 |
+
for i, doc_id in enumerate(results['ids']):
|
224 |
+
try:
|
225 |
+
# Create enriched content for BM25
|
226 |
+
chunk_data = {
|
227 |
+
'enriched_content': results['documents'][i],
|
228 |
+
'metadata': results['metadatas'][i]
|
229 |
+
}
|
230 |
+
|
231 |
+
enriched_content = self.preprocessor.create_enriched_content(chunk_data)
|
232 |
+
|
233 |
+
# Preprocess for BM25
|
234 |
+
tokens = self.preprocessor.preprocess_text(enriched_content)
|
235 |
+
|
236 |
+
if not tokens: # Skip empty documents
|
237 |
+
continue
|
238 |
+
|
239 |
+
# Create BM25 document
|
240 |
+
bm25_doc = BM25Document(
|
241 |
+
doc_id=doc_id,
|
242 |
+
content=results['documents'][i],
|
243 |
+
enriched_content=enriched_content,
|
244 |
+
metadata=results['metadatas'][i],
|
245 |
+
keywords=results['metadatas'][i].get('keywords', '').split('|'),
|
246 |
+
summary=results['metadatas'][i].get('chunk_summary', ''),
|
247 |
+
preprocessed_tokens=tokens
|
248 |
+
)
|
249 |
+
|
250 |
+
bm25_documents.append(bm25_doc)
|
251 |
+
tokenized_docs.append(tokens)
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
logger.warning(f"Failed to process document {doc_id}: {e}")
|
255 |
+
continue
|
256 |
+
|
257 |
+
if not bm25_documents:
|
258 |
+
logger.error("No valid documents processed for BM25")
|
259 |
+
return
|
260 |
+
|
261 |
+
# Build BM25 index
|
262 |
+
logger.info(f"Building BM25 index with {len(bm25_documents)} documents...")
|
263 |
+
self.bm25_index = BM25Okapi(tokenized_docs)
|
264 |
+
self.documents = bm25_documents
|
265 |
+
|
266 |
+
# Create lookup dictionary
|
267 |
+
self.doc_lookup = {doc.doc_id: doc for doc in self.documents}
|
268 |
+
|
269 |
+
# Cache the index
|
270 |
+
self._cache_bm25_index()
|
271 |
+
|
272 |
+
logger.info(f"BM25 index built successfully with {len(self.documents)} documents")
|
273 |
+
|
274 |
+
except Exception as e:
|
275 |
+
logger.error(f"Failed to build BM25 index: {e}")
|
276 |
+
self.bm25_index = None
|
277 |
+
self.documents = []
|
278 |
+
|
279 |
+
def _cache_bm25_index(self):
|
280 |
+
"""Cache BM25 index to disk"""
|
281 |
+
try:
|
282 |
+
# Cache BM25 index
|
283 |
+
with open(self.bm25_cache_file, 'wb') as f:
|
284 |
+
pickle.dump(self.bm25_index, f)
|
285 |
+
|
286 |
+
# Cache documents (convert to JSON-serializable format)
|
287 |
+
docs_data = []
|
288 |
+
for doc in self.documents:
|
289 |
+
doc_dict = {
|
290 |
+
'doc_id': doc.doc_id,
|
291 |
+
'content': doc.content,
|
292 |
+
'enriched_content': doc.enriched_content,
|
293 |
+
'metadata': doc.metadata,
|
294 |
+
'keywords': doc.keywords,
|
295 |
+
'summary': doc.summary,
|
296 |
+
'preprocessed_tokens': doc.preprocessed_tokens
|
297 |
+
}
|
298 |
+
docs_data.append(doc_dict)
|
299 |
+
|
300 |
+
with open(self.docs_cache_file, 'w', encoding='utf-8') as f:
|
301 |
+
json.dump(docs_data, f, ensure_ascii=False, indent=2)
|
302 |
+
|
303 |
+
logger.info("BM25 index cached successfully")
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
logger.warning(f"Failed to cache BM25 index: {e}")
|
307 |
+
|
308 |
+
def search(self, query: str, n_results: int = 10,
|
309 |
+
min_score: float = 0.0) -> List[BM25SearchResult]:
|
310 |
+
"""Search using BM25 algorithm"""
|
311 |
+
if not self.bm25_index or not self.documents:
|
312 |
+
logger.warning("BM25 index not available")
|
313 |
+
return []
|
314 |
+
|
315 |
+
try:
|
316 |
+
# Preprocess query
|
317 |
+
query_tokens = self.preprocessor.preprocess_text(query)
|
318 |
+
|
319 |
+
if not query_tokens:
|
320 |
+
logger.warning("No valid tokens in query")
|
321 |
+
return []
|
322 |
+
|
323 |
+
# Get BM25 scores
|
324 |
+
scores = self.bm25_index.get_scores(query_tokens)
|
325 |
+
|
326 |
+
# Create results with scores
|
327 |
+
results = []
|
328 |
+
for i, score in enumerate(scores):
|
329 |
+
if score <= min_score:
|
330 |
+
continue
|
331 |
+
|
332 |
+
doc = self.documents[i]
|
333 |
+
|
334 |
+
# Find matched terms
|
335 |
+
matched_terms = self._find_matched_terms(query_tokens, doc.preprocessed_tokens)
|
336 |
+
|
337 |
+
# Calculate relevance score (normalized)
|
338 |
+
relevance_score = min(score / 10.0, 1.0) # Normalize to 0-1 range
|
339 |
+
|
340 |
+
result = BM25SearchResult(
|
341 |
+
doc_id=doc.doc_id,
|
342 |
+
content=doc.content,
|
343 |
+
metadata=doc.metadata,
|
344 |
+
bm25_score=score,
|
345 |
+
matched_terms=matched_terms,
|
346 |
+
relevance_score=relevance_score
|
347 |
+
)
|
348 |
+
|
349 |
+
results.append(result)
|
350 |
+
|
351 |
+
# Sort by BM25 score (descending)
|
352 |
+
results.sort(key=lambda x: x.bm25_score, reverse=True)
|
353 |
+
|
354 |
+
# Return top N results
|
355 |
+
return results[:n_results]
|
356 |
+
|
357 |
+
except Exception as e:
|
358 |
+
logger.error(f"BM25 search failed: {e}")
|
359 |
+
return []
|
360 |
+
|
361 |
+
def _find_matched_terms(self, query_tokens: List[str], doc_tokens: List[str]) -> List[str]:
|
362 |
+
"""Find which query terms matched in the document"""
|
363 |
+
doc_token_set = set(doc_tokens)
|
364 |
+
matched = [token for token in query_tokens if token in doc_token_set]
|
365 |
+
return matched
|
366 |
+
|
367 |
+
def get_stats(self) -> Dict:
|
368 |
+
"""Get BM25 retriever statistics"""
|
369 |
+
return {
|
370 |
+
'total_documents': len(self.documents),
|
371 |
+
'index_available': self.bm25_index is not None,
|
372 |
+
'cache_files_exist': {
|
373 |
+
'index': self.bm25_cache_file.exists(),
|
374 |
+
'documents': self.docs_cache_file.exists()
|
375 |
+
},
|
376 |
+
'storage_path': str(self.storage_path)
|
377 |
+
}
|
378 |
+
|
379 |
+
def rebuild_index(self):
|
380 |
+
"""Force rebuild of BM25 index"""
|
381 |
+
logger.info("Rebuilding BM25 index...")
|
382 |
+
|
383 |
+
# Clear existing index
|
384 |
+
self.bm25_index = None
|
385 |
+
self.documents = []
|
386 |
+
self.doc_lookup = {}
|
387 |
+
|
388 |
+
# Remove cache files
|
389 |
+
try:
|
390 |
+
if self.bm25_cache_file.exists():
|
391 |
+
self.bm25_cache_file.unlink()
|
392 |
+
if self.docs_cache_file.exists():
|
393 |
+
self.docs_cache_file.unlink()
|
394 |
+
except Exception as e:
|
395 |
+
logger.warning(f"Failed to remove cache files: {e}")
|
396 |
+
|
397 |
+
# Rebuild
|
398 |
+
self._build_bm25_index()
|
399 |
+
|
400 |
+
|
401 |
+
def main():
|
402 |
+
"""Test the BM25 retriever"""
|
403 |
+
|
404 |
+
print("π Testing BM25 Retriever")
|
405 |
+
print("=" * 40)
|
406 |
+
|
407 |
+
# Initialize knowledge base and BM25 retriever
|
408 |
+
knowledge_base = KnowledgeBaseStorage()
|
409 |
+
bm25_retriever = BM25Retriever(knowledge_base)
|
410 |
+
|
411 |
+
# Get stats
|
412 |
+
stats = bm25_retriever.get_stats()
|
413 |
+
print(f"π BM25 Index Stats:")
|
414 |
+
print(f" Documents: {stats['total_documents']}")
|
415 |
+
print(f" Index Available: {stats['index_available']}")
|
416 |
+
|
417 |
+
if not stats['index_available']:
|
418 |
+
print("β BM25 index not available")
|
419 |
+
return
|
420 |
+
|
421 |
+
# Test queries
|
422 |
+
test_queries = [
|
423 |
+
"Netflix anomaly detection",
|
424 |
+
"isolation forest LinkedIn",
|
425 |
+
"fraud detection machine learning",
|
426 |
+
"bot detection Cloudflare",
|
427 |
+
"spam classification",
|
428 |
+
"graph neural networks"
|
429 |
+
]
|
430 |
+
|
431 |
+
print(f"\nπ Testing BM25 Search:")
|
432 |
+
for query in test_queries:
|
433 |
+
print(f"\nQuery: '{query}'")
|
434 |
+
|
435 |
+
results = bm25_retriever.search(query, n_results=3)
|
436 |
+
print(f"Results: {len(results)} found")
|
437 |
+
|
438 |
+
for i, result in enumerate(results, 1):
|
439 |
+
company = result.metadata.get('company', 'Unknown')
|
440 |
+
title = result.metadata.get('source_title', 'No title')[:50]
|
441 |
+
print(f" {i}. {company} - {title}...")
|
442 |
+
print(f" BM25 Score: {result.bm25_score:.3f}")
|
443 |
+
print(f" Matched Terms: {', '.join(result.matched_terms[:5])}")
|
444 |
+
|
445 |
+
print(f"\nβ
BM25 retriever test complete!")
|
446 |
+
|
447 |
+
|
448 |
+
if __name__ == "__main__":
|
449 |
+
main()
|
src/search/ml_agentic_retriever.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ML Agentic Retriever for SentrySearch
|
3 |
+
|
4 |
+
Implements Agentic RAG approach with intelligent query optimization.
|
5 |
+
Provides intelligent ML-focused retrieval for threat intelligence with:
|
6 |
+
- Query optimization for threat-to-ML translation
|
7 |
+
- Source identification for relevant paper filtering
|
8 |
+
- Enhanced hybrid retrieval with post-processing
|
9 |
+
- Context-aware result ranking and structuring
|
10 |
+
|
11 |
+
Usage:
|
12 |
+
retriever = MLAgenticRetriever(anthropic_client, knowledge_base)
|
13 |
+
ml_guidance = retriever.get_ml_guidance(threat_characteristics)
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import time
|
20 |
+
import random
|
21 |
+
import hashlib
|
22 |
+
from typing import List, Dict, Optional, Tuple, Set
|
23 |
+
from dataclasses import dataclass
|
24 |
+
import re
|
25 |
+
|
26 |
+
from anthropic import Anthropic
|
27 |
+
import anthropic
|
28 |
+
from src.data.ml_knowledge_base_builder import KnowledgeBaseStorage
|
29 |
+
from src.search.bm25_retriever import BM25Retriever, BM25SearchResult
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class ThreatCharacteristics:
|
36 |
+
"""Represents threat characteristics for ML guidance generation"""
|
37 |
+
threat_name: str
|
38 |
+
threat_type: str # e.g., "malware", "apt", "insider_threat"
|
39 |
+
attack_vectors: List[str] # e.g., ["network", "email", "web"]
|
40 |
+
target_assets: List[str] # e.g., ["user_accounts", "financial_data"]
|
41 |
+
behavior_patterns: List[str] # e.g., ["lateral_movement", "data_exfiltration"]
|
42 |
+
time_characteristics: str # e.g., "persistent", "burst", "periodic"
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class OptimizedQuery:
|
47 |
+
"""Represents an optimized query for ML retrieval"""
|
48 |
+
original_query: str
|
49 |
+
optimized_queries: List[str]
|
50 |
+
ml_focus_areas: List[str]
|
51 |
+
reasoning: str
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class SourceSelection:
|
56 |
+
"""Represents filtered source selection"""
|
57 |
+
relevant_companies: List[str]
|
58 |
+
relevant_years: List[str]
|
59 |
+
relevant_techniques: List[str]
|
60 |
+
reasoning: str
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class MLRetrievalResult:
|
65 |
+
"""Represents a structured ML retrieval result"""
|
66 |
+
content: str
|
67 |
+
metadata: Dict
|
68 |
+
relevance_score: float
|
69 |
+
source_paper: str
|
70 |
+
ml_techniques: List[str]
|
71 |
+
implementation_details: str
|
72 |
+
applicability_score: float
|
73 |
+
retrieval_method: str = 'unknown' # 'vector', 'bm25', or 'hybrid'
|
74 |
+
bm25_score: float = 0.0
|
75 |
+
hybrid_score: float = 0.0
|
76 |
+
matched_terms: List[str] = None
|
77 |
+
|
78 |
+
|
79 |
+
class QueryOptimizer:
|
80 |
+
"""Optimizes queries to focus on ML anomaly detection approaches"""
|
81 |
+
|
82 |
+
def __init__(self, anthropic_client):
|
83 |
+
self.client = anthropic_client
|
84 |
+
|
85 |
+
def _api_call_with_retry(self, **kwargs):
|
86 |
+
"""Make API call with intelligent retry logic using retry-after header"""
|
87 |
+
max_retries = 3
|
88 |
+
base_delay = 5
|
89 |
+
|
90 |
+
for attempt in range(max_retries):
|
91 |
+
try:
|
92 |
+
print(f"DEBUG: Query Optimizer API call attempt {attempt + 1}/{max_retries}")
|
93 |
+
return self.client.messages.create(**kwargs)
|
94 |
+
|
95 |
+
except anthropic.RateLimitError as e:
|
96 |
+
if attempt == max_retries - 1:
|
97 |
+
print(f"DEBUG: Query Optimizer rate limit exceeded after {max_retries} attempts")
|
98 |
+
raise e
|
99 |
+
|
100 |
+
# Check if the error response has retry-after information
|
101 |
+
retry_after = None
|
102 |
+
if hasattr(e, 'response') and e.response:
|
103 |
+
retry_after_header = e.response.headers.get('retry-after')
|
104 |
+
if retry_after_header:
|
105 |
+
try:
|
106 |
+
retry_after = float(retry_after_header)
|
107 |
+
print(f"DEBUG: Query Optimizer API provided retry-after: {retry_after} seconds")
|
108 |
+
except (ValueError, TypeError):
|
109 |
+
pass
|
110 |
+
|
111 |
+
# Use retry-after if available, otherwise exponential backoff
|
112 |
+
if retry_after:
|
113 |
+
delay = retry_after + random.uniform(1, 3)
|
114 |
+
else:
|
115 |
+
delay = base_delay * (2 ** attempt) + random.uniform(1, 5)
|
116 |
+
delay = min(delay, 120)
|
117 |
+
|
118 |
+
print(f"DEBUG: Query Optimizer rate limit hit. Waiting {delay:.1f} seconds before retry {attempt + 2}")
|
119 |
+
time.sleep(delay)
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
print(f"DEBUG: Query Optimizer non-rate-limit error: {e}")
|
123 |
+
raise e
|
124 |
+
|
125 |
+
def optimize_query(self, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery:
|
126 |
+
"""Convert threat characteristics into ML-focused queries"""
|
127 |
+
|
128 |
+
prompt = f"""
|
129 |
+
You are an expert in both cybersecurity threats and machine learning anomaly detection.
|
130 |
+
Convert this threat information into 3-5 specific queries about ML approaches for detection.
|
131 |
+
|
132 |
+
Threat Information:
|
133 |
+
- Name: {threat_characteristics.threat_name}
|
134 |
+
- Type: {threat_characteristics.threat_type}
|
135 |
+
- Attack Vectors: {', '.join(threat_characteristics.attack_vectors)}
|
136 |
+
- Target Assets: {', '.join(threat_characteristics.target_assets)}
|
137 |
+
- Behavior Patterns: {', '.join(threat_characteristics.behavior_patterns)}
|
138 |
+
- Time Characteristics: {threat_characteristics.time_characteristics}
|
139 |
+
|
140 |
+
Generate queries that focus on:
|
141 |
+
1. Specific ML techniques for detecting this threat type
|
142 |
+
2. Feature engineering approaches for the attack vectors
|
143 |
+
3. Behavioral analysis methods for the patterns observed
|
144 |
+
4. Implementation considerations for the target environment
|
145 |
+
|
146 |
+
Format your response as:
|
147 |
+
QUERIES:
|
148 |
+
1. [Query 1]
|
149 |
+
2. [Query 2]
|
150 |
+
3. [Query 3]
|
151 |
+
etc.
|
152 |
+
|
153 |
+
ML_FOCUS_AREAS: [comma-separated focus areas]
|
154 |
+
|
155 |
+
REASONING: [1-2 sentences explaining the ML approach rationale]
|
156 |
+
"""
|
157 |
+
|
158 |
+
try:
|
159 |
+
response = self._api_call_with_retry(
|
160 |
+
model="claude-sonnet-4-20250514",
|
161 |
+
max_tokens=600,
|
162 |
+
messages=[{"role": "user", "content": prompt}]
|
163 |
+
)
|
164 |
+
|
165 |
+
# Safe access to response content
|
166 |
+
if not response.content or len(response.content) == 0:
|
167 |
+
raise ValueError("Empty response from API")
|
168 |
+
|
169 |
+
content = response.content[0].text.strip()
|
170 |
+
return self._parse_optimization_response(content, threat_characteristics)
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
logger.error(f"Query optimization failed: {e}")
|
174 |
+
# Fallback to simple query
|
175 |
+
fallback_query = f"Machine learning approaches for detecting {threat_characteristics.threat_name}"
|
176 |
+
return OptimizedQuery(
|
177 |
+
original_query=threat_characteristics.threat_name,
|
178 |
+
optimized_queries=[fallback_query],
|
179 |
+
ml_focus_areas=["anomaly_detection"],
|
180 |
+
reasoning="Fallback query due to optimization failure"
|
181 |
+
)
|
182 |
+
|
183 |
+
def _parse_optimization_response(self, response: str, threat_characteristics: ThreatCharacteristics) -> OptimizedQuery:
|
184 |
+
"""Parse the LLM response into structured query optimization"""
|
185 |
+
|
186 |
+
queries = []
|
187 |
+
ml_focus_areas = []
|
188 |
+
reasoning = ""
|
189 |
+
|
190 |
+
lines = response.split('\n')
|
191 |
+
current_section = None
|
192 |
+
|
193 |
+
for line in lines:
|
194 |
+
line = line.strip()
|
195 |
+
|
196 |
+
if line.startswith('QUERIES:'):
|
197 |
+
current_section = 'queries'
|
198 |
+
continue
|
199 |
+
elif line.startswith('ML_FOCUS_AREAS:'):
|
200 |
+
current_section = 'focus'
|
201 |
+
ml_focus_areas = [area.strip() for area in line.replace('ML_FOCUS_AREAS:', '').split(',')]
|
202 |
+
continue
|
203 |
+
elif line.startswith('REASONING:'):
|
204 |
+
current_section = 'reasoning'
|
205 |
+
reasoning = line.replace('REASONING:', '').strip()
|
206 |
+
continue
|
207 |
+
|
208 |
+
if current_section == 'queries' and line:
|
209 |
+
# Extract query from numbered list
|
210 |
+
query_match = re.match(r'\d+\.\s*(.+)', line)
|
211 |
+
if query_match:
|
212 |
+
queries.append(query_match.group(1).strip())
|
213 |
+
elif current_section == 'reasoning' and line:
|
214 |
+
reasoning += ' ' + line
|
215 |
+
|
216 |
+
# Ensure we have at least one query
|
217 |
+
if not queries:
|
218 |
+
queries = [f"Machine learning detection approaches for {threat_characteristics.threat_name}"]
|
219 |
+
|
220 |
+
return OptimizedQuery(
|
221 |
+
original_query=threat_characteristics.threat_name,
|
222 |
+
optimized_queries=queries,
|
223 |
+
ml_focus_areas=ml_focus_areas if ml_focus_areas else ["anomaly_detection"],
|
224 |
+
reasoning=reasoning.strip()
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
class SourceIdentifier:
|
229 |
+
"""Identifies most relevant papers/sources for a given threat"""
|
230 |
+
|
231 |
+
def __init__(self, knowledge_base: KnowledgeBaseStorage):
|
232 |
+
self.knowledge_base = knowledge_base
|
233 |
+
self._load_source_mappings()
|
234 |
+
|
235 |
+
def _load_source_mappings(self):
|
236 |
+
"""Load mappings between threat types and relevant sources"""
|
237 |
+
|
238 |
+
# Company expertise mappings
|
239 |
+
self.company_expertise = {
|
240 |
+
'Netflix': ['performance_monitoring', 'infrastructure_anomalies', 'streaming_security'],
|
241 |
+
'LinkedIn': ['user_behavior', 'abuse_detection', 'social_platform_security'],
|
242 |
+
'Slack': ['communication_security', 'invite_spam', 'workspace_security'],
|
243 |
+
'Cloudflare': ['network_security', 'bot_detection', 'traffic_analysis'],
|
244 |
+
'Uber': ['fraud_detection', 'real_time_systems', 'human_in_the_loop'],
|
245 |
+
'Grab': ['financial_fraud', 'graph_analysis', 'transaction_security'],
|
246 |
+
'OLX Group': ['marketplace_fraud', 'deep_learning', 'user_verification'],
|
247 |
+
'Stack Exchange': ['content_moderation', 'spam_detection', 'community_security'],
|
248 |
+
'Mercari': ['e_commerce_security', 'content_moderation', 'automated_review']
|
249 |
+
}
|
250 |
+
|
251 |
+
# Attack vector to technique mappings
|
252 |
+
self.attack_vector_techniques = {
|
253 |
+
'network': ['traffic_analysis', 'graph_ml', 'behavioral_analysis'],
|
254 |
+
'email': ['text_classification', 'nlp', 'spam_detection'],
|
255 |
+
'web': ['bot_detection', 'traffic_analysis', 'behavioral_analysis'],
|
256 |
+
'insider': ['user_behavior', 'behavioral_analysis', 'anomaly_detection'],
|
257 |
+
'financial': ['fraud_detection', 'transaction_analysis', 'graph_ml']
|
258 |
+
}
|
259 |
+
|
260 |
+
# Threat type to ML approach mappings
|
261 |
+
self.threat_ml_mappings = {
|
262 |
+
'malware': ['behavioral_analysis', 'static_analysis', 'dynamic_analysis'],
|
263 |
+
'apt': ['behavioral_analysis', 'network_analysis', 'long_term_patterns'],
|
264 |
+
'fraud': ['fraud_detection', 'transaction_analysis', 'user_behavior'],
|
265 |
+
'spam': ['text_classification', 'content_moderation', 'nlp'],
|
266 |
+
'abuse': ['user_behavior', 'abuse_detection', 'behavioral_analysis']
|
267 |
+
}
|
268 |
+
|
269 |
+
def identify_relevant_sources(self, optimized_query: OptimizedQuery,
|
270 |
+
threat_characteristics: ThreatCharacteristics) -> SourceSelection:
|
271 |
+
"""Identify most relevant papers/sources for the optimized queries"""
|
272 |
+
|
273 |
+
relevant_companies = set()
|
274 |
+
relevant_techniques = set()
|
275 |
+
|
276 |
+
# Map attack vectors to companies
|
277 |
+
for vector in threat_characteristics.attack_vectors:
|
278 |
+
for company, expertise in self.company_expertise.items():
|
279 |
+
for expert_area in expertise:
|
280 |
+
if any(keyword in expert_area for keyword in [vector, threat_characteristics.threat_type]):
|
281 |
+
relevant_companies.add(company)
|
282 |
+
|
283 |
+
# Add techniques for this attack vector
|
284 |
+
if vector in self.attack_vector_techniques:
|
285 |
+
relevant_techniques.update(self.attack_vector_techniques[vector])
|
286 |
+
|
287 |
+
# Map threat type to companies and techniques
|
288 |
+
threat_type = threat_characteristics.threat_type.lower()
|
289 |
+
for company, expertise in self.company_expertise.items():
|
290 |
+
if any(threat_type in expert_area or expert_area in threat_type for expert_area in expertise):
|
291 |
+
relevant_companies.add(company)
|
292 |
+
|
293 |
+
# Add techniques for threat type
|
294 |
+
if threat_type in self.threat_ml_mappings:
|
295 |
+
relevant_techniques.update(self.threat_ml_mappings[threat_type])
|
296 |
+
|
297 |
+
# Add techniques from ML focus areas
|
298 |
+
relevant_techniques.update(optimized_query.ml_focus_areas)
|
299 |
+
|
300 |
+
# Always include available companies from knowledge base
|
301 |
+
stats = self.knowledge_base.get_stats()
|
302 |
+
available_companies = set(stats['companies'])
|
303 |
+
|
304 |
+
# For now, include all available companies to ensure we get results
|
305 |
+
# In production, you can make this more selective
|
306 |
+
relevant_companies.update(available_companies)
|
307 |
+
|
308 |
+
# If still empty somehow, use all available
|
309 |
+
if not relevant_companies:
|
310 |
+
relevant_companies = available_companies
|
311 |
+
|
312 |
+
# Generate reasoning
|
313 |
+
reasoning = f"Selected companies based on expertise in {threat_characteristics.threat_type} and {', '.join(threat_characteristics.attack_vectors)}. Focus on {', '.join(list(relevant_techniques)[:3])} techniques."
|
314 |
+
|
315 |
+
return SourceSelection(
|
316 |
+
relevant_companies=list(relevant_companies),
|
317 |
+
relevant_years=['2019', '2020', '2021', '2022'], # All available years
|
318 |
+
relevant_techniques=list(relevant_techniques),
|
319 |
+
reasoning=reasoning
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
class EnhancedRetriever:
|
324 |
+
"""Enhanced hybrid retriever with vector search + BM25 and post-processing"""
|
325 |
+
|
326 |
+
def __init__(self, knowledge_base: KnowledgeBaseStorage):
|
327 |
+
self.knowledge_base = knowledge_base
|
328 |
+
self.bm25_retriever = BM25Retriever(knowledge_base)
|
329 |
+
|
330 |
+
def retrieve_with_context(self, optimized_query: OptimizedQuery,
|
331 |
+
source_selection: SourceSelection,
|
332 |
+
max_results: int = 10,
|
333 |
+
trace_exporter=None) -> List[MLRetrievalResult]:
|
334 |
+
"""Hybrid retrieval using both vector search and BM25, with context-aware processing"""
|
335 |
+
|
336 |
+
vector_results = []
|
337 |
+
bm25_results = []
|
338 |
+
|
339 |
+
# Log query optimization to trace
|
340 |
+
if trace_exporter:
|
341 |
+
trace_exporter.log_query_optimization(
|
342 |
+
optimized_query.original_query,
|
343 |
+
optimized_query.optimized_queries,
|
344 |
+
optimized_query.reasoning,
|
345 |
+
optimized_query.ml_focus_areas
|
346 |
+
)
|
347 |
+
|
348 |
+
# 1. Vector Search - Search with each optimized query
|
349 |
+
if trace_exporter:
|
350 |
+
trace_exporter.log_stage_start("vector_retrieval")
|
351 |
+
|
352 |
+
for query in optimized_query.optimized_queries:
|
353 |
+
results = self.knowledge_base.search(query, n_results=max_results)
|
354 |
+
|
355 |
+
for result in results:
|
356 |
+
# Filter by relevant sources
|
357 |
+
metadata = result['metadata']
|
358 |
+
company = metadata.get('company', '')
|
359 |
+
|
360 |
+
if company in source_selection.relevant_companies:
|
361 |
+
ml_result = self._create_ml_result(result, optimized_query, source_selection)
|
362 |
+
ml_result.retrieval_method = 'vector'
|
363 |
+
vector_results.append(ml_result)
|
364 |
+
|
365 |
+
if trace_exporter:
|
366 |
+
trace_exporter.log_stage_end("vector_retrieval", result_count=len(vector_results))
|
367 |
+
# Convert to format expected by trace exporter
|
368 |
+
vector_trace_results = [self._convert_to_trace_format(r) for r in vector_results]
|
369 |
+
trace_exporter.log_retrieval_results(vector_trace_results, "vector")
|
370 |
+
|
371 |
+
# 2. BM25 Search - Search with each optimized query
|
372 |
+
if trace_exporter:
|
373 |
+
trace_exporter.log_stage_start("bm25_retrieval")
|
374 |
+
|
375 |
+
for query in optimized_query.optimized_queries:
|
376 |
+
bm25_search_results = self.bm25_retriever.search(query, n_results=max_results)
|
377 |
+
|
378 |
+
for bm25_result in bm25_search_results:
|
379 |
+
# Filter by relevant sources
|
380 |
+
company = bm25_result.metadata.get('company', '')
|
381 |
+
|
382 |
+
if company in source_selection.relevant_companies:
|
383 |
+
# Convert BM25 result to ML result format
|
384 |
+
ml_result = self._create_ml_result_from_bm25(bm25_result, optimized_query, source_selection)
|
385 |
+
ml_result.retrieval_method = 'bm25'
|
386 |
+
bm25_results.append(ml_result)
|
387 |
+
|
388 |
+
if trace_exporter:
|
389 |
+
trace_exporter.log_stage_end("bm25_retrieval", result_count=len(bm25_results))
|
390 |
+
# Convert to format expected by trace exporter
|
391 |
+
bm25_trace_results = [self._convert_to_trace_format(r) for r in bm25_results]
|
392 |
+
trace_exporter.log_retrieval_results(bm25_trace_results, "bm25")
|
393 |
+
|
394 |
+
# 3. Combine and deduplicate results
|
395 |
+
if trace_exporter:
|
396 |
+
trace_exporter.log_stage_start("hybrid_fusion")
|
397 |
+
|
398 |
+
all_results = self._fuse_hybrid_results(vector_results, bm25_results)
|
399 |
+
|
400 |
+
# 4. Post-process results
|
401 |
+
processed_results = self._post_process_results(all_results)
|
402 |
+
|
403 |
+
# 5. Rank by hybrid score combining relevance and applicability
|
404 |
+
ranked_results = sorted(processed_results,
|
405 |
+
key=lambda x: x.hybrid_score,
|
406 |
+
reverse=True)
|
407 |
+
|
408 |
+
final_results = ranked_results[:max_results]
|
409 |
+
|
410 |
+
# Log hybrid fusion results
|
411 |
+
if trace_exporter:
|
412 |
+
trace_exporter.log_stage_end("hybrid_fusion", final_result_count=len(final_results))
|
413 |
+
# Convert to format expected by trace exporter
|
414 |
+
hybrid_trace_results = [self._convert_to_trace_format(r) for r in final_results]
|
415 |
+
trace_exporter.log_retrieval_results(hybrid_trace_results, "hybrid")
|
416 |
+
|
417 |
+
# Log hybrid scoring if available
|
418 |
+
if final_results:
|
419 |
+
avg_vector_score = sum(r.relevance_score for r in final_results) / len(final_results)
|
420 |
+
avg_bm25_score = sum(getattr(r, 'bm25_score', 0) for r in final_results) / len(final_results)
|
421 |
+
avg_hybrid_score = sum(r.hybrid_score for r in final_results) / len(final_results)
|
422 |
+
avg_applicability_score = sum(r.applicability_score for r in final_results) / len(final_results)
|
423 |
+
|
424 |
+
trace_exporter.log_hybrid_scoring(
|
425 |
+
avg_vector_score, avg_bm25_score, avg_hybrid_score, avg_applicability_score
|
426 |
+
)
|
427 |
+
|
428 |
+
return final_results
|
429 |
+
|
430 |
+
def _create_ml_result(self, search_result: Dict,
|
431 |
+
optimized_query: OptimizedQuery,
|
432 |
+
source_selection: SourceSelection) -> MLRetrievalResult:
|
433 |
+
"""Convert search result to structured ML result"""
|
434 |
+
|
435 |
+
metadata = search_result['metadata']
|
436 |
+
|
437 |
+
# Calculate applicability score based on technique overlap
|
438 |
+
ml_techniques_raw = metadata.get('ml_techniques', '')
|
439 |
+
# Ensure we have a string before splitting
|
440 |
+
ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else ''
|
441 |
+
paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip())
|
442 |
+
relevant_techniques = set(source_selection.relevant_techniques)
|
443 |
+
|
444 |
+
technique_overlap = len(paper_techniques.intersection(relevant_techniques))
|
445 |
+
applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0)
|
446 |
+
|
447 |
+
# Extract implementation details from content
|
448 |
+
content = search_result['document']
|
449 |
+
implementation_details = self._extract_implementation_details(content)
|
450 |
+
|
451 |
+
result = MLRetrievalResult(
|
452 |
+
content=content,
|
453 |
+
metadata=metadata,
|
454 |
+
relevance_score=search_result['score'],
|
455 |
+
source_paper=metadata.get('source_title', ''),
|
456 |
+
ml_techniques=list(paper_techniques),
|
457 |
+
implementation_details=implementation_details,
|
458 |
+
applicability_score=applicability_score,
|
459 |
+
retrieval_method='vector'
|
460 |
+
)
|
461 |
+
|
462 |
+
# Calculate hybrid score (for vector results, this is just the combination)
|
463 |
+
result.hybrid_score = (result.relevance_score * 0.6 + result.applicability_score * 0.4)
|
464 |
+
|
465 |
+
return result
|
466 |
+
|
467 |
+
def _extract_implementation_details(self, content: str) -> str:
|
468 |
+
"""Extract key implementation details from content"""
|
469 |
+
|
470 |
+
# Look for implementation-specific keywords
|
471 |
+
impl_keywords = [
|
472 |
+
'architecture', 'framework', 'algorithm', 'model',
|
473 |
+
'feature', 'training', 'deployment', 'performance',
|
474 |
+
'accuracy', 'precision', 'recall', 'latency'
|
475 |
+
]
|
476 |
+
|
477 |
+
sentences = content.split('.')
|
478 |
+
impl_sentences = []
|
479 |
+
|
480 |
+
for sentence in sentences:
|
481 |
+
if any(keyword in sentence.lower() for keyword in impl_keywords):
|
482 |
+
impl_sentences.append(sentence.strip())
|
483 |
+
|
484 |
+
return '. '.join(impl_sentences[:3]) # Top 3 relevant sentences
|
485 |
+
|
486 |
+
def _post_process_results(self, results: List[MLRetrievalResult]) -> List[MLRetrievalResult]:
|
487 |
+
"""Post-process results for deduplication and enhancement"""
|
488 |
+
|
489 |
+
# Simple deduplication by content hash
|
490 |
+
seen_hashes = set()
|
491 |
+
deduplicated = []
|
492 |
+
|
493 |
+
for result in results:
|
494 |
+
content_hash = hash(result.content[:200]) # Hash first 200 chars
|
495 |
+
|
496 |
+
if content_hash not in seen_hashes:
|
497 |
+
seen_hashes.add(content_hash)
|
498 |
+
deduplicated.append(result)
|
499 |
+
|
500 |
+
return deduplicated
|
501 |
+
|
502 |
+
def _convert_to_trace_format(self, ml_result: MLRetrievalResult) -> Dict:
|
503 |
+
"""Convert MLRetrievalResult to format expected by trace exporter"""
|
504 |
+
return {
|
505 |
+
"content": ml_result.content,
|
506 |
+
"metadata": ml_result.metadata,
|
507 |
+
"score": ml_result.relevance_score,
|
508 |
+
"method": ml_result.retrieval_method,
|
509 |
+
"matched_terms": getattr(ml_result, 'matched_terms', None),
|
510 |
+
"source_company": ml_result.metadata.get('company'),
|
511 |
+
"ml_techniques": ml_result.ml_techniques
|
512 |
+
}
|
513 |
+
|
514 |
+
def _create_ml_result_from_bm25(self, bm25_result: BM25SearchResult,
|
515 |
+
optimized_query: OptimizedQuery,
|
516 |
+
source_selection: SourceSelection) -> MLRetrievalResult:
|
517 |
+
"""Convert BM25 search result to ML result format"""
|
518 |
+
|
519 |
+
metadata = bm25_result.metadata
|
520 |
+
|
521 |
+
# Calculate applicability score based on technique overlap
|
522 |
+
ml_techniques_raw = metadata.get('ml_techniques', '')
|
523 |
+
ml_techniques_str = str(ml_techniques_raw) if ml_techniques_raw else ''
|
524 |
+
paper_techniques = set(str(t).strip() for t in ml_techniques_str.split('|') if str(t).strip())
|
525 |
+
relevant_techniques = set(source_selection.relevant_techniques)
|
526 |
+
|
527 |
+
technique_overlap = len(paper_techniques.intersection(relevant_techniques))
|
528 |
+
applicability_score = min(technique_overlap / max(len(relevant_techniques), 1), 1.0)
|
529 |
+
|
530 |
+
# Extract implementation details from content
|
531 |
+
implementation_details = self._extract_implementation_details(bm25_result.content)
|
532 |
+
|
533 |
+
result = MLRetrievalResult(
|
534 |
+
content=bm25_result.content,
|
535 |
+
metadata=metadata,
|
536 |
+
relevance_score=bm25_result.relevance_score,
|
537 |
+
source_paper=metadata.get('source_title', ''),
|
538 |
+
ml_techniques=list(paper_techniques),
|
539 |
+
implementation_details=implementation_details,
|
540 |
+
applicability_score=applicability_score,
|
541 |
+
retrieval_method='bm25',
|
542 |
+
bm25_score=bm25_result.bm25_score,
|
543 |
+
matched_terms=bm25_result.matched_terms
|
544 |
+
)
|
545 |
+
|
546 |
+
# Calculate hybrid score (for BM25 results, give more weight to exact matches)
|
547 |
+
bm25_weight = min(bm25_result.bm25_score / 5.0, 1.0) # Normalize BM25 score
|
548 |
+
term_match_bonus = len(bm25_result.matched_terms) * 0.1 # Bonus for matched terms
|
549 |
+
result.hybrid_score = (bm25_weight * 0.5 + result.applicability_score * 0.4 + term_match_bonus)
|
550 |
+
|
551 |
+
return result
|
552 |
+
|
553 |
+
def _fuse_hybrid_results(self, vector_results: List[MLRetrievalResult],
|
554 |
+
bm25_results: List[MLRetrievalResult]) -> List[MLRetrievalResult]:
|
555 |
+
"""Fuse vector and BM25 results using reciprocal rank fusion"""
|
556 |
+
|
557 |
+
# Create dictionaries for fast lookup
|
558 |
+
vector_lookup = {self._get_result_key(r): r for r in vector_results}
|
559 |
+
bm25_lookup = {self._get_result_key(r): r for r in bm25_results}
|
560 |
+
|
561 |
+
# Get all unique result keys
|
562 |
+
all_keys = set(vector_lookup.keys()) | set(bm25_lookup.keys())
|
563 |
+
|
564 |
+
fused_results = []
|
565 |
+
|
566 |
+
for key in all_keys:
|
567 |
+
vector_result = vector_lookup.get(key)
|
568 |
+
bm25_result = bm25_lookup.get(key)
|
569 |
+
|
570 |
+
if vector_result and bm25_result:
|
571 |
+
# Both methods found this result - create hybrid
|
572 |
+
hybrid_result = self._create_hybrid_result(vector_result, bm25_result)
|
573 |
+
fused_results.append(hybrid_result)
|
574 |
+
elif vector_result:
|
575 |
+
# Only vector search found this
|
576 |
+
fused_results.append(vector_result)
|
577 |
+
elif bm25_result:
|
578 |
+
# Only BM25 found this
|
579 |
+
fused_results.append(bm25_result)
|
580 |
+
|
581 |
+
return fused_results
|
582 |
+
|
583 |
+
def _get_result_key(self, result: MLRetrievalResult) -> str:
|
584 |
+
"""Generate a unique key for a result based on content hash"""
|
585 |
+
# Use first 100 characters of content as key to detect near-duplicates
|
586 |
+
content_key = result.content[:100] if result.content else ""
|
587 |
+
return hashlib.md5(content_key.encode()).hexdigest()[:16]
|
588 |
+
|
589 |
+
def _create_hybrid_result(self, vector_result: MLRetrievalResult,
|
590 |
+
bm25_result: MLRetrievalResult) -> MLRetrievalResult:
|
591 |
+
"""Create a hybrid result by combining vector and BM25 results"""
|
592 |
+
|
593 |
+
# Use vector result as base and enhance with BM25 data
|
594 |
+
hybrid_result = MLRetrievalResult(
|
595 |
+
content=vector_result.content,
|
596 |
+
metadata=vector_result.metadata,
|
597 |
+
relevance_score=max(vector_result.relevance_score, bm25_result.relevance_score),
|
598 |
+
source_paper=vector_result.source_paper,
|
599 |
+
ml_techniques=vector_result.ml_techniques,
|
600 |
+
implementation_details=vector_result.implementation_details,
|
601 |
+
applicability_score=max(vector_result.applicability_score, bm25_result.applicability_score),
|
602 |
+
retrieval_method='hybrid',
|
603 |
+
bm25_score=bm25_result.bm25_score,
|
604 |
+
matched_terms=bm25_result.matched_terms
|
605 |
+
)
|
606 |
+
|
607 |
+
# Calculate enhanced hybrid score using reciprocal rank fusion concept
|
608 |
+
vector_score = vector_result.hybrid_score
|
609 |
+
bm25_score = bm25_result.hybrid_score
|
610 |
+
|
611 |
+
# Reciprocal rank fusion with k=60 (standard value)
|
612 |
+
k = 60
|
613 |
+
vector_rank = 1.0 / (k + vector_score * 100) # Convert score to rank
|
614 |
+
bm25_rank = 1.0 / (k + bm25_score * 100)
|
615 |
+
|
616 |
+
# Combine ranks and add bonus for being found by both methods
|
617 |
+
hybrid_result.hybrid_score = (vector_rank + bm25_rank) + 0.2 # 0.2 bonus for hybrid
|
618 |
+
|
619 |
+
return hybrid_result
|
620 |
+
|
621 |
+
|
622 |
+
class MLAgenticRetriever:
|
623 |
+
"""Main agentic retriever orchestrating all components"""
|
624 |
+
|
625 |
+
def __init__(self, anthropic_client, knowledge_base_path: str = "./ml_knowledge_base"):
|
626 |
+
self.client = anthropic_client
|
627 |
+
self.knowledge_base = KnowledgeBaseStorage(knowledge_base_path)
|
628 |
+
|
629 |
+
# Initialize agents
|
630 |
+
self.query_optimizer = QueryOptimizer(anthropic_client)
|
631 |
+
self.source_identifier = SourceIdentifier(self.knowledge_base)
|
632 |
+
self.enhanced_retriever = EnhancedRetriever(self.knowledge_base)
|
633 |
+
|
634 |
+
logger.info("ML Agentic Retriever initialized")
|
635 |
+
|
636 |
+
def get_ml_guidance(self, threat_characteristics: ThreatCharacteristics, trace_exporter=None) -> Dict:
|
637 |
+
"""Get comprehensive ML guidance for threat detection"""
|
638 |
+
|
639 |
+
try:
|
640 |
+
logger.info(f"Getting ML guidance for: {threat_characteristics.threat_name}")
|
641 |
+
|
642 |
+
# Step 1: Query Optimization
|
643 |
+
optimized_query = self.query_optimizer.optimize_query(threat_characteristics)
|
644 |
+
logger.info(f"Generated {len(optimized_query.optimized_queries)} optimized queries")
|
645 |
+
|
646 |
+
# Step 2: Source Identification
|
647 |
+
source_selection = self.source_identifier.identify_relevant_sources(
|
648 |
+
optimized_query, threat_characteristics
|
649 |
+
)
|
650 |
+
logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies")
|
651 |
+
|
652 |
+
# Step 3: Enhanced Retrieval
|
653 |
+
ml_results = self.enhanced_retriever.retrieve_with_context(
|
654 |
+
optimized_query, source_selection, max_results=8,
|
655 |
+
trace_exporter=trace_exporter
|
656 |
+
)
|
657 |
+
logger.info(f"Retrieved {len(ml_results)} relevant ML approaches")
|
658 |
+
|
659 |
+
# Step 4: Structure results
|
660 |
+
guidance = self._structure_ml_guidance(
|
661 |
+
threat_characteristics, optimized_query, source_selection, ml_results
|
662 |
+
)
|
663 |
+
|
664 |
+
return guidance
|
665 |
+
|
666 |
+
except Exception as e:
|
667 |
+
logger.error(f"ML guidance generation failed: {e}")
|
668 |
+
return self._create_fallback_guidance(threat_characteristics)
|
669 |
+
|
670 |
+
def get_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
|
671 |
+
complete_threat_data: Dict, trace_exporter=None) -> Dict:
|
672 |
+
"""Get enhanced ML guidance leveraging complete threat intelligence context"""
|
673 |
+
|
674 |
+
try:
|
675 |
+
logger.info(f"Getting enhanced ML guidance for: {threat_characteristics.threat_name}")
|
676 |
+
|
677 |
+
# Step 1: Enhanced Query Optimization with threat context
|
678 |
+
optimized_query = self._optimize_query_with_context(threat_characteristics, complete_threat_data)
|
679 |
+
logger.info(f"Generated {len(optimized_query.optimized_queries)} context-enhanced queries")
|
680 |
+
|
681 |
+
# Step 2: Enhanced Source Identification
|
682 |
+
source_selection = self.source_identifier.identify_relevant_sources(
|
683 |
+
optimized_query, threat_characteristics
|
684 |
+
)
|
685 |
+
logger.info(f"Identified {len(source_selection.relevant_companies)} relevant companies")
|
686 |
+
|
687 |
+
# Step 3: Enhanced Retrieval with threat context
|
688 |
+
ml_results = self.enhanced_retriever.retrieve_with_context(
|
689 |
+
optimized_query, source_selection, max_results=10, # More results for enhanced mode
|
690 |
+
trace_exporter=trace_exporter
|
691 |
+
)
|
692 |
+
logger.info(f"Retrieved {len(ml_results)} relevant ML approaches")
|
693 |
+
|
694 |
+
# Step 4: Structure results with enhanced context
|
695 |
+
guidance = self._structure_enhanced_ml_guidance(
|
696 |
+
threat_characteristics, optimized_query, source_selection, ml_results, complete_threat_data
|
697 |
+
)
|
698 |
+
|
699 |
+
return guidance
|
700 |
+
|
701 |
+
except Exception as e:
|
702 |
+
logger.error(f"Enhanced ML guidance generation failed: {e}")
|
703 |
+
return self._create_enhanced_fallback_guidance(threat_characteristics, complete_threat_data)
|
704 |
+
|
705 |
+
def _structure_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
|
706 |
+
optimized_query: OptimizedQuery,
|
707 |
+
source_selection: SourceSelection,
|
708 |
+
ml_results: List[MLRetrievalResult]) -> Dict:
|
709 |
+
"""Structure the ML guidance into organized sections"""
|
710 |
+
|
711 |
+
# Group results by ML technique
|
712 |
+
techniques_map = {}
|
713 |
+
for result in ml_results:
|
714 |
+
for technique in result.ml_techniques:
|
715 |
+
if technique not in techniques_map:
|
716 |
+
techniques_map[technique] = []
|
717 |
+
techniques_map[technique].append(result)
|
718 |
+
|
719 |
+
# Create structured guidance
|
720 |
+
guidance = {
|
721 |
+
'threat_name': threat_characteristics.threat_name,
|
722 |
+
'ml_approaches': [],
|
723 |
+
'implementation_considerations': [],
|
724 |
+
'source_papers': [],
|
725 |
+
'query_optimization': {
|
726 |
+
'original_query': optimized_query.original_query,
|
727 |
+
'optimized_queries': optimized_query.optimized_queries,
|
728 |
+
'reasoning': optimized_query.reasoning
|
729 |
+
},
|
730 |
+
'source_selection': {
|
731 |
+
'relevant_companies': source_selection.relevant_companies,
|
732 |
+
'reasoning': source_selection.reasoning
|
733 |
+
}
|
734 |
+
}
|
735 |
+
|
736 |
+
# Add ML approaches
|
737 |
+
for technique, results in techniques_map.items():
|
738 |
+
if results: # Only include techniques with results
|
739 |
+
best_result = max(results, key=lambda x: x.applicability_score)
|
740 |
+
|
741 |
+
approach = {
|
742 |
+
'technique': str(technique),
|
743 |
+
'description': best_result.implementation_details,
|
744 |
+
'source_company': str(best_result.metadata.get('company', '')),
|
745 |
+
'source_paper': str(best_result.source_paper),
|
746 |
+
'applicability_score': best_result.applicability_score,
|
747 |
+
'relevance_score': best_result.relevance_score,
|
748 |
+
'retrieval_method': best_result.retrieval_method,
|
749 |
+
'hybrid_score': best_result.hybrid_score,
|
750 |
+
'bm25_score': best_result.bm25_score,
|
751 |
+
'matched_terms': best_result.matched_terms if best_result.matched_terms else []
|
752 |
+
}
|
753 |
+
|
754 |
+
guidance['ml_approaches'].append(approach)
|
755 |
+
|
756 |
+
# Add implementation considerations
|
757 |
+
for result in ml_results[:3]: # Top 3 results
|
758 |
+
consideration = {
|
759 |
+
'aspect': f"{result.metadata.get('company', '')} Implementation",
|
760 |
+
'details': result.implementation_details,
|
761 |
+
'source': result.source_paper
|
762 |
+
}
|
763 |
+
guidance['implementation_considerations'].append(consideration)
|
764 |
+
|
765 |
+
# Add source papers
|
766 |
+
seen_papers = set()
|
767 |
+
for result in ml_results:
|
768 |
+
paper_title = result.source_paper
|
769 |
+
if paper_title not in seen_papers:
|
770 |
+
seen_papers.add(paper_title)
|
771 |
+
paper_info = {
|
772 |
+
'title': paper_title,
|
773 |
+
'company': result.metadata.get('company', ''),
|
774 |
+
'year': result.metadata.get('year', ''),
|
775 |
+
'url': result.metadata.get('source_url', ''),
|
776 |
+
'techniques': result.ml_techniques
|
777 |
+
}
|
778 |
+
guidance['source_papers'].append(paper_info)
|
779 |
+
|
780 |
+
return guidance
|
781 |
+
|
782 |
+
def _create_fallback_guidance(self, threat_characteristics: ThreatCharacteristics) -> Dict:
|
783 |
+
"""Create fallback guidance when main pipeline fails"""
|
784 |
+
|
785 |
+
return {
|
786 |
+
'threat_name': threat_characteristics.threat_name,
|
787 |
+
'ml_approaches': [{
|
788 |
+
'technique': 'anomaly_detection',
|
789 |
+
'description': 'General anomaly detection approaches using statistical methods and machine learning',
|
790 |
+
'source_company': 'General',
|
791 |
+
'source_paper': 'Fallback recommendation',
|
792 |
+
'applicability_score': 0.5,
|
793 |
+
'relevance_score': 0.5
|
794 |
+
}],
|
795 |
+
'implementation_considerations': [{
|
796 |
+
'aspect': 'General Implementation',
|
797 |
+
'details': 'Consider implementing statistical anomaly detection as a baseline approach',
|
798 |
+
'source': 'Fallback recommendation'
|
799 |
+
}],
|
800 |
+
'source_papers': [],
|
801 |
+
'error': 'ML guidance generation failed - fallback recommendations provided'
|
802 |
+
}
|
803 |
+
|
804 |
+
|
805 |
+
def create_test_threat_characteristics() -> ThreatCharacteristics:
|
806 |
+
"""Create test threat characteristics for validation"""
|
807 |
+
|
808 |
+
return ThreatCharacteristics(
|
809 |
+
threat_name="ShadowPad",
|
810 |
+
threat_type="malware",
|
811 |
+
attack_vectors=["network", "lateral_movement"],
|
812 |
+
target_assets=["corporate_networks", "sensitive_data"],
|
813 |
+
behavior_patterns=["persistence", "data_exfiltration", "command_control"],
|
814 |
+
time_characteristics="persistent"
|
815 |
+
)
|
816 |
+
|
817 |
+
|
818 |
+
# Add enhanced methods to MLAgenticRetriever class
|
819 |
+
def _optimize_query_with_context(self, threat_characteristics: ThreatCharacteristics,
|
820 |
+
complete_threat_data: Dict) -> OptimizedQuery:
|
821 |
+
"""Create enhanced queries using complete threat intelligence context"""
|
822 |
+
|
823 |
+
# Extract additional context for query enhancement
|
824 |
+
context_elements = []
|
825 |
+
|
826 |
+
# Technical capabilities
|
827 |
+
if tech_details := complete_threat_data.get('technicalDetails'):
|
828 |
+
if capabilities := tech_details.get('capabilities'):
|
829 |
+
# Ensure capabilities is a list before slicing
|
830 |
+
if isinstance(capabilities, list):
|
831 |
+
cap_names = [cap.get('name', str(cap)) if isinstance(cap, dict) else str(cap)
|
832 |
+
for cap in capabilities[:3]]
|
833 |
+
context_elements.extend(cap_names)
|
834 |
+
|
835 |
+
# C2 protocols
|
836 |
+
if c2_data := complete_threat_data.get('commandAndControl'):
|
837 |
+
if methods := c2_data.get('communicationMethods'):
|
838 |
+
# Ensure methods is a list before slicing
|
839 |
+
if isinstance(methods, list):
|
840 |
+
protocols = [method.get('protocol', str(method)) if isinstance(method, dict) else str(method)
|
841 |
+
for method in methods[:2]]
|
842 |
+
context_elements.extend(protocols)
|
843 |
+
|
844 |
+
# Use the regular optimizer with enhanced threat characteristics
|
845 |
+
enhanced_characteristics = ThreatCharacteristics(
|
846 |
+
threat_name=threat_characteristics.threat_name,
|
847 |
+
threat_type=threat_characteristics.threat_type,
|
848 |
+
attack_vectors=threat_characteristics.attack_vectors + context_elements[:2],
|
849 |
+
target_assets=threat_characteristics.target_assets,
|
850 |
+
behavior_patterns=threat_characteristics.behavior_patterns + context_elements[2:4],
|
851 |
+
time_characteristics=threat_characteristics.time_characteristics
|
852 |
+
)
|
853 |
+
|
854 |
+
return self.query_optimizer.optimize_query(enhanced_characteristics)
|
855 |
+
|
856 |
+
def _structure_enhanced_ml_guidance(self, threat_characteristics: ThreatCharacteristics,
|
857 |
+
optimized_query: OptimizedQuery,
|
858 |
+
source_selection: SourceSelection,
|
859 |
+
ml_results: List, # MLRetrievalResult type
|
860 |
+
complete_threat_data: Dict) -> Dict:
|
861 |
+
"""Structure enhanced ML guidance with threat context"""
|
862 |
+
|
863 |
+
# Start with regular structuring
|
864 |
+
guidance = self._structure_ml_guidance(
|
865 |
+
threat_characteristics, optimized_query, source_selection, ml_results
|
866 |
+
)
|
867 |
+
|
868 |
+
# Enhance with threat context
|
869 |
+
guidance['threat_context_applied'] = True
|
870 |
+
guidance['context_sources'] = {
|
871 |
+
'technical_details': bool(complete_threat_data.get('technicalDetails')),
|
872 |
+
'command_and_control': bool(complete_threat_data.get('commandAndControl')),
|
873 |
+
'detection_and_mitigation': bool(complete_threat_data.get('detectionAndMitigation')),
|
874 |
+
'forensic_artifacts': bool(complete_threat_data.get('forensicArtifacts'))
|
875 |
+
}
|
876 |
+
|
877 |
+
# Add threat-specific implementation considerations
|
878 |
+
if tech_details := complete_threat_data.get('technicalDetails'):
|
879 |
+
if os_data := tech_details.get('operatingSystems'):
|
880 |
+
# Ensure os_data is a list before slicing
|
881 |
+
if isinstance(os_data, list):
|
882 |
+
os_names = [os.get('name', str(os)) if isinstance(os, dict) else str(os) for os in os_data[:2]]
|
883 |
+
guidance['implementation_considerations'].append({
|
884 |
+
'aspect': 'OS Compatibility',
|
885 |
+
'details': f'Ensure ML models are trained on {", ".join(os_names)} environments for optimal detection.',
|
886 |
+
'source': 'Threat Intelligence Profile'
|
887 |
+
})
|
888 |
+
|
889 |
+
return guidance
|
890 |
+
|
891 |
+
def _create_enhanced_fallback_guidance(self, threat_characteristics: ThreatCharacteristics,
|
892 |
+
complete_threat_data: Dict) -> Dict:
|
893 |
+
"""Create enhanced fallback guidance with threat context"""
|
894 |
+
|
895 |
+
fallback = self._create_fallback_guidance(threat_characteristics)
|
896 |
+
|
897 |
+
# Add context-aware recommendations
|
898 |
+
fallback['threat_context_applied'] = True
|
899 |
+
fallback['enhanced_fallback'] = True
|
900 |
+
|
901 |
+
# Add context-specific ML approaches
|
902 |
+
if complete_threat_data.get('commandAndControl'):
|
903 |
+
fallback['ml_approaches'].append({
|
904 |
+
'technique': 'C2 Traffic Analysis',
|
905 |
+
'source_company': 'Context-Derived',
|
906 |
+
'description': 'ML-based detection of command and control communication patterns identified in the threat profile.',
|
907 |
+
'applicability_score': 0.8
|
908 |
+
})
|
909 |
+
|
910 |
+
if complete_threat_data.get('forensicArtifacts'):
|
911 |
+
fallback['ml_approaches'].append({
|
912 |
+
'technique': 'Artifact-Based Detection',
|
913 |
+
'source_company': 'Context-Derived',
|
914 |
+
'description': 'Machine learning models trained on forensic artifacts specific to this threat.',
|
915 |
+
'applicability_score': 0.7
|
916 |
+
})
|
917 |
+
|
918 |
+
return fallback
|
919 |
+
|
920 |
+
# Monkey patch these methods onto the MLAgenticRetriever class
|
921 |
+
MLAgenticRetriever._optimize_query_with_context = _optimize_query_with_context
|
922 |
+
MLAgenticRetriever._structure_enhanced_ml_guidance = _structure_enhanced_ml_guidance
|
923 |
+
MLAgenticRetriever._create_enhanced_fallback_guidance = _create_enhanced_fallback_guidance
|
924 |
+
|
925 |
+
|
926 |
+
def main():
|
927 |
+
"""Test the ML Agentic Retriever"""
|
928 |
+
|
929 |
+
# Initialize
|
930 |
+
api_key = os.getenv('ANTHROPIC_API_KEY')
|
931 |
+
if not api_key:
|
932 |
+
print("Error: ANTHROPIC_API_KEY environment variable not set")
|
933 |
+
return
|
934 |
+
|
935 |
+
print("π€ Testing ML Agentic Retriever")
|
936 |
+
print("=" * 40)
|
937 |
+
|
938 |
+
# Create components
|
939 |
+
anthropic_client = Anthropic(api_key=api_key)
|
940 |
+
retriever = MLAgenticRetriever(anthropic_client)
|
941 |
+
|
942 |
+
# Test with sample threat
|
943 |
+
threat = create_test_threat_characteristics()
|
944 |
+
|
945 |
+
print(f"π― Testing with threat: {threat.threat_name}")
|
946 |
+
print(f" Type: {threat.threat_type}")
|
947 |
+
print(f" Attack Vectors: {', '.join(threat.attack_vectors)}")
|
948 |
+
print(f" Behavior Patterns: {', '.join(threat.behavior_patterns)}")
|
949 |
+
|
950 |
+
# Get ML guidance
|
951 |
+
guidance = retriever.get_ml_guidance(threat)
|
952 |
+
|
953 |
+
print(f"\nπ§ ML Guidance Generated:")
|
954 |
+
print(f" ML Approaches: {len(guidance['ml_approaches'])}")
|
955 |
+
print(f" Implementation Considerations: {len(guidance['implementation_considerations'])}")
|
956 |
+
print(f" Source Papers: {len(guidance['source_papers'])}")
|
957 |
+
|
958 |
+
# Show details
|
959 |
+
if guidance['ml_approaches']:
|
960 |
+
print(f"\nπ Top ML Approaches:")
|
961 |
+
for i, approach in enumerate(guidance['ml_approaches'][:3], 1):
|
962 |
+
print(f" {i}. {approach['technique']} ({approach['source_company']})")
|
963 |
+
print(f" Applicability: {approach['applicability_score']:.2f}")
|
964 |
+
print(f" Description: {approach['description'][:100]}...")
|
965 |
+
|
966 |
+
if guidance['source_papers']:
|
967 |
+
print(f"\nπ Source Papers:")
|
968 |
+
for paper in guidance['source_papers'][:3]:
|
969 |
+
print(f" β’ {paper['company']} ({paper['year']}): {paper['title'][:60]}...")
|
970 |
+
|
971 |
+
print(f"\nβ
Agentic retrieval test complete!")
|
972 |
+
|
973 |
+
|
974 |
+
if __name__ == "__main__":
|
975 |
+
main()
|
src/{ml_workers_retriever.py β search/ml_workers_retriever.py}
RENAMED
File without changes
|
src/{app.py β ui/app.py}
RENAMED
@@ -2,8 +2,8 @@
|
|
2 |
π SentrySearch - Threat Intelligence Profile Generator
|
3 |
"""
|
4 |
import gradio as gr
|
5 |
-
from threat_intel_tool import ThreatIntelTool
|
6 |
-
from markdown_generator import generate_markdown
|
7 |
|
8 |
|
9 |
def generate_threat_profile(api_key, tool_name, enable_quality_control, progress=gr.Progress()):
|
|
|
2 |
π SentrySearch - Threat Intelligence Profile Generator
|
3 |
"""
|
4 |
import gradio as gr
|
5 |
+
from src.core.threat_intel_tool import ThreatIntelTool
|
6 |
+
from src.core.markdown_generator import generate_markdown
|
7 |
|
8 |
|
9 |
def generate_threat_profile(api_key, tool_name, enable_quality_control, progress=gr.Progress()):
|
wrangler.toml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cloudflare Workers configuration for SentrySearch Hybrid Search
|
2 |
+
|
3 |
+
name = "sentry-search-hybrid"
|
4 |
+
main = "worker.js"
|
5 |
+
compatibility_date = "2023-12-01"
|
6 |
+
|
7 |
+
# KV Namespace binding
|
8 |
+
[[kv_namespaces]]
|
9 |
+
binding = "SENTRY_KV"
|
10 |
+
id = "f97c6c83f96f4b548307ccb1ffaa2668"
|
11 |
+
preview_id = "3fdb32abdd9d472d9d0506297dde4587"
|
12 |
+
|
13 |
+
# Environment variables (non-secret)
|
14 |
+
[vars]
|
15 |
+
ENVIRONMENT = "production"
|
16 |
+
MAX_SEARCH_RESULTS = "50"
|
17 |
+
DEFAULT_SEARCH_TIMEOUT = "30000"
|
18 |
+
|
19 |
+
# Routes (configure after deployment)
|
20 |
+
# routes = [
|
21 |
+
# { pattern = "sentry-search.your-domain.com/*", zone_name = "your-domain.com" }
|
22 |
+
# ]
|
23 |
+
|
24 |
+
# Resource limits (removed for free plan compatibility)
|
25 |
+
# [limits]
|
26 |
+
# cpu_ms = 30000
|