| """ |
| ThreatHunter 雙層記憶學習 Tool |
| ============================== |
| |
| Layer 1: JSON 持久化(穩定保底,Day 1 起可用) |
| Layer 2: LlamaIndex RAG(語義搜尋,ENABLE_MEMORY_RAG=true 啟用) |
| |
| 遵循文件: |
| - FINAL_PLAN.md §支柱 3(Feedback Loops:雙層記憶學習系統) |
| - leader_plan.md §記憶學習系統 |
| """ |
|
|
| import json |
| import logging |
| from pathlib import Path |
| from datetime import datetime, timezone |
| from typing import Any |
|
|
| from config import MEMORY_DIR, ENABLE_MEMORY_RAG, SIMILARITY_THRESHOLD |
| from crewai.tools import tool |
|
|
| logger = logging.getLogger("threathunter.memory") |
|
|
| |
| try: |
| from sandbox.memory_sanitizer import sanitize_memory_write as _sanitize_write |
| _MEM_SANITIZER_OK = True |
| except ImportError: |
| def _sanitize_write(data, agent_name=''): |
| return True, data, 'ok' |
| _MEM_SANITIZER_OK = False |
|
|
| |
| VALID_AGENT_NAMES = {"scout", "analyst", "advisor", "critic", "orchestrator"} |
|
|
|
|
| |
| def _get_memory_path(agent_name: str) -> Path: |
| """取得指定 Agent 的記憶 JSON 路徑""" |
| return MEMORY_DIR / f"{agent_name}_memory.json" |
|
|
|
|
| def _load_json(path: Path) -> dict: |
| """安全載入 JSON,檔案不存在或損壞時回傳空 dict""" |
| if not path.exists(): |
| return {} |
| try: |
| content = path.read_text(encoding="utf-8") |
| if not content.strip(): |
| return {} |
| return json.loads(content) |
| except (json.JSONDecodeError, OSError) as e: |
| logger.warning("[WARN] Memory file read failed %s: %s, returning empty", path, e) |
| return {} |
|
|
|
|
| def _save_json(path: Path, data: dict) -> None: |
| """安全寫入 JSON(先寫臨時檔再 rename,防止寫入中斷導致損壞)""" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| temp_path = path.with_suffix(".tmp") |
| try: |
| temp_path.write_text( |
| json.dumps(data, ensure_ascii=False, indent=2), |
| encoding="utf-8", |
| ) |
| temp_path.replace(path) |
| except OSError as e: |
| logger.error("[FAIL] Memory file write failed %s: %s", path, e) |
| if temp_path.exists(): |
| temp_path.unlink() |
| raise |
|
|
|
|
| |
| _rag_index = None |
| _rag_query_engine = None |
|
|
|
|
| def _init_rag() -> None: |
| """延遲初始化 LlamaIndex RAG(只在第一次呼叫時執行)""" |
| global _rag_index, _rag_query_engine |
|
|
| if not ENABLE_MEMORY_RAG: |
| return |
| if _rag_index is not None: |
| return |
|
|
| try: |
| from llama_index.core import VectorStoreIndex, StorageContext, Settings |
| from llama_index.core import load_index_from_storage |
|
|
| |
| |
| try: |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
| Settings.embed_model = HuggingFaceEmbedding( |
| model_name="BAAI/bge-small-en-v1.5" |
| ) |
| logger.info("[OK] Embedding: HuggingFace BAAI/bge-small-en-v1.5 (local free)") |
| except ImportError: |
| logger.warning("[WARN] HuggingFace embedding not installed, trying OpenAI embedding") |
|
|
| |
| try: |
| from llama_index.core.llms import MockLLM |
| Settings.llm = MockLLM() |
| except Exception: |
| Settings.llm = None |
|
|
| vector_store_path = MEMORY_DIR / "vector_store" |
|
|
| if (vector_store_path / "docstore.json").exists(): |
| storage_context = StorageContext.from_defaults( |
| persist_dir=str(vector_store_path) |
| ) |
| _rag_index = load_index_from_storage(storage_context) |
| logger.info("[OK] LlamaIndex vector index loaded") |
| _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3) |
| else: |
| _rag_index = VectorStoreIndex([]) |
| logger.info("[OK] LlamaIndex vector index created (empty)") |
| |
| _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3) |
| _backfill_from_json_history() |
| |
| _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3) |
|
|
| except ImportError: |
| logger.warning("[WARN] LlamaIndex not installed, RAG disabled") |
| except Exception as e: |
| logger.warning("[WARN] LlamaIndex init failed: %s", e) |
|
|
|
|
| def _backfill_from_json_history() -> None: |
| """ |
| 首次啟用 RAG 時,將所有已有的 *_memory.json 批次回填進 LlamaIndex。 |
| 只在 vector_store 新建(非載入)時執行一次,之後透過 persist 保存。 |
| |
| 設計原則(Harness Engineering — Feedback Loops 支柱): |
| - 確保歷史掃描記錄不因 RAG 冷啟動而遺失語義感知 |
| - 失敗不阻塞(Graceful Degradation) |
| """ |
| if _rag_index is None: |
| return |
|
|
| total_inserted = 0 |
| for agent_name in VALID_AGENT_NAMES: |
| json_path = _get_memory_path(agent_name) |
| if not json_path.exists(): |
| continue |
|
|
| data = _load_json(json_path) |
| if not data: |
| continue |
|
|
| |
| _rag_insert(agent_name, data) |
| total_inserted += 1 |
|
|
| |
| for hist_scan in data.get("history", []): |
| if isinstance(hist_scan, dict) and hist_scan: |
| _rag_insert(agent_name, hist_scan) |
| total_inserted += 1 |
|
|
| if total_inserted > 0: |
| logger.info("[OK] RAG history backfill done: %d scan records vectorized", total_inserted) |
| else: |
| logger.info("[INFO] RAG backfill: no historical JSON memory to backfill (first scan)") |
|
|
|
|
| def _extract_package_names(tech_stack: str) -> set[str]: |
| """ |
| 從技術棧字串中提取套件名稱(小寫、去版本號)。 |
| |
| 範例: |
| 'Django 4.2, Redis 7.0' -> {'django', 'redis'} |
| 'Spring Boot 3.1 和 Node.js 18' -> {'spring', 'boot', 'node.js'} |
| """ |
| if not tech_stack: |
| return set() |
| names = set() |
| for part in tech_stack.replace(",", " ").split(): |
| clean = part.strip().lower() |
| |
| if clean and not clean.replace(".", "").replace("-", "").isdigit(): |
| names.add(clean) |
| return names |
|
|
|
|
| def _rag_insert(agent_name: str, data: dict) -> None: |
| """將資料插入 LlamaIndex 向量索引(雙寫的 Layer 2),含 tech_stack 元資料""" |
| if not ENABLE_MEMORY_RAG or _rag_index is None: |
| return |
| try: |
| from llama_index.core import Document |
|
|
| |
| tech_stack = ( |
| data.get("tech_stack", "") |
| or data.get("tech_stack_input", "") |
| or "" |
| ) |
| |
| if isinstance(tech_stack, list): |
| tech_stack = ", ".join(str(t) for t in tech_stack) |
|
|
| doc = Document( |
| text=json.dumps(data, ensure_ascii=False), |
| metadata={ |
| "agent": agent_name, |
| "timestamp": data.get("timestamp", ""), |
| "scan_id": data.get("scan_id", ""), |
| "tech_stack": str(tech_stack), |
| }, |
| ) |
| _rag_index.insert(doc) |
| vector_store_path = MEMORY_DIR / "vector_store" |
| _rag_index.storage_context.persist(persist_dir=str(vector_store_path)) |
| logger.info("[OK] RAG index updated: %s (tech_stack=%s)", agent_name, tech_stack[:50]) |
| except Exception as e: |
| logger.warning("[WARN] RAG write failed (JSON layer unaffected): %s", e) |
|
|
|
|
| def _rag_search(query: str, tech_stack: str | None = None) -> str: |
| """ |
| 語義搜尋(帶安全閥 + 技術棧相關性過濾)。 |
| |
| Args: |
| query: 搜尋查詢 |
| tech_stack: 當前掃描的技術棧(用於過濾不相關歷史) |
| """ |
| if not ENABLE_MEMORY_RAG: |
| return "RAG disabled (ENABLE_MEMORY_RAG=false)" |
|
|
| _init_rag() |
| if _rag_index is None: |
| return "RAG index unavailable" |
|
|
| try: |
| doc_count = len(_rag_index.docstore.docs) if hasattr(_rag_index, "docstore") else 0 |
| if doc_count == 0: |
| return "No history available (vector index empty)" |
|
|
| response = _rag_query_engine.query(query) |
|
|
| |
| if hasattr(response, "source_nodes") and response.source_nodes: |
| scores = [n.score for n in response.source_nodes if n.score is not None] |
| max_score = max(scores) if scores else 0 |
| if max_score < SIMILARITY_THRESHOLD: |
| return ( |
| f"No relevant history found" |
| f" (max_similarity {max_score:.2f} < threshold {SIMILARITY_THRESHOLD})" |
| ) |
|
|
| |
| if tech_stack: |
| current_packages = _extract_package_names(tech_stack) |
| if current_packages: |
| filtered_nodes = [] |
| for node in response.source_nodes: |
| node_tech = node.metadata.get("tech_stack", "") |
| if not node_tech: |
| |
| filtered_nodes.append(node) |
| continue |
| node_packages = _extract_package_names(node_tech) |
| |
| if current_packages & node_packages: |
| filtered_nodes.append(node) |
| else: |
| logger.info( |
| "[FILTER] Excluded history: %s (no overlap with %s)", |
| node_tech[:50], tech_stack[:50], |
| ) |
|
|
| if not filtered_nodes: |
| return ( |
| f"No relevant history for current tech stack" |
| f" (filtered {len(response.source_nodes)} results, 0 matched)" |
| ) |
| response.source_nodes = filtered_nodes |
|
|
| return str(response) |
|
|
| except Exception as e: |
| logger.warning("[WARN] RAG search failed: %s", e) |
| return f"RAG search failed: {e}" |
|
|
|
|
| |
| @tool("read_memory") |
| def read_memory(agent_name: str) -> str: |
| """ |
| 讀取指定 Agent 的歷史記憶(JSON Layer 1:穩定保底)。 |
| 0 份歷史回傳空 JSON,Agent 可據此判斷是否為第一次掃描。 |
| |
| Args: |
| agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator) |
| |
| Returns: |
| JSON 字串格式的歷史記憶 |
| """ |
| agent_name = agent_name.strip().lower() |
| if agent_name not in VALID_AGENT_NAMES: |
| logger.warning("[WARN] Invalid agent name: %s", agent_name) |
| return json.dumps({}, ensure_ascii=False) |
|
|
| data = _load_json(_get_memory_path(agent_name)) |
|
|
| if not data: |
| logger.info("[INFO] %s has no history (first scan)", agent_name) |
| else: |
| logger.info("[OK] %s memory loaded (scan_id: %s)", agent_name, data.get('scan_id', 'N/A')) |
|
|
| return json.dumps(data, ensure_ascii=False, indent=2) |
|
|
|
|
| @tool("write_memory") |
| def write_memory(agent_name: str, data: str) -> str: |
| """ |
| 寫入 Agent 記憶(雙寫:JSON + LlamaIndex)。自動添加 timestamp。 |
| |
| Args: |
| agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator) |
| data: JSON 字串格式的記憶資料 |
| |
| Returns: |
| 寫入結果訊息 |
| """ |
| agent_name = agent_name.strip().lower() |
| if agent_name not in VALID_AGENT_NAMES: |
| return f"[FAIL] Invalid agent name: {agent_name} (allowed: {VALID_AGENT_NAMES})" |
|
|
| try: |
| memory_data = json.loads(data) if isinstance(data, str) else data |
| except json.JSONDecodeError as e: |
| return f"[FAIL] JSON format error: {e}" |
|
|
| |
| is_safe, clean_data, reason = _sanitize_write(memory_data, agent_name) |
| if not is_safe: |
| logger.warning('[MEMORY][SANDBOX] Write BLOCKED: %s', reason) |
| return '[BLOCKED] Memory write rejected by Sandbox: ' + reason |
| memory_data = clean_data |
|
|
| memory_data["timestamp"] = datetime.now(timezone.utc).isoformat() |
|
|
| |
| try: |
| existing = _load_json(_get_memory_path(agent_name)) |
| history = existing.get("history", []) |
|
|
| |
| if existing and "scan_id" in existing: |
| old_entry = {k: v for k, v in existing.items() if k != "history"} |
| history.append(old_entry) |
| if len(history) > 50: |
| history = history[-50:] |
|
|
| memory_data["history"] = history |
| _save_json(_get_memory_path(agent_name), memory_data) |
| logger.info("[OK] %s memory saved to JSON (Layer 1 | history=%d records)", agent_name, len(history)) |
| except Exception as e: |
| return f"[FAIL] JSON write failed: {e}" |
| |
| _rag_insert(agent_name, memory_data) |
|
|
| return f"[OK] {agent_name} memory saved (timestamp: {memory_data['timestamp']})" |
|
|
|
|
| @tool("history_search") |
| def history_search(query: str, tech_stack: str = "") -> str: |
| """ |
| 語義搜尋歷史安全報告(帶技術棧過濾)。 |
| 帶安全閥:索引為空 / 分數太低 / 技術棧不匹配 / RAG 未啟用 → 回傳提示。 |
| |
| Args: |
| query: 搜尋查詢(例如:"Django SSRF 歷史") |
| tech_stack: 當前掃描的技術棧(例如:"Django 4.2, Redis 7.0"), |
| 用於過濾不相關的歷史記錄 |
| |
| Returns: |
| 搜尋結果或安全提示 |
| """ |
| return _rag_search(query, tech_stack if tech_stack else None) |
|
|