Spaces:
Sleeping
Sleeping
| """ | |
| train_rl.py β Training + Evaluation Script for Adaptive Alert Triage | |
| ===================================================================== | |
| Runs: | |
| 1. Rule-based baselines (RuleBasedAgent, ImprovedRuleBasedAgent) | |
| 2. PPO RL agent training across all 3 tasks | |
| 3. Saves all results to results.json for the comparison plot | |
| Changes vs previous version: | |
| - Per-task episode budgets: hard gets 3Γ more episodes than easy | |
| because it has 40% chain probability and needs far more samples | |
| to observe enough chain outcomes for the policy to converge. | |
| - --episodes arg now sets the EASY budget; medium and hard scale up. | |
| - Grader is now passed into trainer.train() so terminal scores are | |
| injected into trajectories (was only logged after, not used). | |
| Usage: | |
| python train_rl.py [--episodes 300] [--eval-episodes 20] [--seed 42] | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| import os | |
| import argparse | |
| import time | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from adaptive_alert_triage.env import AdaptiveAlertTriageEnv | |
| from adaptive_alert_triage.models import Action | |
| from tasks.easy import EasyTaskGrader, SUCCESS_THRESHOLD as EASY_THRESH | |
| from tasks.medium import MediumTaskGrader, SUCCESS_THRESHOLD as MED_THRESH | |
| from tasks.hard import HardTaskGrader, SUCCESS_THRESHOLD as HARD_THRESH | |
| from rl_agent import PPOTrainer, encode_state, _ACTION_NAMES | |
| _GRADER_REGISTRY = { | |
| "easy": (EasyTaskGrader, {}, EASY_THRESH), | |
| "medium": (MediumTaskGrader, {"max_investigations_per_step": 3}, MED_THRESH), | |
| "hard": (HardTaskGrader, {}, HARD_THRESH), | |
| } | |
| # Per-task episode budgets. | |
| # Hard needs ~3Γ easy: 40% chain probability over 50 steps means the agent | |
| # sees ~20 chain alerts per episode but needs hundreds of episodes to learn | |
| # which observable features (age, type, severity pattern) correlate with | |
| # chain membership. Medium gets 1.5Γ for the resource-constraint curriculum. | |
| _EPISODE_SCALE = { | |
| "easy": 1.0, | |
| "medium": 1.5, | |
| "hard": 3.0, | |
| } | |
| # ββ Rule-based agents βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SimpleRuleAgent: | |
| name = "RuleBased" | |
| def act(self, obs): | |
| if not obs.alerts: | |
| raise ValueError("No alerts") | |
| alert = max(obs.alerts, key=lambda a: a.visible_severity) | |
| sev, conf = alert.visible_severity, alert.confidence | |
| budget = obs.resource_budget | |
| if sev > 0.75 and conf > 0.70: | |
| if budget is not None and budget <= 0: | |
| return Action(alert_id=alert.id, action_type="ESCALATE") | |
| return Action(alert_id=alert.id, action_type="INVESTIGATE") | |
| if conf < 0.30: | |
| return Action(alert_id=alert.id, action_type="IGNORE") | |
| if sev > 0.55: | |
| return Action(alert_id=alert.id, action_type="ESCALATE") | |
| return Action(alert_id=alert.id, action_type="DELAY") | |
| def reset(self): | |
| pass | |
| class ImprovedRuleAgent: | |
| name = "ImprovedRule" | |
| def act(self, obs): | |
| if not obs.alerts: | |
| raise ValueError("No alerts") | |
| def score(a): | |
| s = a.visible_severity * 2.0 + a.age * 0.08 | |
| if a.alert_type == "SECURITY": | |
| s += 0.1 | |
| return s | |
| alert = max(obs.alerts, key=score) | |
| sev, conf, age = alert.visible_severity, alert.confidence, alert.age | |
| budget, sys_load = obs.resource_budget, obs.system_load | |
| if age >= 3 and sev > 0.70: | |
| if budget is not None and budget <= 0: | |
| return Action(alert_id=alert.id, action_type="ESCALATE") | |
| return Action(alert_id=alert.id, action_type="INVESTIGATE") | |
| if sys_load > 0.85: | |
| if sev > 0.85 and conf > 0.80: | |
| return Action(alert_id=alert.id, action_type="INVESTIGATE") | |
| return Action(alert_id=alert.id, action_type="DELAY") | |
| if sev > 0.75 and conf > 0.70: | |
| if budget is not None and budget <= 0: | |
| return Action(alert_id=alert.id, action_type="ESCALATE") | |
| return Action(alert_id=alert.id, action_type="INVESTIGATE") | |
| if conf < 0.30: | |
| return Action(alert_id=alert.id, action_type="IGNORE") | |
| if sev > 0.55: | |
| return Action(alert_id=alert.id, action_type="ESCALATE") | |
| return Action(alert_id=alert.id, action_type="DELAY") | |
| def reset(self): | |
| pass | |
| # ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_agent(agent, task_id: str, n_episodes: int, seed_offset: int = 0) -> Dict: | |
| grader_cls, grader_kwargs, threshold = _GRADER_REGISTRY[task_id] | |
| env = AdaptiveAlertTriageEnv(task_id=task_id) | |
| scores = [] | |
| is_hard = (task_id == "hard") | |
| for ep in range(n_episodes): | |
| grader = grader_cls(**grader_kwargs) | |
| if hasattr(agent, 'reset'): | |
| agent.reset() | |
| obs = env.reset(seed=seed_offset + ep) | |
| done = False | |
| while not done: | |
| if not obs.alerts: | |
| break | |
| action = agent.act(obs) | |
| obs, _r, done, info = env.step(action) | |
| if is_hard: | |
| grader.update_correlation_state(info.get("correlation_groups", [])) | |
| for ad in info.get("processed_alerts", []): | |
| grader.process_step(ad, info) | |
| if is_hard: | |
| grader.record_failures(info.get("failures_this_step", 0)) | |
| scores.append(grader.get_episode_score()) | |
| arr = np.array(scores) | |
| return { | |
| "mean": float(arr.mean()), | |
| "std": float(arr.std()), | |
| "min": float(arr.min()), | |
| "max": float(arr.max()), | |
| "success_rate": float((arr >= threshold).mean()), | |
| "scores": scores, | |
| } | |
| # ββ PPO wrapper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PPOAgentWrapper: | |
| def __init__(self, trainer: PPOTrainer): | |
| self._trainer = trainer | |
| self.name = "PPO_LSTM" | |
| def act(self, obs): | |
| return self._trainer.act(obs) | |
| def reset(self): | |
| self._trainer.reset() | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(args): | |
| results = {} | |
| for task_id in ["easy", "medium", "hard"]: | |
| grader_cls, grader_kwargs, threshold = _GRADER_REGISTRY[task_id] | |
| # Per-task episode budget | |
| n_episodes = int(args.episodes * _EPISODE_SCALE[task_id]) | |
| print(f"\n{'='*60}") | |
| print(f"TASK: {task_id.upper()} " | |
| f"(threshold β₯ {threshold}, episodes = {n_episodes})") | |
| print(f"{'='*60}") | |
| # 1. Rule-based baselines | |
| print(f"\n[1/3] Evaluating rule-based baselinesβ¦") | |
| rb_basic_res = evaluate_agent(SimpleRuleAgent(), task_id, args.eval_episodes, seed_offset=100) | |
| rb_improved_res = evaluate_agent(ImprovedRuleAgent(), task_id, args.eval_episodes, seed_offset=100) | |
| print(f" RuleBased : mean={rb_basic_res['mean']:.3f} " | |
| f"success={rb_basic_res['success_rate']:.0%}") | |
| print(f" ImprovedRule : mean={rb_improved_res['mean']:.3f} " | |
| f"success={rb_improved_res['success_rate']:.0%}") | |
| # 2. PPO training (grader passed in so terminal scores hit trajectories) | |
| print(f"\n[2/3] Training PPO agent ({n_episodes} episodes)β¦") | |
| env = AdaptiveAlertTriageEnv(task_id=task_id) | |
| trainer = PPOTrainer(task_id=task_id, seed=args.seed, lr=3e-4) | |
| t0 = time.time() | |
| history = trainer.train( | |
| env, | |
| n_episodes = n_episodes, | |
| grader_cls = grader_cls, | |
| grader_kwargs = grader_kwargs, | |
| log_interval = max(1, n_episodes // 10), | |
| verbose = True, | |
| ) | |
| elapsed = time.time() - t0 | |
| print(f" Training done in {elapsed:.1f}s") | |
| os.makedirs("weights", exist_ok=True) | |
| trainer.save(f"weights/ppo_{task_id}.json") | |
| # 3. PPO evaluation | |
| print(f"\n[3/3] Evaluating PPO agent ({args.eval_episodes} episodes)β¦") | |
| ppo_agent = PPOAgentWrapper(trainer) | |
| ppo_res = evaluate_agent(ppo_agent, task_id, args.eval_episodes, seed_offset=200) | |
| print(f" PPO : mean={ppo_res['mean']:.3f} " | |
| f"success={ppo_res['success_rate']:.0%}") | |
| results[task_id] = { | |
| "threshold": threshold, | |
| "n_episodes": n_episodes, | |
| "rule_basic": rb_basic_res, | |
| "rule_improved": rb_improved_res, | |
| "ppo": ppo_res, | |
| "training": { | |
| "episode_rewards": history["episode_rewards"], | |
| "episode_scores": history["episode_scores"], | |
| "policy_losses": history["policy_losses"], | |
| "entropies": history["entropies"], | |
| }, | |
| } | |
| # Save results | |
| class _NumpyEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, (np.floating,)): return float(obj) | |
| if isinstance(obj, (np.integer,)): return int(obj) | |
| if isinstance(obj, np.ndarray): return obj.tolist() | |
| return super().default(obj) | |
| os.makedirs("results", exist_ok=True) | |
| out_path = "results/comparison_results.json" | |
| with open(out_path, "w") as f: | |
| json.dump(results, f, indent=2, cls=_NumpyEncoder) | |
| print(f"\nβ Results saved to {out_path}") | |
| # Summary table | |
| print(f"\n{'='*60}") | |
| print("FINAL COMPARISON SUMMARY") | |
| print(f"{'='*60}") | |
| print(f"{'Task':<10} {'Agent':<16} {'Mean':>8} {'Std':>7} {'Pass%':>8}") | |
| print("β" * 55) | |
| for task_id, res in results.items(): | |
| for name, key in [("RuleBased", "rule_basic"), | |
| ("ImprovedRule", "rule_improved"), | |
| ("PPO+LSTM", "ppo")]: | |
| r = res[key] | |
| print(f"{task_id:<10} {name:<16} " | |
| f"{r['mean']:>8.3f} " | |
| f"{r['std']:>7.3f} " | |
| f"{r['success_rate']*100:>7.1f}%") | |
| print("β" * 55) | |
| return results | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--episodes", type=int, default=300, | |
| help="Episode budget for easy task; medium=1.5Γ, hard=3Γ") | |
| p.add_argument("--eval-episodes", type=int, default=20) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| run(args) |