Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| IAMSentinel RL Training Example | |
| ================================= | |
| Demonstrates how to connect a local RL training loop to the | |
| remote IAMSentinel OpenEnv server (Hugging Face Spaces or local Docker). | |
| This implements a simple LLM-guided policy (REINFORCE-style) using the | |
| OpenAI API as the policy network, with episode-level reward signals. | |
| The same pattern works with any RL framework (Stable-Baselines3, RLlib, | |
| CleanRL) β just replace the policy network. | |
| Setup: | |
| # Option A β local docker | |
| docker build -t iamsentinel . && docker run -p 7860:7860 iamsentinel | |
| # Option B β HF Space (set HF_SPACE_URL env var) | |
| export HF_SPACE_URL=https://<username>-iamsentinel.hf.space | |
| # Run training | |
| export OPENAI_API_KEY=sk-... | |
| python scripts/rl_training_example.py --episodes 20 --task task1 | |
| Architecture: | |
| βββββββββββββββββββββββββββββββββββ | |
| β Local Machine (trainer) β | |
| β β | |
| β ββββββββββββ βββββββββββββ β | |
| β β Policy β β Replay β β ββββββββββββββββββββββββ | |
| β β (GPT-4o) β β Buffer β ββββββββββΊβ IAMSentinel Server β | |
| β ββββββββββββ βββββββββββββ β HTTP β (HF Space / Docker) β | |
| β ββββββββββββββββββββββββββββ β ββββββββββββββββββββββββ | |
| β β Episode Logger / Scorer β β | |
| β ββββββββββββββββββββββββββββ β | |
| βββββββββββββββββββββββββββββββββββ | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import statistics | |
| from collections import defaultdict | |
| from typing import Optional | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from iamsentinel.client import IAMSentinelClient, IAMSentinelClientError | |
| try: | |
| from openai import OpenAI | |
| HAS_OPENAI = True | |
| except ImportError: | |
| HAS_OPENAI = False | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Replay buffer (stores episodes for training) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Episode: | |
| def __init__(self, task_id: str, seed: int): | |
| self.task_id = task_id | |
| self.seed = seed | |
| self.transitions: list[dict] = [] # (obs, action, reward, next_obs, done) | |
| self.total_reward = 0.0 | |
| self.final_score = 0.0 | |
| self.steps = 0 | |
| def add(self, obs: dict, action: dict, reward: dict, | |
| next_obs: dict, done: bool): | |
| self.transitions.append({ | |
| "obs": obs, | |
| "action": action, | |
| "reward": reward["total"], | |
| "step_reward": reward.get("step_reward", 0.0), | |
| "next_obs": next_obs, | |
| "done": done, | |
| }) | |
| self.total_reward += reward["total"] | |
| self.steps += 1 | |
| if done and reward.get("total") is not None: | |
| self.final_score = reward["total"] | |
| class ReplayBuffer: | |
| def __init__(self, max_episodes: int = 100): | |
| self.episodes: list[Episode] = [] | |
| self.max_episodes = max_episodes | |
| def add(self, episode: Episode): | |
| self.episodes.append(episode) | |
| if len(self.episodes) > self.max_episodes: | |
| self.episodes.pop(0) | |
| def mean_score(self, last_n: int = 10) -> float: | |
| recent = [e.final_score for e in self.episodes[-last_n:]] | |
| return statistics.mean(recent) if recent else 0.0 | |
| def task_scores(self) -> dict[str, list[float]]: | |
| by_task: dict[str, list[float]] = defaultdict(list) | |
| for ep in self.episodes: | |
| by_task[ep.task_id].append(ep.final_score) | |
| return dict(by_task) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LLM Policy | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| POLICY_SYSTEM = """You are an IAM security AI agent. You interact with a cloud IAM | |
| environment by outputting ONE JSON action per turn. | |
| Your goal: identify security vulnerabilities and complete the assigned task. | |
| Output ONLY a valid JSON action block like: | |
| {"action": "list_principals", "kind": "all"} | |
| Available actions: | |
| - {"action": "list_principals", "kind": "all"|"user"|"role"} | |
| - {"action": "list_policies", "principal_arn": null} | |
| - {"action": "get_policy", "policy_arn": "<arn>"} | |
| - {"action": "get_principal", "principal_arn": "<arn>"} | |
| - {"action": "get_role_trust", "role_arn": "<arn>"} | |
| - {"action": "query_audit_log", "filter": {"severity":"critical","event_name":"..."}, "limit": 20} | |
| - {"action": "trace_escalation_path", "from_principal_arn": "<arn>", "to_principal_arn": null} | |
| - {"action": "flag_finding", "finding_type": "wildcard_policy"|"mfa_disabled"|"stale_admin_role"|"privilege_escalation_path"|"exposed_trust_policy", "severity": "critical", "description": "...", "affected_principal_arn": null, "evidence": []} | |
| - {"action": "attribute_attack", "compromised_principal_arn":"<arn>","attack_technique":"...","mitre_techniques":["T1078.004"],"lateral_movement_path":["<arn1>","<arn2>"],"containment_actions":["disable_user:<arn>"]} | |
| Be systematic. For Task 1: scan all principals and policies for misconfigs. | |
| For Task 2: find iam:PassRole chains. For Task 3: query critical/high severity logs first.""" | |
| def _format_obs_for_policy(obs: dict, step: int, prev_reward: float = 0.0) -> str: | |
| """Format observation into LLM-friendly text.""" | |
| lines = [ | |
| f"Step {step}/{obs.get('max_steps', '?')} | Budget: {obs.get('budget_remaining', '?')}", | |
| f"Task: {obs.get('task_description', '')[:120]}", | |
| ] | |
| if prev_reward != 0: | |
| lines.append(f"Last reward signal: {prev_reward:+.3f}") | |
| findings = obs.get("findings", []) | |
| if findings: | |
| lines.append(f"Findings logged ({len(findings)}):") | |
| for f in findings[-3:]: | |
| lines.append(f" [{f['severity']}] {f['finding_type']}: {f['description'][:60]}") | |
| if obs.get("hints"): | |
| lines.append("Hints: " + " | ".join(obs["hints"])) | |
| if obs.get("principals"): | |
| lines.append(f"Principals ({len(obs['principals'])}):") | |
| for p in obs["principals"][:6]: | |
| mfa = "MFAβ" if p.get("mfa_enabled") else "MFAβ" | |
| lines.append( | |
| f" {p['kind']}: {p['name']} | {mfa} | " | |
| f"inactive={p['last_active_days']}d | " | |
| f"policies={len(p.get('policies',[]))}" | |
| ) | |
| if obs.get("policies"): | |
| lines.append(f"Policies ({len(obs['policies'])}):") | |
| for p in obs["policies"][:6]: | |
| wc = "β WILDCARD" if p.get("is_wildcard") else "" | |
| acts = [] | |
| for stmt in p.get("statements", []): | |
| acts.extend(stmt.get("actions", [])) | |
| lines.append(f" {p['name']} {wc} | arn={p['arn']} | actions={acts[:4]}") | |
| if obs.get("audit_events"): | |
| lines.append(f"Audit events ({len(obs['audit_events'])}):") | |
| for e in obs["audit_events"][:8]: | |
| lines.append( | |
| f" [{e.get('severity','?')}] {e['event_time'][-8:]} | " | |
| f"{e['event_name']} | {e['principal_name']} | ip={e['source_ip']}" | |
| ) | |
| if obs.get("escalation_paths"): | |
| lines.append(f"Escalation paths found: {len(obs['escalation_paths'])}") | |
| for ep in obs["escalation_paths"][:2]: | |
| path_str = " β ".join(a.split("/")[-1] for a in ep.get("path", [])) | |
| lines.append(f" {path_str} (risk={ep.get('risk_score', '?')})") | |
| lines.append("\nOutput ONE JSON action:") | |
| return "\n".join(lines) | |
| def _extract_action(text: str) -> Optional[dict]: | |
| """Extract JSON action from LLM output.""" | |
| import re | |
| for pattern in [ | |
| r"```(?:json)?\s*(\{.*?\})\s*```", | |
| r"(\{[^{}]*\"action\"[^{}]*\})", | |
| ]: | |
| m = re.search(pattern, text, re.DOTALL) | |
| if m: | |
| try: | |
| return json.loads(m.group(1)) | |
| except Exception: | |
| pass | |
| # Greedy fallback | |
| for s in range(len(text)): | |
| if text[s] == "{": | |
| for e in range(len(text), s, -1): | |
| try: | |
| obj = json.loads(text[s:e]) | |
| if "action" in obj: | |
| return obj | |
| except Exception: | |
| continue | |
| return None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Episode runner | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_episode( | |
| client: IAMSentinelClient, | |
| task_id: str, | |
| seed: int, | |
| model: str, | |
| openai_client, | |
| verbose: bool = False, | |
| ) -> Episode: | |
| """Run one complete episode and return the filled Episode object.""" | |
| episode = Episode(task_id=task_id, seed=seed) | |
| obs = client.reset(task_id=task_id, seed=seed, complexity="medium") | |
| messages = [{"role": "system", "content": POLICY_SYSTEM}] | |
| prev_reward = 0.0 | |
| done = False | |
| step = 0 | |
| max_steps = obs.get("max_steps", 40) | |
| while not done and step < max_steps: | |
| step += 1 | |
| user_msg = _format_obs_for_policy(obs, step, prev_reward) | |
| messages.append({"role": "user", "content": user_msg}) | |
| # Get action from policy | |
| try: | |
| resp = openai_client.chat.completions.create( | |
| model=model, | |
| messages=messages[-20:], # sliding window context | |
| temperature=0.3, | |
| max_tokens=400, | |
| ) | |
| text = resp.choices[0].message.content | |
| messages.append({"role": "assistant", "content": text}) | |
| except Exception as ex: | |
| if verbose: | |
| print(f" LLM error: {ex}") | |
| time.sleep(2) | |
| continue | |
| action = _extract_action(text) | |
| if action is None: | |
| if verbose: | |
| print(f" [Step {step}] Failed to parse action") | |
| messages.append({ | |
| "role": "user", | |
| "content": "Could not parse JSON. Output ONLY a valid JSON action." | |
| }) | |
| continue | |
| # Execute action | |
| try: | |
| next_obs, reward, done, info = client.step(action) | |
| except IAMSentinelClientError as ex: | |
| if verbose: | |
| print(f" [Step {step}] Server error: {ex}") | |
| break | |
| prev_reward = reward.get("step_reward", 0.0) | |
| episode.add(obs, action, reward, next_obs, done) | |
| if verbose: | |
| final = f" | FINAL={reward['total']:.3f}" if done else "" | |
| print( | |
| f" [Step {step:02d}] {action.get('action','?'):<28} " | |
| f"r={prev_reward:+.3f}{final}" | |
| ) | |
| obs = next_obs | |
| time.sleep(0.2) # rate limit | |
| return episode | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Training loop | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train( | |
| server_url: str, | |
| tasks: list[str], | |
| n_episodes: int, | |
| seeds: list[int], | |
| model: str, | |
| verbose: bool, | |
| output_path: Optional[str], | |
| ): | |
| if not HAS_OPENAI: | |
| print("ERROR: pip install openai") | |
| sys.exit(1) | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| print("ERROR: Set OPENAI_API_KEY environment variable") | |
| sys.exit(1) | |
| openai_client = OpenAI(api_key=api_key) | |
| client = IAMSentinelClient(base_url=server_url) | |
| # Verify server is up | |
| try: | |
| health = client.health() | |
| print(f"β Connected to IAMSentinel server: {server_url}") | |
| print(f" Status: {health['status']} | Active sessions: {health['sessions']}") | |
| except IAMSentinelClientError as e: | |
| print(f"β Cannot reach server at {server_url}") | |
| print(f" Error: {e}") | |
| print("\nTo start a local server:") | |
| print(" docker build -t iamsentinel . && docker run -p 7860:7860 iamsentinel") | |
| sys.exit(1) | |
| buffer = ReplayBuffer(max_episodes=200) | |
| episode_num = 0 | |
| all_results = [] | |
| print(f"\n{'='*65}") | |
| print(f"IAMSentinel RL Training") | |
| print(f"Tasks: {tasks} | Episodes: {n_episodes} | Model: {model}") | |
| print(f"{'='*65}\n") | |
| for ep_idx in range(n_episodes): | |
| task_id = tasks[ep_idx % len(tasks)] | |
| seed = seeds[ep_idx % len(seeds)] | |
| episode_num += 1 | |
| print(f"Episode {episode_num:03d}/{n_episodes} | task={task_id} | seed={seed}") | |
| episode = run_episode( | |
| client, task_id, seed, model, openai_client, verbose | |
| ) | |
| buffer.add(episode) | |
| # Log results | |
| result = { | |
| "episode": episode_num, | |
| "task_id": task_id, | |
| "seed": seed, | |
| "steps": episode.steps, | |
| "total_reward": round(episode.total_reward, 4), | |
| "final_score": round(episode.final_score, 4), | |
| } | |
| all_results.append(result) | |
| mean_10 = buffer.mean_score(last_n=10) | |
| print( | |
| f" Score={episode.final_score:.3f} | " | |
| f"Steps={episode.steps} | " | |
| f"Moving avg(10)={mean_10:.3f}" | |
| ) | |
| # Print per-task breakdown every 5 episodes | |
| if episode_num % 5 == 0: | |
| print("\n π Per-task mean scores:") | |
| for tid, scores in buffer.task_scores().items(): | |
| print(f" {tid}: mean={statistics.mean(scores):.3f} " | |
| f"over {len(scores)} episodes") | |
| print() | |
| # ββ Final summary ββββββββββββββββββββββββββ | |
| print(f"\n{'='*65}") | |
| print("TRAINING COMPLETE β Final Summary") | |
| print(f"{'='*65}") | |
| task_scores = buffer.task_scores() | |
| for tid in tasks: | |
| scores = task_scores.get(tid, []) | |
| if scores: | |
| print( | |
| f" {tid}: mean={statistics.mean(scores):.3f} " | |
| f"| best={max(scores):.3f} " | |
| f"| worst={min(scores):.3f} " | |
| f"| n={len(scores)}" | |
| ) | |
| if output_path: | |
| with open(output_path, "w") as f: | |
| json.dump({ | |
| "config": { | |
| "server_url": server_url, | |
| "tasks": tasks, | |
| "model": model, | |
| "n_episodes": n_episodes, | |
| }, | |
| "episodes": all_results, | |
| "final_task_scores": { | |
| tid: { | |
| "mean": round(statistics.mean(s), 4), | |
| "best": round(max(s), 4), | |
| "n": len(s), | |
| } | |
| for tid, s in task_scores.items() | |
| }, | |
| }, f, indent=2) | |
| print(f"\nResults saved β {output_path}") | |
| return all_results | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry point | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| hf_url = os.environ.get("HF_SPACE_URL", "") | |
| default_url = hf_url if hf_url else "http://localhost:7860" | |
| parser = argparse.ArgumentParser(description="IAMSentinel RL Training") | |
| parser.add_argument("--server", default=default_url, | |
| help="Server URL (default: $HF_SPACE_URL or http://localhost:7860)") | |
| parser.add_argument("--task", default="all", | |
| help="task1|task2|task3|all") | |
| parser.add_argument("--episodes", type=int, default=15, | |
| help="Total training episodes") | |
| parser.add_argument("--seeds", default="42,123,456,789,1337", | |
| help="Comma-separated seeds to cycle through") | |
| parser.add_argument("--model", default="gpt-4o-mini", | |
| help="OpenAI model to use as policy") | |
| parser.add_argument("--output", default="training_results.json", | |
| help="Output file for results") | |
| parser.add_argument("--verbose", action="store_true", | |
| help="Print step-level details") | |
| args = parser.parse_args() | |
| tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task] | |
| seeds = [int(s) for s in args.seeds.split(",")] | |
| train( | |
| server_url=args.server, | |
| tasks=tasks, | |
| n_episodes=args.episodes, | |
| seeds=seeds, | |
| model=args.model, | |
| verbose=args.verbose, | |
| output_path=args.output, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |