Ajayyy00 commited on
Commit
bb0d7fd
·
1 Parent(s): f6c80b9

Initial commit: CyberSOC Enterprise Environment Baseline

Browse files
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Create user with home directory
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ WORKDIR /app
9
+
10
+ # Copy requirements and install
11
+ COPY --chown=user ./requirements.txt requirements.txt
12
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
+
14
+ # Copy all environment files
15
+ COPY --chown=user . /app
16
+
17
+ # The hackathon expects the OpenEnv Server to run on 7860 for Spaces Gradio endpoints
18
+ # We will use uvicorn to host the app which complies with the spec
19
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,48 @@
1
  ---
2
- title: CyberSOC
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CyberSOC Enterprise Environment
3
+ emoji: 🛡️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
 
8
  ---
9
 
10
+ # CyberSOC: Enterprise Network Defense Environment 🛡️
11
+
12
+ CyberSOC is a highly scalable, production-grade OpenEnv reinforcement learning environment designed to evaluate AI agents on their ability to perform Incident Response across a 500-node enterprise network.
13
+
14
+ ## 🌟 Hackathon Highlights
15
+
16
+ This is not a toy benchmark. This environment models real-world enterprise infrastructure:
17
+
18
+ 1. **Massive Procedural Variety (1,000 Tasks):**
19
+ Instead of hardcoded puzzles, CyberSOC features a seed-based procedural generation engine. We dynamically spin up **1000 unique network topologies** containing a mix of 12 distinct attack vectors (from Supply Chain to Ransomware). This guarantees that agents cannot overfit.
20
+
21
+ 2. **Dense, Business-Aligned Grading:**
22
+ Unlike simple pass/fail benchmarks, CyberSOC uses intelligent reward shaping. Agents earn rewards for hunting down malicious processes and blocking IOCs mid-investigation. However, they are heavily penalized for increasing "Business Downtime" (quarantining healthy subnets haphazardly). They must balance security guarantees with business continuity.
23
+
24
+ 3. **Complex State & Action Space:**
25
+ Agents must use structured tools (Pydantic models) to traverse the environment:
26
+ - `query_host`: Map the active topology.
27
+ - `run_forensics`: Scrape memory and process lists.
28
+ - `kill_process` & `block_ioc`: Perform active containment.
29
+ - `isolate_segment`: Implement extreme fail-safes.
30
+ - `submit_containment_plan`: Formulate a final executive overview.
31
+
32
+ 4. **Flawless Inference Benchmarking:**
33
+ The included `inference.py` provides an out-of-the-box evaluation loop. We have successfully benchmarked state-of-the-art LLMs (like Qwen2.5-72B and Llama-3.3-70B) natively within this environment using standard OpenAI/Groq clients.
34
+
35
+ ## 🚀 Running the Environment
36
+
37
+ This repository is fully packaged as a Docker container.
38
+
39
+ ### Local Execution:
40
+ ```bash
41
+ python inference.py
42
+ ```
43
+
44
+ ### Agent Configuration
45
+ To run your own agent, define:
46
+ `API_KEY` - Your LLM token
47
+ `API_BASE_URL` - The endpoint you are hitting
48
+ `MODEL_NAME` - Target identifier
__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CyberSOCEnv — Enterprise Cybersecurity Operations Center Environment."""
8
+
9
+ from .client import CyberSOCClient
10
+ from .models import (
11
+ SOCObservation,
12
+ SOCActionWrapper,
13
+ SOCState,
14
+ QueryHost,
15
+ IsolateSegment,
16
+ BlockIOC,
17
+ RunForensics,
18
+ KillProcess,
19
+ SubmitContainmentPlan,
20
+ )
21
+
22
+ __all__ = [
23
+ "CyberSOCClient",
24
+ "SOCObservation",
25
+ "SOCActionWrapper",
26
+ "SOCState",
27
+ "QueryHost",
28
+ "IsolateSegment",
29
+ "BlockIOC",
30
+ "RunForensics",
31
+ "KillProcess",
32
+ "SubmitContainmentPlan",
33
+ ]
__pycache__/__init__.cpython-311.pyc ADDED
Binary file (717 Bytes). View file
 
__pycache__/client.cpython-311.pyc ADDED
Binary file (5.1 kB). View file
 
__pycache__/inference.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
__pycache__/models.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
client.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CyberSOCEnv Client — connects to the SOC environment server."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+
14
+ from .models import (
15
+ SOCObservation,
16
+ SOCActionWrapper,
17
+ SOCState,
18
+ Alert,
19
+ Severity,
20
+ ThreatType,
21
+ NetworkTopology,
22
+ ForensicsResult,
23
+ TimelineEntry,
24
+ )
25
+
26
+
27
+ class CyberSOCClient(
28
+ EnvClient[SOCActionWrapper, SOCObservation, SOCState]
29
+ ):
30
+ """
31
+ Client for the CyberSOCEnv environment.
32
+
33
+ Connects via WebSocket to the SOC environment server for
34
+ low-latency, persistent-session interaction.
35
+
36
+ Example:
37
+ >>> with CyberSOCClient(base_url="http://localhost:8000") as client:
38
+ ... result = client.reset()
39
+ ... print(result.observation.alert_queue)
40
+ ...
41
+ ... from play.models import QueryHost
42
+ ... result = client.step(SOCActionWrapper(type="query_host", hostname="WS-001"))
43
+ ... print(result.observation.host_forensics)
44
+ """
45
+
46
+ def _step_payload(self, action: SOCActionWrapper) -> Dict:
47
+ """Convert SOCActionWrapper to JSON payload for step message."""
48
+ return action.model_dump(exclude_none=True)
49
+
50
+ def _parse_result(self, payload: Dict) -> StepResult[SOCObservation]:
51
+ """Parse server response into StepResult[SOCObservation]."""
52
+ obs_data = payload.get("observation", {})
53
+
54
+ # Parse alerts
55
+ alerts = [Alert(**a) for a in obs_data.get("alert_queue", [])]
56
+
57
+ # Parse network topology
58
+ topo_data = obs_data.get("network_topology", {})
59
+ topology = NetworkTopology(**topo_data) if topo_data else NetworkTopology()
60
+
61
+ # Parse forensics (may be None)
62
+ forensics_data = obs_data.get("host_forensics")
63
+ forensics = ForensicsResult(**forensics_data) if forensics_data else None
64
+
65
+ # Parse timeline
66
+ timeline = [TimelineEntry(**t) for t in obs_data.get("timeline", [])]
67
+
68
+ observation = SOCObservation(
69
+ alert_queue=alerts,
70
+ network_topology=topology,
71
+ host_forensics=forensics,
72
+ timeline=timeline,
73
+ business_impact_score=obs_data.get("business_impact_score", 0.0),
74
+ step_count=obs_data.get("step_count", 0),
75
+ active_threats=obs_data.get("active_threats", []),
76
+ max_steps=obs_data.get("max_steps", 30),
77
+ task_id=obs_data.get("task_id", "easy"),
78
+ total_reward=obs_data.get("total_reward", 0.0),
79
+ final_score=obs_data.get("final_score"),
80
+ grade_breakdown=obs_data.get("grade_breakdown"),
81
+ done=payload.get("done", False),
82
+ reward=payload.get("reward"),
83
+ )
84
+
85
+ return StepResult(
86
+ observation=observation,
87
+ reward=payload.get("reward"),
88
+ done=payload.get("done", False),
89
+ )
90
+
91
+ def _parse_state(self, payload: Dict) -> SOCState:
92
+ """Parse server response into SOCState."""
93
+ return SOCState(
94
+ episode_id=payload.get("episode_id"),
95
+ step_count=payload.get("step_count", 0),
96
+ task_id=payload.get("task_id", "easy"),
97
+ total_reward=payload.get("total_reward", 0.0),
98
+ business_impact=payload.get("business_impact", 0.0),
99
+ )
demo_scripted.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Scripted CyberSOCEnv Demo — runs optimal actions for all 3 tasks
4
+ without requiring an LLM. Demonstrates the full environment pipeline.
5
+ """
6
+ import asyncio
7
+ import json
8
+ import websockets
9
+
10
+ WS_URL = "ws://127.0.0.1:8000/ws"
11
+
12
+ # Pre-scripted optimal action sequences for each task
13
+ SCRIPTS = {
14
+ "easy": [
15
+ {"type": "query_host", "hostname": "WS-042"},
16
+ {"type": "run_forensics", "hostname": "WS-042"},
17
+ {"type": "kill_process", "hostname": "WS-042", "process_name": "cryptolocker.exe"},
18
+ {"type": "block_ioc", "ioc_value": "e99a18c428cb38d5f260853678922e03", "ioc_type": "hash"},
19
+ {"type": "submit_containment_plan",
20
+ "plan": [{"threat_id": "T-EASY-001", "actions_taken": ["killed cryptolocker.exe", "blocked hash", "ran forensics"], "root_cause": "ransomware via user download", "confidence": 0.95}],
21
+ "executive_summary": "Single ransomware on WS-042 fully contained. Process killed, IOC blocked."},
22
+ ],
23
+ "medium": [
24
+ {"type": "run_forensics", "hostname": "WS-017"},
25
+ {"type": "kill_process", "hostname": "WS-017", "process_name": "powershell.exe"},
26
+ {"type": "kill_process", "hostname": "WS-017", "process_name": "mimikatz.exe"},
27
+ {"type": "block_ioc", "ioc_value": "evil-login.example.com", "ioc_type": "domain"},
28
+ {"type": "block_ioc", "ioc_value": "d41d8cd98f00b204e9800998ecf8427e", "ioc_type": "hash"},
29
+ {"type": "run_forensics", "hostname": "DEV-033"},
30
+ {"type": "kill_process", "hostname": "DEV-033", "process_name": "svchost_backdoor.exe"},
31
+ {"type": "run_forensics", "hostname": "FIN-012"},
32
+ {"type": "kill_process", "hostname": "FIN-012", "process_name": "svchost_backdoor.exe"},
33
+ {"type": "block_ioc", "ioc_value": "203.0.113.50", "ioc_type": "ip"},
34
+ {"type": "block_ioc", "ioc_value": "aabbccdd11223344eeff5566778899aa", "ioc_type": "hash"},
35
+ {"type": "block_ioc", "ioc_value": "112233445566778899aabbccddeeff00", "ioc_type": "hash"},
36
+ {"type": "submit_containment_plan",
37
+ "plan": [
38
+ {"threat_id": "T-MED-001", "actions_taken": ["killed powershell.exe", "blocked evil-login.example.com"], "root_cause": "phishing email with macro", "confidence": 0.9},
39
+ {"threat_id": "T-MED-002", "actions_taken": ["killed mimikatz.exe", "blocked hash"], "root_cause": "credential theft via Mimikatz", "confidence": 0.95},
40
+ {"threat_id": "T-MED-003", "actions_taken": ["killed svchost_backdoor on DEV-033 and FIN-012", "blocked C2 IP"], "root_cause": "lateral movement using stolen creds", "confidence": 0.9},
41
+ ],
42
+ "executive_summary": "Multi-stage attack contained: phishing -> cred theft -> lateral movement across 3 hosts."},
43
+ ],
44
+ "hard": [
45
+ {"type": "block_ioc", "ioc_value": "198.51.100.77", "ioc_type": "ip"},
46
+ {"type": "block_ioc", "ioc_value": "cdn-update.malware-c2.net", "ioc_type": "domain"},
47
+ {"type": "run_forensics", "hostname": "EXEC-003"},
48
+ {"type": "kill_process", "hostname": "EXEC-003", "process_name": "outlook_macro.exe"},
49
+ {"type": "kill_process", "hostname": "EXEC-003", "process_name": "svchost_c2.exe"},
50
+ {"type": "run_forensics", "hostname": "WS-088"},
51
+ {"type": "kill_process", "hostname": "WS-088", "process_name": "svchost_c2.exe"},
52
+ {"type": "run_forensics", "hostname": "SRV-002"},
53
+ {"type": "kill_process", "hostname": "SRV-002", "process_name": "exploit_kernel.exe"},
54
+ {"type": "kill_process", "hostname": "SRV-002", "process_name": "data_pump.exe"},
55
+ {"type": "block_ioc", "ioc_value": "203.0.113.99", "ioc_type": "ip"},
56
+ {"type": "block_ioc", "ioc_value": "exfil.malware-c2.net", "ioc_type": "domain"},
57
+ {"type": "run_forensics", "hostname": "FIN-008"},
58
+ {"type": "kill_process", "hostname": "FIN-008", "process_name": "data_pump.exe"},
59
+ {"type": "run_forensics", "hostname": "SRV-010"},
60
+ {"type": "kill_process", "hostname": "SRV-010", "process_name": "blackcat_ransom.exe"},
61
+ {"type": "kill_process", "hostname": "SRV-015", "process_name": "blackcat_ransom.exe"},
62
+ {"type": "block_ioc", "ioc_value": "deadbeef0123456789abcdef01234567", "ioc_type": "hash"},
63
+ {"type": "block_ioc", "ioc_value": "cafebabe9876543210fedcba98765432", "ioc_type": "hash"},
64
+ {"type": "submit_containment_plan",
65
+ "plan": [
66
+ {"threat_id": "T-HARD-001", "actions_taken": ["killed outlook_macro.exe", "blocked C2"], "root_cause": "spearphishing executive VP", "confidence": 0.95},
67
+ {"threat_id": "T-HARD-002", "actions_taken": ["killed svchost_c2.exe on 2 hosts", "blocked C2 domains"], "root_cause": "C2 beaconing via encrypted channel", "confidence": 0.9},
68
+ {"threat_id": "T-HARD-003", "actions_taken": ["killed exploit_kernel.exe"], "root_cause": "kernel exploit for privilege escalation on SRV-002", "confidence": 0.9},
69
+ {"threat_id": "T-HARD-004", "actions_taken": ["killed data_pump.exe on SRV-002 and FIN-008", "blocked exfil IP/domain"], "root_cause": "data exfiltration of PII and financial records", "confidence": 0.85},
70
+ {"threat_id": "T-HARD-005", "actions_taken": ["killed blackcat_ransom.exe on SRV-010 and SRV-015"], "root_cause": "BlackCat ransomware deployment on production storage", "confidence": 0.95},
71
+ ],
72
+ "executive_summary": "APT campaign fully contained: initial access via exec phishing, C2 cut, privilege escalation stopped, exfiltration blocked, ransomware neutralized."},
73
+ ],
74
+ }
75
+
76
+
77
+ async def run_task(task_id: str):
78
+ """Run a single task with scripted optimal actions."""
79
+ print(f"\n{'='*60}")
80
+ print(f"[START] task={task_id} env=cybersocenv model=scripted-optimal")
81
+ print(f"{'='*60}")
82
+
83
+ async with websockets.connect(WS_URL) as ws:
84
+ # Reset
85
+ await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}}))
86
+ resp = json.loads(await ws.recv())
87
+ data = resp.get("data", {})
88
+ obs = data.get("observation", {})
89
+ print(f" Reset: alerts={len(obs.get('alert_queue', []))}, threats={obs.get('active_threats', [])}")
90
+
91
+ # Execute scripted actions
92
+ rewards = []
93
+ for i, action in enumerate(SCRIPTS[task_id], 1):
94
+ await ws.send(json.dumps({"type": "step", "data": action}))
95
+ resp = json.loads(await ws.recv())
96
+
97
+ if resp.get("type") == "error":
98
+ print(f" [STEP] step={i} action={action['type']} ERROR: {resp.get('data', {})}")
99
+ continue
100
+
101
+ data = resp.get("data", {})
102
+ obs = data.get("observation", {})
103
+ reward = data.get("reward", 0)
104
+ done = data.get("done", False)
105
+ rewards.append(reward)
106
+
107
+ print(f" [STEP] step={i} action={action['type']} reward={reward:.2f} done={done}")
108
+
109
+ if done:
110
+ score = obs.get("final_score") or 0.0
111
+ breakdown = obs.get("grade_breakdown") or {}
112
+ print(f"\n FINAL SCORE: {score:.4f}")
113
+ if breakdown:
114
+ print(f" Threats contained: {breakdown.get('threats_contained')}/{breakdown.get('total_threats')}")
115
+ print(f" IOCs blocked: {breakdown.get('iocs_blocked')}")
116
+ print(f" Hosts forensics: {breakdown.get('hosts_forensics')}")
117
+ print(f" Business impact: {breakdown.get('business_impact'):.4f}")
118
+ total_r = sum(rewards)
119
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
120
+ print(f" [END] success={score > 0.5} steps={i} score={score:.3f} rewards={rewards_str}")
121
+ break
122
+
123
+ await ws.send(json.dumps({"type": "close"}))
124
+
125
+ return score if 'score' in dir() else 0.0
126
+
127
+
128
+ async def main():
129
+ print("# CyberSOCEnv Scripted Optimal Agent Demo")
130
+ print("# This runs pre-computed optimal actions for all 3 tasks")
131
+ print("# to demonstrate the environment and grading pipeline.\n")
132
+
133
+ scores = {}
134
+ for task_id in ["easy", "medium", "hard"]:
135
+ score = await run_task(task_id)
136
+ scores[task_id] = score
137
+
138
+ print(f"\n{'='*60}")
139
+ print("# FINAL RESULTS")
140
+ print(f"{'='*60}")
141
+ for tid, s in scores.items():
142
+ print(f" {tid:8s}: {s:.4f}")
143
+ avg = sum(scores.values()) / len(scores)
144
+ print(f"\n Average: {avg:.4f}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ asyncio.run(main())
eval_100.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import asyncio
3
+ import json
4
+ import os
5
+ import sys
6
+
7
+ from openai import OpenAI
8
+ from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
9
+
10
+ # Ensure we can import from play
11
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
12
+
13
+ from play.inference import run_episode, API_BASE_URL, API_KEY, MODEL_NAME
14
+ from play.server.task_generator import list_generated_task_ids
15
+
16
+ RESULTS_FILE = "d:\\MetaNew\\play\\eval_results_100.json"
17
+
18
+ @retry(
19
+ wait=wait_exponential(multiplier=2, min=5, max=120),
20
+ stop=stop_after_attempt(10),
21
+ retry=retry_if_exception_type(Exception)
22
+ )
23
+ async def resilient_run_episode(client, task_id):
24
+ try:
25
+ return await run_episode(client, task_id)
26
+ except Exception as exc:
27
+ if "429" in str(exc) or "RateLimit" in str(exc):
28
+ print(f"\n[!] Rate limit hit for {task_id}. Backing off... ({exc})\n", flush=True)
29
+ raise # Trigger tenacity retry
30
+ print(f"\n[!] Unknown error directly in resilient wrapper: {exc}\n", flush=True)
31
+ raise
32
+
33
+ async def main():
34
+ print(f"=== Starting Batch Evaluation (100 Tasks) ===", flush=True)
35
+ print(f"Model: {MODEL_NAME} | Endpoint: {API_BASE_URL}", flush=True)
36
+
37
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
38
+
39
+ # Load checkpoint
40
+ results = {}
41
+ if os.path.exists(RESULTS_FILE):
42
+ with open(RESULTS_FILE, "r") as f:
43
+ results = json.load(f)
44
+ print(f"Loaded {len(results)} existing results from checkpoint.", flush=True)
45
+
46
+ tasks = list_generated_task_ids(100)
47
+
48
+ success_count = 0
49
+ total_score = 0.0
50
+
51
+ for i, task_id in enumerate(tasks):
52
+ if task_id in results:
53
+ score = results[task_id]["score"]
54
+ if results[task_id]["success"]:
55
+ success_count += 1
56
+ total_score += score
57
+ continue
58
+
59
+ print(f"\n--- Evaluating {i+1}/100: {task_id} ---", flush=True)
60
+ try:
61
+ success, steps, score, rewards = await resilient_run_episode(client, task_id)
62
+
63
+ results[task_id] = {
64
+ "success": success,
65
+ "steps": steps,
66
+ "score": score,
67
+ "rewards": rewards
68
+ }
69
+
70
+ if success:
71
+ success_count += 1
72
+ total_score += score
73
+
74
+ # Save checkpoint
75
+ with open(RESULTS_FILE, "w") as f:
76
+ json.dump(results, f, indent=2)
77
+
78
+ except Exception as e:
79
+ print(f"\n[FATAL] FAILED TO EVALUATE {task_id} AFTER RETRIES. Error: {e}", flush=True)
80
+ break
81
+
82
+ completed = len(results)
83
+ if completed > 0:
84
+ print(f"\n=== FINAL EVALUATION SUMMARY ===", flush=True)
85
+ print(f"Tasks Completed: {completed}/100", flush=True)
86
+ print(f"Success Rate: {success_count}/{completed} ({(success_count/completed)*100:.1f}%)", flush=True)
87
+ print(f"Average Score: {total_score/completed:.3f}", flush=True)
88
+
89
+ if __name__ == "__main__":
90
+ asyncio.run(main())
inference.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ CyberSOCEnv Baseline Inference Script.
10
+
11
+ HACKATHON RULES:
12
+ - File must be named inference.py in the project root
13
+ - Must use OpenAI Client for all LLM calls
14
+ - Must emit structured stdout logs: [START], [STEP], [END]
15
+ - Runtime < 20 minutes
16
+ - Must work on vcpu=2, memory=8gb
17
+
18
+ Environment Variables:
19
+ API_BASE_URL - The API endpoint for the LLM
20
+ MODEL_NAME - The model identifier to use for inference
21
+ HF_TOKEN - Your Hugging Face / API key
22
+ """
23
+
24
+ import asyncio
25
+ import json
26
+ import os
27
+ import textwrap
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ from openai import OpenAI
31
+
32
+ from play.models import SOCActionWrapper, SOCObservation
33
+ from play.server.play_environment import CyberSOCEnvironment
34
+
35
+ # =============================================================================
36
+ # Configuration (from environment variables)
37
+ # =============================================================================
38
+
39
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or ""
40
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
41
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
42
+
43
+ BENCHMARK = "cybersocenv"
44
+ TASKS = ["easy", "medium", "hard"]
45
+ MAX_STEPS = {"easy": 15, "medium": 25, "hard": 30}
46
+ TEMPERATURE = 0.1
47
+ MAX_TOKENS = 1024
48
+
49
+ # Scoring: normalize rewards to [0, 1]
50
+ MAX_POSSIBLE_REWARD = 2.0 # Approximate max reward per episode
51
+ SUCCESS_SCORE_THRESHOLD = 0.3
52
+
53
+ # =============================================================================
54
+ # System Prompt
55
+ # =============================================================================
56
+
57
+ SYSTEM_PROMPT = textwrap.dedent("""
58
+ You are an expert Cybersecurity SOC (Security Operations Center) Analyst AI.
59
+ You are responding to security incidents on a 500-node enterprise network.
60
+
61
+ Your goal: Investigate alerts, contain all threats, and submit a containment plan — while minimizing business downtime.
62
+
63
+ Available Actions (respond with exactly ONE JSON object per turn):
64
+
65
+ 1. Query a host: {"type": "query_host", "hostname": "<HOST>"}
66
+ 2. Isolate a segment (causes downtime): {"type": "isolate_segment", "subnet": "<SUBNET>", "reason": "<WHY>"}
67
+ 3. Block an IOC: {"type": "block_ioc", "ioc_value": "<VALUE>", "ioc_type": "ip|domain|hash"}
68
+ 4. Run forensics: {"type": "run_forensics", "hostname": "<HOST>"}
69
+ 5. Kill a process: {"type": "kill_process", "hostname": "<HOST>", "process_name": "<PROC>"}
70
+ 6. Submit containment plan (ends episode): {"type": "submit_containment_plan", "plan": [{"threat_id": "<ID>", "actions_taken": [...], "root_cause": "<CAUSE>", "confidence": 0.0-1.0}], "executive_summary": "<SUMMARY>"}
71
+
72
+ Rules:
73
+ - Respond with ONLY a valid JSON object. No markdown, no explanation.
74
+ - Investigate before acting. Query hosts and run forensics to gather evidence.
75
+ - Block IOCs (IPs, domains, hashes) found in alerts and forensics.
76
+ - Kill malicious processes found via forensics.
77
+ - Avoid unnecessary subnet isolation — it increases business impact.
78
+ - Submit the containment plan once you've contained all threats.
79
+ - You have a limited number of steps. Be efficient.
80
+ """).strip()
81
+
82
+
83
+ # =============================================================================
84
+ # Logging Helpers (EXACT hackathon format — lowercase booleans, null errors)
85
+ # =============================================================================
86
+
87
+ def log_start(task: str, env: str, model: str) -> None:
88
+ print(f"[START] task={task} env={env} model={model}", flush=True)
89
+
90
+
91
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
92
+ error_val = error if error else "null"
93
+ done_val = str(done).lower()
94
+ print(
95
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
96
+ flush=True,
97
+ )
98
+
99
+
100
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
101
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
102
+ print(
103
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
104
+ flush=True,
105
+ )
106
+
107
+
108
+ # =============================================================================
109
+ # Observation Formatting for LLM
110
+ # =============================================================================
111
+
112
+ def format_observation(obs: SOCObservation) -> str:
113
+ """Format observation into readable text for the LLM."""
114
+ parts = []
115
+
116
+ # Alert queue
117
+ if obs.alert_queue:
118
+ parts.append(f"## Active Alerts ({len(obs.alert_queue)}):")
119
+ for a in obs.alert_queue:
120
+ parts.append(
121
+ f" - [{a.severity.value.upper()}] {a.alert_id} "
122
+ f"on {a.source_host} ({a.subnet}): {a.description}"
123
+ )
124
+ if a.ioc_indicators:
125
+ parts.append(f" IOCs: {', '.join(a.ioc_indicators)}")
126
+
127
+ # Network topology
128
+ topo = obs.network_topology
129
+ parts.append(f"\n## Network Status:")
130
+ parts.append(f" Compromised: {topo.compromised_count} | "
131
+ f"Isolated: {topo.isolated_count} | "
132
+ f"Online: {topo.online_count}")
133
+
134
+ # Forensics
135
+ if obs.host_forensics:
136
+ f = obs.host_forensics
137
+ parts.append(f"\n## Forensics Result ({f.hostname}):")
138
+ parts.append(f" Compromised: {f.is_compromised}")
139
+ parts.append(f" Malicious processes: {f.malicious_processes}")
140
+ parts.append(f" Suspicious files: {f.suspicious_files}")
141
+ parts.append(f" Network connections: {f.network_connections}")
142
+ parts.append(f" Memory artifacts: {f.memory_artifacts}")
143
+
144
+ # Active threats
145
+ parts.append(f"\n## Active Threats: {obs.active_threats if obs.active_threats else 'None (all contained!)'}")
146
+ parts.append(f"## Business Impact: {obs.business_impact_score:.2f}")
147
+ parts.append(f"## Step: {obs.step_count} / {obs.max_steps}")
148
+
149
+ # Timeline (last 5)
150
+ if obs.timeline:
151
+ parts.append(f"\n## Recent Actions:")
152
+ for t in obs.timeline[-5:]:
153
+ parts.append(f" Step {t.step}: {t.action_type} -> {t.target} (reward={t.reward:.2f})")
154
+
155
+ return "\n".join(parts)
156
+
157
+
158
+ def parse_llm_action(content: str) -> Dict[str, Any]:
159
+ """Parse the LLM's response into a valid action dict."""
160
+ content = content.strip()
161
+ if content.startswith("```"):
162
+ lines = content.split("\n")
163
+ lines = [l for l in lines if not l.strip().startswith("```")]
164
+ content = "\n".join(lines).strip()
165
+
166
+ try:
167
+ action = json.loads(content)
168
+ if isinstance(action, dict) and "type" in action:
169
+ return action
170
+ except json.JSONDecodeError:
171
+ pass
172
+
173
+ # Try to find JSON in the response
174
+ for start in range(len(content)):
175
+ if content[start] == "{":
176
+ for end in range(len(content), start, -1):
177
+ if content[end - 1] == "}":
178
+ try:
179
+ action = json.loads(content[start:end])
180
+ if isinstance(action, dict) and "type" in action:
181
+ return action
182
+ except json.JSONDecodeError:
183
+ continue
184
+
185
+ raise ValueError(f"Could not parse action from LLM response: {content[:200]}")
186
+
187
+
188
+ def get_model_action(
189
+ client: OpenAI,
190
+ step: int,
191
+ obs: SOCObservation,
192
+ task_id: str,
193
+ history: List[str],
194
+ ) -> str:
195
+ """Get the next action from the LLM."""
196
+ obs_text = format_observation(obs)
197
+
198
+ if step == 1:
199
+ user_content = (
200
+ f"## Incident Briefing (Task: {task_id.upper()})\n\n"
201
+ f"{obs_text}\n\n"
202
+ f"Analyze the alerts and begin your investigation. Respond with a single JSON action."
203
+ )
204
+ else:
205
+ user_content = (
206
+ f"## Observation after your action:\n\n"
207
+ f"{obs_text}\n\n"
208
+ f"Continue your investigation. Respond with a single JSON action."
209
+ )
210
+
211
+ try:
212
+ completion = client.chat.completions.create(
213
+ model=MODEL_NAME,
214
+ messages=[
215
+ {"role": "system", "content": SYSTEM_PROMPT},
216
+ {"role": "user", "content": user_content},
217
+ ],
218
+ temperature=TEMPERATURE,
219
+ max_tokens=MAX_TOKENS,
220
+ stream=False,
221
+ )
222
+ text = (completion.choices[0].message.content or "").strip()
223
+ return text if text else '{"type": "query_host", "hostname": "WS-001"}'
224
+ except Exception as exc:
225
+ if "429" in str(exc) or "RateLimit" in str(exc):
226
+ raise # Let the batch runner handle rate limits
227
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
228
+ return '{"type": "query_host", "hostname": "WS-001"}'
229
+
230
+
231
+ # =============================================================================
232
+ # Episode Runner
233
+ # =============================================================================
234
+
235
+ async def run_episode(client: OpenAI, task_id: str) -> tuple:
236
+ """Run a single episode. Returns (success, steps, score, rewards)."""
237
+ env = CyberSOCEnvironment()
238
+ history: List[str] = []
239
+ rewards: List[float] = []
240
+ steps_taken = 0
241
+ score = 0.0
242
+ success = False
243
+
244
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
245
+
246
+ try:
247
+ # Reset environment
248
+ obs = env.reset(task_id=task_id)
249
+
250
+ max_steps = MAX_STEPS.get(task_id, 30)
251
+
252
+ for step in range(1, max_steps + 1):
253
+ if obs.done:
254
+ break
255
+
256
+ # Get action from LLM
257
+ llm_response = get_model_action(client, step, obs, task_id, history)
258
+
259
+ # Parse and execute
260
+ error = None
261
+ action_str = "unknown"
262
+ reward = 0.0
263
+
264
+ try:
265
+ action_dict = parse_llm_action(llm_response)
266
+ action_str = action_dict.get("type", "unknown")
267
+ action = SOCActionWrapper(**action_dict)
268
+ obs = env.step(action)
269
+ reward = obs.reward or 0.0
270
+ done = obs.done
271
+ except Exception as exc:
272
+ error = str(exc)[:200]
273
+ done = False
274
+ reward = 0.0
275
+
276
+ rewards.append(reward)
277
+ steps_taken = step
278
+
279
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
280
+
281
+ history.append(f"Step {step}: {action_str} -> reward {reward:+.2f}")
282
+
283
+ if done:
284
+ break
285
+
286
+ # Calculate score from final_score if available, else normalize rewards
287
+ if obs.final_score is not None:
288
+ score = obs.final_score
289
+ else:
290
+ score = sum(rewards) / MAX_POSSIBLE_REWARD if MAX_POSSIBLE_REWARD > 0 else 0.0
291
+
292
+ score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
293
+ success = score >= SUCCESS_SCORE_THRESHOLD
294
+
295
+ finally:
296
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
297
+
298
+ return success, steps_taken, score, rewards
299
+
300
+
301
+ # =============================================================================
302
+ # Main
303
+ # =============================================================================
304
+
305
+ async def main() -> None:
306
+ """Run baseline inference across all tasks."""
307
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
308
+
309
+ total_scores = {}
310
+ for task_id in TASKS:
311
+ success, steps, score, rewards = await run_episode(client, task_id)
312
+ total_scores[task_id] = score
313
+
314
+ # Print summary
315
+ avg = sum(total_scores.values()) / len(total_scores) if total_scores else 0.0
316
+ print(f"\n# Summary: avg_score={avg:.3f}", flush=True)
317
+ for tid, s in total_scores.items():
318
+ print(f"# {tid}: {s:.3f}", flush=True)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the CyberSOCEnv — Enterprise Cybersecurity Operations Center.
9
+
10
+ Defines strict Pydantic models for:
11
+ - Observation: What the agent sees (alerts, forensics, network state, business impact)
12
+ - Action: What the agent can do (discriminated union of 6 action types)
13
+ - Internal state: Deterministic network graph, attack chains, and task tracking
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from enum import Enum
19
+ from typing import Annotated, Any, Dict, List, Literal, Optional, Union
20
+
21
+ from openenv.core.env_server.types import Action, Observation, State
22
+ from pydantic import BaseModel, ConfigDict, Field
23
+
24
+
25
+ # =============================================================================
26
+ # Enums
27
+ # =============================================================================
28
+
29
+
30
+ class Severity(str, Enum):
31
+ """SIEM alert severity levels."""
32
+ LOW = "low"
33
+ MEDIUM = "medium"
34
+ HIGH = "high"
35
+ CRITICAL = "critical"
36
+
37
+
38
+ class ThreatType(str, Enum):
39
+ """Classification of threat types in the SOC environment."""
40
+ RANSOMWARE = "ransomware"
41
+ PHISHING = "phishing"
42
+ CREDENTIAL_THEFT = "credential_theft"
43
+ LATERAL_MOVEMENT = "lateral_movement"
44
+ C2_COMMUNICATION = "c2_communication"
45
+ DATA_EXFILTRATION = "data_exfiltration"
46
+ PRIVILEGE_ESCALATION = "privilege_escalation"
47
+ MALWARE = "malware"
48
+ CRYPTOMINING = "cryptomining"
49
+ SUPPLY_CHAIN = "supply_chain"
50
+ INSIDER_THREAT = "insider_threat"
51
+ WEBSHELL = "webshell"
52
+ BOTNET = "botnet"
53
+
54
+
55
+ class HostStatus(str, Enum):
56
+ """Host operational status."""
57
+ ONLINE = "online"
58
+ COMPROMISED = "compromised"
59
+ ISOLATED = "isolated"
60
+ OFFLINE = "offline"
61
+
62
+
63
+ class SubnetRole(str, Enum):
64
+ """Business function of a network subnet."""
65
+ CORPORATE = "corporate"
66
+ ENGINEERING = "engineering"
67
+ FINANCE = "finance"
68
+ DMZ = "dmz"
69
+ DATACENTER = "datacenter"
70
+ EXECUTIVE = "executive"
71
+
72
+
73
+ # =============================================================================
74
+ # Alert & Network Sub-Models (used in Observation)
75
+ # =============================================================================
76
+
77
+
78
+ class Alert(BaseModel):
79
+ """A single SIEM/EDR alert in the queue."""
80
+ model_config = ConfigDict(extra="forbid")
81
+
82
+ alert_id: str = Field(..., description="Unique alert identifier")
83
+ timestamp: str = Field(..., description="ISO-8601 timestamp of the alert")
84
+ source_host: str = Field(..., description="Hostname that generated the alert")
85
+ severity: Severity = Field(..., description="Alert severity level")
86
+ threat_type: ThreatType = Field(..., description="Classified threat type")
87
+ description: str = Field(..., description="Human-readable alert description")
88
+ ioc_indicators: List[str] = Field(
89
+ default_factory=list,
90
+ description="Indicators of compromise (IPs, hashes, domains)",
91
+ )
92
+ subnet: str = Field(..., description="Subnet where the alert originated")
93
+ is_acknowledged: bool = Field(default=False, description="Whether the SOC analyst has acknowledged this alert")
94
+
95
+
96
+ class HostInfo(BaseModel):
97
+ """Summary information about a single network host."""
98
+ model_config = ConfigDict(extra="forbid")
99
+
100
+ hostname: str = Field(..., description="Host FQDN")
101
+ ip_address: str = Field(..., description="IPv4 address")
102
+ subnet: str = Field(..., description="Subnet the host belongs to")
103
+ role: SubnetRole = Field(..., description="Business function")
104
+ status: HostStatus = Field(default=HostStatus.ONLINE, description="Current status")
105
+ running_processes: List[str] = Field(default_factory=list, description="Running process names")
106
+ open_ports: List[int] = Field(default_factory=list, description="Open TCP ports")
107
+ criticality: float = Field(
108
+ default=0.5, ge=0.0, le=1.0,
109
+ description="Business criticality score (0=low, 1=mission-critical)",
110
+ )
111
+
112
+
113
+ class NetworkTopology(BaseModel):
114
+ """Summarized view of the 500-node enterprise network."""
115
+ model_config = ConfigDict(extra="forbid")
116
+
117
+ total_hosts: int = Field(default=500, description="Total hosts in the network")
118
+ subnets: Dict[str, int] = Field(
119
+ default_factory=dict,
120
+ description="Map of subnet name -> host count",
121
+ )
122
+ compromised_count: int = Field(default=0, description="Number of compromised hosts")
123
+ isolated_count: int = Field(default=0, description="Number of isolated hosts")
124
+ online_count: int = Field(default=500, description="Number of online hosts")
125
+
126
+
127
+ class ForensicsResult(BaseModel):
128
+ """Results from running forensics on a host."""
129
+ model_config = ConfigDict(extra="forbid")
130
+
131
+ hostname: str = Field(..., description="Analyzed host")
132
+ malicious_processes: List[str] = Field(default_factory=list, description="Detected malicious processes")
133
+ suspicious_files: List[str] = Field(default_factory=list, description="Suspicious file paths found")
134
+ network_connections: List[str] = Field(
135
+ default_factory=list,
136
+ description="Suspicious outbound connections (ip:port)",
137
+ )
138
+ registry_modifications: List[str] = Field(default_factory=list, description="Modified registry keys")
139
+ memory_artifacts: List[str] = Field(default_factory=list, description="In-memory IOCs found")
140
+ is_compromised: bool = Field(default=False, description="Whether forensics confirm compromise")
141
+
142
+
143
+ class TimelineEntry(BaseModel):
144
+ """A single entry in the analyst action timeline."""
145
+ model_config = ConfigDict(extra="forbid")
146
+
147
+ step: int = Field(..., description="Step number when this action was taken")
148
+ action_type: str = Field(..., description="Type of action taken")
149
+ target: str = Field(..., description="Target of the action (host, subnet, IOC)")
150
+ result: str = Field(..., description="Outcome description")
151
+ reward: float = Field(default=0.0, description="Reward received for this action")
152
+
153
+
154
+ # =============================================================================
155
+ # Observation
156
+ # =============================================================================
157
+
158
+
159
+ class SOCObservation(Observation):
160
+ """What the SOC agent sees at each step.
161
+
162
+ Extends OpenEnv Observation (inherits: done, reward, metadata).
163
+ """
164
+
165
+ alert_queue: List[Alert] = Field(
166
+ default_factory=list,
167
+ description="Current queue of unresolved SIEM/EDR alerts",
168
+ )
169
+ network_topology: NetworkTopology = Field(
170
+ default_factory=NetworkTopology,
171
+ description="Summary of the enterprise network state",
172
+ )
173
+ host_forensics: Optional[ForensicsResult] = Field(
174
+ default=None,
175
+ description="Forensics results if RunForensics was the last action, else None",
176
+ )
177
+ timeline: List[TimelineEntry] = Field(
178
+ default_factory=list,
179
+ description="Chronological log of all actions taken in this episode",
180
+ )
181
+ business_impact_score: float = Field(
182
+ default=0.0, ge=0.0, le=1.0,
183
+ description="Current business impact (0=no impact, 1=catastrophic outage)",
184
+ )
185
+ step_count: int = Field(default=0, ge=0, description="Current step number")
186
+ active_threats: List[str] = Field(
187
+ default_factory=list,
188
+ description="List of threat IDs that are still active/uncontained",
189
+ )
190
+ max_steps: int = Field(default=30, description="Maximum steps allowed in this episode")
191
+ task_id: str = Field(default="easy", description="Current task identifier")
192
+ total_reward: float = Field(default=0.0, description="Accumulated episode reward")
193
+ final_score: Optional[float] = Field(
194
+ default=None,
195
+ description="Post-episode grader score (0.0-1.0). Only set when done=True and plan submitted.",
196
+ )
197
+ grade_breakdown: Optional[Dict[str, Any]] = Field(
198
+ default=None,
199
+ description="Detailed grading breakdown. Only set when done=True and plan submitted.",
200
+ )
201
+
202
+
203
+ # =============================================================================
204
+ # Actions (Discriminated Union)
205
+ # =============================================================================
206
+
207
+
208
+ class QueryHost(Action):
209
+ """Query a specific host for status, processes, and connections."""
210
+ type: Literal["query_host"] = Field(default="query_host", description="Action discriminator")
211
+ hostname: str = Field(..., description="Target hostname to query")
212
+
213
+
214
+ class IsolateSegment(Action):
215
+ """Isolate an entire network segment from the rest of the network."""
216
+ type: Literal["isolate_segment"] = Field(default="isolate_segment", description="Action discriminator")
217
+ subnet: str = Field(..., description="Subnet name to isolate (e.g. 'finance', 'engineering')")
218
+ reason: str = Field(default="", description="Justification for isolation")
219
+
220
+
221
+ class BlockIOC(Action):
222
+ """Block an Indicator of Compromise at the perimeter firewall."""
223
+ type: Literal["block_ioc"] = Field(default="block_ioc", description="Action discriminator")
224
+ ioc_value: str = Field(..., description="The IOC to block (IP, domain, or file hash)")
225
+ ioc_type: Literal["ip", "domain", "hash"] = Field(..., description="Type of IOC")
226
+
227
+
228
+ class RunForensics(Action):
229
+ """Run deep forensic analysis on a specific host."""
230
+ type: Literal["run_forensics"] = Field(default="run_forensics", description="Action discriminator")
231
+ hostname: str = Field(..., description="Target hostname for forensics")
232
+
233
+
234
+ class KillProcess(Action):
235
+ """Terminate a specific process on a host."""
236
+ type: Literal["kill_process"] = Field(default="kill_process", description="Action discriminator")
237
+ hostname: str = Field(..., description="Host where the process is running")
238
+ process_name: str = Field(..., description="Name of the process to terminate")
239
+
240
+
241
+ class ContainmentEntry(BaseModel):
242
+ """A single entry in the containment plan."""
243
+ model_config = ConfigDict(extra="forbid")
244
+
245
+ threat_id: str = Field(..., description="Threat being addressed")
246
+ actions_taken: List[str] = Field(..., description="List of actions taken to contain this threat")
247
+ root_cause: str = Field(..., description="Identified root cause")
248
+ confidence: float = Field(
249
+ ..., ge=0.0, le=1.0,
250
+ description="Confidence in the containment (0-1)",
251
+ )
252
+
253
+
254
+ class SubmitContainmentPlan(Action):
255
+ """Submit the final containment plan to end the episode."""
256
+ type: Literal["submit_containment_plan"] = Field(
257
+ default="submit_containment_plan", description="Action discriminator"
258
+ )
259
+ plan: List[ContainmentEntry] = Field(
260
+ ..., description="The containment plan addressing all identified threats"
261
+ )
262
+ executive_summary: str = Field(
263
+ ..., description="Brief executive summary for CISO reporting"
264
+ )
265
+
266
+
267
+ # Discriminated union of all SOC actions
268
+ SOCAction = Annotated[
269
+ Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan],
270
+ Field(discriminator="type"),
271
+ ]
272
+
273
+ # Wrapper model so OpenEnv's create_app can accept it as a single Action class
274
+ class SOCActionWrapper(Action):
275
+ """Wrapper that deserializes the discriminated union action.
276
+
277
+ OpenEnv's create_app expects a single Action subclass. This wrapper
278
+ uses a discriminated union field so the HTTP/WS layer can parse
279
+ any of the 6 action types from a flat JSON payload.
280
+
281
+ Client sends: {"action": {"type": "query_host", "hostname": "WS-001"}}
282
+ The wrapper validates -> QueryHost(hostname="WS-001")
283
+ """
284
+ type: str = Field(..., description="Action type discriminator")
285
+
286
+ model_config = ConfigDict(extra="allow") # Allow action-specific fields
287
+
288
+ def to_typed_action(self) -> Union[QueryHost, IsolateSegment, BlockIOC, RunForensics, KillProcess, SubmitContainmentPlan]:
289
+ """Convert the raw wrapper into the correctly typed action."""
290
+ data = self.model_dump(exclude={"metadata"})
291
+ action_map = {
292
+ "query_host": QueryHost,
293
+ "isolate_segment": IsolateSegment,
294
+ "block_ioc": BlockIOC,
295
+ "run_forensics": RunForensics,
296
+ "kill_process": KillProcess,
297
+ "submit_containment_plan": SubmitContainmentPlan,
298
+ }
299
+ cls = action_map.get(data["type"])
300
+ if cls is None:
301
+ raise ValueError(
302
+ f"Unknown action type: {data['type']}. "
303
+ f"Valid types: {list(action_map.keys())}"
304
+ )
305
+ return cls(**data)
306
+
307
+
308
+ # =============================================================================
309
+ # Internal State (not exposed to agent directly)
310
+ # =============================================================================
311
+
312
+
313
+ class SOCState(State):
314
+ """Internal environment state tracking the attack simulation.
315
+
316
+ Extends OpenEnv State (inherits: episode_id, step_count).
317
+ Uses extra='allow' from base State.
318
+ """
319
+
320
+ task_id: str = Field(default="easy", description="Current task: 'easy', 'medium', or 'hard'")
321
+ max_steps: int = Field(default=30, description="Maximum steps for this episode")
322
+ total_reward: float = Field(default=0.0, description="Accumulated reward")
323
+ business_impact: float = Field(default=0.0, ge=0.0, le=1.0, description="Current business impact score")
324
+ contained_threats: List[str] = Field(default_factory=list, description="Threat IDs that have been contained")
325
+ active_threats: List[str] = Field(default_factory=list, description="Currently active threat IDs")
326
+ blocked_iocs: List[str] = Field(default_factory=list, description="IOCs blocked at perimeter")
327
+ isolated_subnets: List[str] = Field(default_factory=list, description="Isolated network segments")
328
+ forensics_run: List[str] = Field(default_factory=list, description="Hosts that had forensics run")
329
+ killed_processes: List[Dict[str, str]] = Field(default_factory=list, description="Processes killed")
330
+ queried_hosts: List[str] = Field(default_factory=list, description="Hosts queried")
331
+ timeline: List[Dict[str, Any]] = Field(default_factory=list, description="Action timeline")
332
+ is_done: bool = Field(default=False, description="Whether episode has ended")
333
+ submitted_plan: bool = Field(default=False, description="Whether containment plan was submitted")
openenv.yaml ADDED
The diff for this file is too large to render. See raw diff
 
openenv_play.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-play
3
+ Version: 0.1.0
4
+ Summary: Play environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.2
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_play.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./models.py
6
+ openenv_play.egg-info/PKG-INFO
7
+ openenv_play.egg-info/SOURCES.txt
8
+ openenv_play.egg-info/dependency_links.txt
9
+ openenv_play.egg-info/entry_points.txt
10
+ openenv_play.egg-info/requires.txt
11
+ openenv_play.egg-info/top_level.txt
12
+ server/__init__.py
13
+ server/app.py
14
+ server/play_environment.py
openenv_play.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_play.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = play.server.app:main
openenv_play.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
openenv_play.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ play
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-cybersocenv"
13
+ version = "0.1.0"
14
+ description = "CyberSOCEnv — Enterprise SOC Incident Response environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ "openenv-core[core]>=0.2.2",
19
+ # Inference dependencies
20
+ "openai>=1.0.0",
21
+ "websockets>=12.0",
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ dev = [
26
+ "pytest>=8.0.0",
27
+ "pytest-cov>=4.0.0",
28
+ ]
29
+
30
+ [project.scripts]
31
+ server = "play.server.app:main"
32
+
33
+ [tool.setuptools]
34
+ include-package-data = true
35
+ packages = ["play", "play.server"]
36
+ package-dir = { "play" = ".", "play.server" = "server" }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic
4
+ networkx
5
+ websockets
6
+ openai
7
+ tenacity
server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=play
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CyberSOCEnv server components."""
8
+
9
+ from .play_environment import CyberSOCEnvironment
10
+
11
+ __all__ = ["CyberSOCEnvironment"]
server/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (296 Bytes). View file
 
server/__pycache__/app.cpython-311.pyc ADDED
Binary file (2.4 kB). View file
 
server/__pycache__/graders.cpython-311.pyc ADDED
Binary file (7.07 kB). View file
 
server/__pycache__/play_environment.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
server/__pycache__/task_generator.cpython-311.pyc ADDED
Binary file (28.9 kB). View file
 
server/__pycache__/tasks.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
server/app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the CyberSOCEnv Environment.
9
+
10
+ Endpoints:
11
+ - POST /reset: Reset the environment (pass task_id in body)
12
+ - POST /step: Execute an action
13
+ - GET /state: Get current environment state
14
+ - GET /schema: Get action/observation schemas
15
+ - WS /ws: WebSocket endpoint for persistent sessions
16
+
17
+ Usage:
18
+ # Development (with auto-reload):
19
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
20
+
21
+ # Production:
22
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
23
+ """
24
+
25
+ try:
26
+ from openenv.core.env_server.http_server import create_app
27
+ except Exception as e: # pragma: no cover
28
+ raise ImportError(
29
+ "openenv is required. Install with: pip install 'openenv-core[core]'"
30
+ ) from e
31
+
32
+ try:
33
+ from ..models import SOCObservation, SOCActionWrapper
34
+ from .play_environment import CyberSOCEnvironment
35
+ except (ImportError, ModuleNotFoundError):
36
+ from models import SOCObservation, SOCActionWrapper
37
+ from server.play_environment import CyberSOCEnvironment
38
+
39
+
40
+ # Create the app with the CyberSOCEnv environment
41
+ app = create_app(
42
+ CyberSOCEnvironment,
43
+ SOCActionWrapper,
44
+ SOCObservation,
45
+ env_name="cybersocenv",
46
+ max_concurrent_envs=4,
47
+ )
48
+
49
+
50
+ def main(host: str = "0.0.0.0", port: int = 8000):
51
+ """Entry point for direct execution.
52
+
53
+ Usage:
54
+ python -m play.server.app
55
+ python -m play.server.app --port 8001
56
+ """
57
+ import uvicorn
58
+ uvicorn.run(app, host=host, port=port)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
server/graders.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Deterministic post-episode graders for CyberSOCEnv.
9
+
10
+ Each grader returns a float in [0.0, 1.0] based on how well the agent
11
+ contained the incident. Grading is entirely deterministic — same actions
12
+ always produce the same score.
13
+
14
+ Scoring breakdown:
15
+ - Threat containment (40%): Did the agent kill all malicious processes?
16
+ - IOC blocking (20%): Were critical IOCs blocked at the perimeter?
17
+ - Forensic coverage (15%): Were compromised hosts analyzed?
18
+ - Business impact (15%): Was unnecessary downtime avoided?
19
+ - Plan quality (10%): Did the final plan correctly identify root causes?
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import Any, Dict, List
25
+
26
+
27
+ def grade_episode(
28
+ task_id: str,
29
+ task_def: Dict[str, Any],
30
+ killed_processes: List[Dict[str, str]],
31
+ blocked_iocs: List[str],
32
+ forensics_run: List[str],
33
+ isolated_subnets: List[str],
34
+ submitted_plan: bool,
35
+ plan_entries: List[Dict[str, Any]],
36
+ final_business_impact: float,
37
+ step_count: int,
38
+ total_reward: float,
39
+ ) -> float:
40
+ """Grade an episode deterministically.
41
+
42
+ Args:
43
+ task_id: The task that was run.
44
+ task_def: The full task definition from tasks.py.
45
+ killed_processes: List of {"hostname": ..., "process": ...} killed.
46
+ blocked_iocs: List of IOC values that were blocked.
47
+ forensics_run: List of hostnames where forensics were executed.
48
+ isolated_subnets: List of subnet names that were isolated.
49
+ submitted_plan: Whether the agent submitted a containment plan.
50
+ plan_entries: The containment plan entries (list of dicts).
51
+ final_business_impact: The final business_impact_score at episode end.
52
+ step_count: Total steps taken.
53
+ total_reward: Accumulated trajectory reward.
54
+
55
+ Returns:
56
+ Float in [0.0, 1.0] — the final episode score.
57
+ """
58
+ requirements = task_def["containment_requirements"]
59
+ score = 0.0
60
+
61
+ # ---- 1. Threat Containment (40%) ----
62
+ must_kill = requirements["must_kill"]
63
+ if must_kill:
64
+ kills_matched = 0
65
+ for req in must_kill:
66
+ for k in killed_processes:
67
+ if k.get("hostname") == req["hostname"] and k.get("process") == req["process"]:
68
+ kills_matched += 1
69
+ break
70
+ containment_ratio = kills_matched / len(must_kill)
71
+ score += 0.40 * containment_ratio
72
+ else:
73
+ score += 0.40 # No kills required = full marks
74
+
75
+ # ---- 2. IOC Blocking (20%) ----
76
+ must_block = requirements["must_block_iocs"]
77
+ if must_block:
78
+ blocked_matched = sum(1 for ioc in must_block if ioc in blocked_iocs)
79
+ block_ratio = blocked_matched / len(must_block)
80
+ score += 0.20 * block_ratio
81
+ else:
82
+ score += 0.20
83
+
84
+ # ---- 3. Forensic Coverage (15%) ----
85
+ must_forensics = requirements["must_forensics"]
86
+ if must_forensics:
87
+ forensics_matched = sum(1 for h in must_forensics if h in forensics_run)
88
+ forensics_ratio = forensics_matched / len(must_forensics)
89
+ score += 0.15 * forensics_ratio
90
+ else:
91
+ score += 0.15
92
+
93
+ # ---- 4. Business Impact / Downtime (15%) ----
94
+ must_not_isolate = requirements.get("must_not_isolate", [])
95
+ unnecessary_isolations = sum(1 for s in isolated_subnets if s in must_not_isolate)
96
+
97
+ # Penalty for unnecessary isolations (each costs 5% of this category)
98
+ isolation_penalty = min(1.0, unnecessary_isolations * 0.33)
99
+ # Penalty for high business impact
100
+ impact_penalty = final_business_impact # 0.0 = perfect, 1.0 = catastrophic
101
+
102
+ downtime_score = max(0.0, 1.0 - isolation_penalty - impact_penalty * 0.5)
103
+ score += 0.15 * downtime_score
104
+
105
+ # ---- 5. Plan Quality (10%) ----
106
+ if submitted_plan and plan_entries:
107
+ # Check if plan addresses all attack chain threats
108
+ attack_threats = {t["threat_id"] for t in task_def["attack_chain"]}
109
+ plan_threats = {e.get("threat_id", "") for e in plan_entries}
110
+ threats_addressed = len(attack_threats & plan_threats)
111
+
112
+ if attack_threats:
113
+ plan_coverage = threats_addressed / len(attack_threats)
114
+ else:
115
+ plan_coverage = 1.0
116
+
117
+ # Average confidence of plan entries
118
+ confidences = [e.get("confidence", 0.0) for e in plan_entries]
119
+ avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
120
+
121
+ plan_score = 0.6 * plan_coverage + 0.4 * avg_confidence
122
+ score += 0.10 * plan_score
123
+ elif submitted_plan:
124
+ score += 0.02 # Submitted but empty plan
125
+ # else: no plan submitted = 0 for this category
126
+
127
+ # Clamp to [0.0, 1.0]
128
+ return round(max(0.0, min(1.0, score)), 4)
129
+
130
+
131
+ def grade_easy(
132
+ killed_processes: List[Dict[str, str]],
133
+ blocked_iocs: List[str],
134
+ forensics_run: List[str],
135
+ isolated_subnets: List[str],
136
+ submitted_plan: bool,
137
+ plan_entries: List[Dict[str, Any]],
138
+ final_business_impact: float,
139
+ step_count: int,
140
+ total_reward: float,
141
+ task_def: Dict[str, Any],
142
+ ) -> float:
143
+ """Grade the easy task."""
144
+ return grade_episode(
145
+ task_id="easy",
146
+ task_def=task_def,
147
+ killed_processes=killed_processes,
148
+ blocked_iocs=blocked_iocs,
149
+ forensics_run=forensics_run,
150
+ isolated_subnets=isolated_subnets,
151
+ submitted_plan=submitted_plan,
152
+ plan_entries=plan_entries,
153
+ final_business_impact=final_business_impact,
154
+ step_count=step_count,
155
+ total_reward=total_reward,
156
+ )
157
+
158
+
159
+ def grade_medium(
160
+ killed_processes: List[Dict[str, str]],
161
+ blocked_iocs: List[str],
162
+ forensics_run: List[str],
163
+ isolated_subnets: List[str],
164
+ submitted_plan: bool,
165
+ plan_entries: List[Dict[str, Any]],
166
+ final_business_impact: float,
167
+ step_count: int,
168
+ total_reward: float,
169
+ task_def: Dict[str, Any],
170
+ ) -> float:
171
+ """Grade the medium task."""
172
+ return grade_episode(
173
+ task_id="medium",
174
+ task_def=task_def,
175
+ killed_processes=killed_processes,
176
+ blocked_iocs=blocked_iocs,
177
+ forensics_run=forensics_run,
178
+ isolated_subnets=isolated_subnets,
179
+ submitted_plan=submitted_plan,
180
+ plan_entries=plan_entries,
181
+ final_business_impact=final_business_impact,
182
+ step_count=step_count,
183
+ total_reward=total_reward,
184
+ )
185
+
186
+
187
+ def grade_hard(
188
+ killed_processes: List[Dict[str, str]],
189
+ blocked_iocs: List[str],
190
+ forensics_run: List[str],
191
+ isolated_subnets: List[str],
192
+ submitted_plan: bool,
193
+ plan_entries: List[Dict[str, Any]],
194
+ final_business_impact: float,
195
+ step_count: int,
196
+ total_reward: float,
197
+ task_def: Dict[str, Any],
198
+ ) -> float:
199
+ """Grade the hard task."""
200
+ return grade_episode(
201
+ task_id="hard",
202
+ task_def=task_def,
203
+ killed_processes=killed_processes,
204
+ blocked_iocs=blocked_iocs,
205
+ forensics_run=forensics_run,
206
+ isolated_subnets=isolated_subnets,
207
+ submitted_plan=submitted_plan,
208
+ plan_entries=plan_entries,
209
+ final_business_impact=final_business_impact,
210
+ step_count=step_count,
211
+ total_reward=total_reward,
212
+ )
server/play_environment.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ CyberSOCEnv — Enterprise Cybersecurity Operations Center Environment.
9
+
10
+ Implements the OpenEnv Environment interface for a deterministic SOC
11
+ incident response simulation on a 500-node enterprise network.
12
+
13
+ The agent receives SIEM/EDR alerts, queries hosts, runs forensics,
14
+ isolates segments, blocks IOCs, kills processes, and submits a
15
+ containment plan — all while minimizing business downtime.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import copy
21
+ from typing import Any, Dict, List, Optional
22
+ from uuid import uuid4
23
+
24
+ from openenv.core.env_server.interfaces import Environment
25
+ from openenv.core.env_server.types import State
26
+
27
+ try:
28
+ from ..models import (
29
+ SOCObservation,
30
+ SOCActionWrapper,
31
+ SOCState,
32
+ Alert,
33
+ NetworkTopology,
34
+ ForensicsResult,
35
+ TimelineEntry,
36
+ QueryHost,
37
+ IsolateSegment,
38
+ BlockIOC,
39
+ RunForensics,
40
+ KillProcess,
41
+ SubmitContainmentPlan,
42
+ )
43
+ except ImportError:
44
+ from models import (
45
+ SOCObservation,
46
+ SOCActionWrapper,
47
+ SOCState,
48
+ Alert,
49
+ NetworkTopology,
50
+ ForensicsResult,
51
+ TimelineEntry,
52
+ QueryHost,
53
+ IsolateSegment,
54
+ BlockIOC,
55
+ RunForensics,
56
+ KillProcess,
57
+ SubmitContainmentPlan,
58
+ )
59
+
60
+ from .tasks import get_task, build_network
61
+ from .graders import grade_episode
62
+
63
+
64
+ class CyberSOCEnvironment(Environment):
65
+ """
66
+ Deterministic SOC incident response environment.
67
+
68
+ Simulates a 500-node enterprise network under attack. The agent must
69
+ investigate alerts, contain threats, and submit a containment plan
70
+ while minimizing business downtime.
71
+
72
+ Supports concurrent WebSocket sessions (each gets own instance).
73
+
74
+ Example:
75
+ >>> env = CyberSOCEnvironment()
76
+ >>> obs = env.reset(task_id="easy")
77
+ >>> print(len(obs.alert_queue)) # Initial alerts
78
+ >>> obs = env.step(SOCActionWrapper(type="query_host", hostname="WS-042"))
79
+ """
80
+
81
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
82
+
83
+ def __init__(self):
84
+ """Initialize the environment (actual state set in reset)."""
85
+ super().__init__()
86
+ self._state = SOCState(episode_id=str(uuid4()), step_count=0)
87
+ self._network: Dict[str, List[Dict[str, Any]]] = {}
88
+ self._task_def: Dict[str, Any] = {}
89
+ self._alert_queue: List[Dict[str, Any]] = []
90
+ self._host_index: Dict[str, Dict[str, Any]] = {} # hostname -> host dict
91
+ self._plan_entries: List[Dict[str, Any]] = []
92
+ self._last_forensics: Optional[ForensicsResult] = None
93
+
94
+ # ===========================================================================
95
+ # reset()
96
+ # ===========================================================================
97
+
98
+ def reset(
99
+ self,
100
+ seed: Optional[int] = None,
101
+ episode_id: Optional[str] = None,
102
+ **kwargs: Any,
103
+ ) -> SOCObservation:
104
+ """Reset the environment for a specific task.
105
+
106
+ Args:
107
+ seed: Ignored (environment is fully deterministic).
108
+ episode_id: Optional custom episode ID.
109
+ **kwargs: Must include task_id ('easy', 'medium', or 'hard').
110
+
111
+ Returns:
112
+ Initial SOCObservation with alert queue and network state.
113
+ """
114
+ task_id = kwargs.get("task_id", "easy")
115
+ self._task_def = get_task(task_id)
116
+
117
+ # Build deterministic network
118
+ self._network = build_network()
119
+
120
+ # Build hostname index for O(1) lookups
121
+ self._host_index = {}
122
+ for subnet_name, hosts in self._network.items():
123
+ for host in hosts:
124
+ self._host_index[host["hostname"]] = host
125
+
126
+ # Inject attack chain: mark compromised hosts, add malicious processes
127
+ for threat in self._task_def["attack_chain"]:
128
+ for hostname in threat["compromised_hosts"]:
129
+ if hostname in self._host_index:
130
+ host = self._host_index[hostname]
131
+ host["status"] = "compromised"
132
+ for proc in threat["malicious_processes"]:
133
+ if proc not in host["running_processes"]:
134
+ host["running_processes"].append(proc)
135
+
136
+ # Initialize alert queue (deep copy so mutations don't affect task def)
137
+ self._alert_queue = copy.deepcopy(self._task_def["initial_alerts"])
138
+
139
+ # Reset state
140
+ eid = episode_id or str(uuid4())
141
+ self._state = SOCState(
142
+ episode_id=eid,
143
+ step_count=0,
144
+ task_id=task_id,
145
+ max_steps=self._task_def["max_steps"],
146
+ total_reward=0.0,
147
+ business_impact=self._task_def["initial_business_impact"],
148
+ contained_threats=[],
149
+ active_threats=[t["threat_id"] for t in self._task_def["attack_chain"]],
150
+ blocked_iocs=[],
151
+ isolated_subnets=[],
152
+ forensics_run=[],
153
+ killed_processes=[],
154
+ queried_hosts=[],
155
+ timeline=[],
156
+ is_done=False,
157
+ submitted_plan=False,
158
+ )
159
+
160
+ self._plan_entries = []
161
+ self._last_forensics = None
162
+ self._reset_rubric()
163
+
164
+ return self._build_observation(reward=0.0, done=False)
165
+
166
+ # ===========================================================================
167
+ # step()
168
+ # ===========================================================================
169
+
170
+ def step(
171
+ self,
172
+ action: SOCActionWrapper, # type: ignore[override]
173
+ timeout_s: Optional[float] = None,
174
+ **kwargs: Any,
175
+ ) -> SOCObservation:
176
+ """Process one agent action.
177
+
178
+ Args:
179
+ action: SOCActionWrapper containing the typed action.
180
+ timeout_s: Ignored.
181
+
182
+ Returns:
183
+ SOCObservation with updated state, reward, and done flag.
184
+ """
185
+ if self._state.is_done:
186
+ return self._build_observation(reward=0.0, done=True)
187
+
188
+ # Increment step
189
+ self._state.step_count += 1
190
+
191
+ # Convert wrapper to typed action
192
+ typed_action = action.to_typed_action()
193
+
194
+ # Dispatch to handler
195
+ reward = 0.0
196
+ result_description = "unknown action"
197
+
198
+ if isinstance(typed_action, QueryHost):
199
+ reward, result_description = self._handle_query_host(typed_action)
200
+ elif isinstance(typed_action, IsolateSegment):
201
+ reward, result_description = self._handle_isolate_segment(typed_action)
202
+ elif isinstance(typed_action, BlockIOC):
203
+ reward, result_description = self._handle_block_ioc(typed_action)
204
+ elif isinstance(typed_action, RunForensics):
205
+ reward, result_description = self._handle_run_forensics(typed_action)
206
+ elif isinstance(typed_action, KillProcess):
207
+ reward, result_description = self._handle_kill_process(typed_action)
208
+ elif isinstance(typed_action, SubmitContainmentPlan):
209
+ reward, result_description = self._handle_submit_plan(typed_action)
210
+
211
+ # Business impact grows each step (attacker progresses)
212
+ if not self._state.is_done:
213
+ impact_rate = self._task_def.get("impact_per_step", 0.02)
214
+ # Reduce impact growth if threats are being contained
215
+ active_ratio = len(self._state.active_threats) / max(1, len(self._task_def["attack_chain"]))
216
+ self._state.business_impact = min(
217
+ 1.0,
218
+ self._state.business_impact + impact_rate * active_ratio,
219
+ )
220
+
221
+ # Record timeline
222
+ self._state.timeline.append({
223
+ "step": self._state.step_count,
224
+ "action_type": typed_action.type,
225
+ "target": self._get_action_target(typed_action),
226
+ "result": result_description,
227
+ "reward": reward,
228
+ })
229
+
230
+ # Accumulate reward
231
+ self._state.total_reward += reward
232
+
233
+ # Check termination
234
+ done = False
235
+ if self._state.submitted_plan:
236
+ done = True
237
+ self._state.is_done = True
238
+ elif self._state.step_count >= self._state.max_steps:
239
+ done = True
240
+ self._state.is_done = True
241
+ reward -= 0.20 # Penalty for running out of time
242
+ self._state.total_reward += (-0.20)
243
+
244
+ return self._build_observation(reward=reward, done=done)
245
+
246
+ # ===========================================================================
247
+ # Action Handlers (return (reward, description))
248
+ # ===========================================================================
249
+
250
+ def _handle_query_host(self, action: QueryHost) -> tuple[float, str]:
251
+ """Query a host for status info."""
252
+ hostname = action.hostname
253
+ self._last_forensics = None # Clear forensics from previous step
254
+
255
+ if hostname not in self._host_index:
256
+ return -0.05, f"Host '{hostname}' not found in network"
257
+
258
+ host = self._host_index[hostname]
259
+
260
+ # Reward for querying compromised hosts (useful investigation)
261
+ reward = 0.0
262
+ if host["status"] == "compromised" and hostname not in self._state.queried_hosts:
263
+ reward = 0.05 # Good: investigating a compromised host
264
+ elif hostname in self._state.queried_hosts:
265
+ reward = -0.02 # Penalty: re-querying same host wastes time
266
+
267
+ self._state.queried_hosts.append(hostname)
268
+
269
+ return reward, f"Queried {hostname}: status={host['status']}, procs={len(host['running_processes'])}"
270
+
271
+ def _handle_isolate_segment(self, action: IsolateSegment) -> tuple[float, str]:
272
+ """Isolate a network segment."""
273
+ subnet = action.subnet
274
+ self._last_forensics = None
275
+
276
+ if subnet not in self._network:
277
+ return -0.05, f"Subnet '{subnet}' does not exist"
278
+
279
+ if subnet in self._state.isolated_subnets:
280
+ return -0.02, f"Subnet '{subnet}' is already isolated"
281
+
282
+ # Isolate all hosts in the subnet
283
+ for host in self._network[subnet]:
284
+ host["status"] = "isolated"
285
+
286
+ self._state.isolated_subnets.append(subnet)
287
+
288
+ # Check if this contains any active threats
289
+ reward = 0.0
290
+ threats_contained = []
291
+ for threat in self._task_def["attack_chain"]:
292
+ if threat["threat_id"] in self._state.active_threats:
293
+ # Check if any compromised hosts are in this subnet
294
+ for ch in threat["compromised_hosts"]:
295
+ if ch in self._host_index and self._host_index[ch]["subnet"] == subnet:
296
+ threats_contained.append(threat["threat_id"])
297
+ break
298
+
299
+ if threats_contained:
300
+ reward = 0.15 * len(threats_contained) # Good: containing lateral movement
301
+ for tid in threats_contained:
302
+ if tid not in self._state.contained_threats:
303
+ self._state.contained_threats.append(tid)
304
+ if tid in self._state.active_threats:
305
+ self._state.active_threats.remove(tid)
306
+
307
+ # Check if this is an unnecessary isolation (business downtime)
308
+ must_not_isolate = self._task_def["containment_requirements"].get("must_not_isolate", [])
309
+ if subnet in must_not_isolate:
310
+ reward -= 0.10 # Penalty: unnecessary downtime
311
+ self._state.business_impact = min(1.0, self._state.business_impact + 0.08)
312
+
313
+ return reward, f"Isolated subnet '{subnet}'. Threats contained: {threats_contained}"
314
+
315
+ def _handle_block_ioc(self, action: BlockIOC) -> tuple[float, str]:
316
+ """Block an IOC at the perimeter."""
317
+ ioc = action.ioc_value
318
+ self._last_forensics = None
319
+
320
+ if ioc in self._state.blocked_iocs:
321
+ return -0.02, f"IOC '{ioc}' is already blocked"
322
+
323
+ self._state.blocked_iocs.append(ioc)
324
+
325
+ # Check if this IOC is relevant to any active threat
326
+ reward = 0.0
327
+ relevant = False
328
+ for threat in self._task_def["attack_chain"]:
329
+ all_iocs = (
330
+ threat["iocs"].get("hashes", [])
331
+ + threat["iocs"].get("ips", [])
332
+ + threat["iocs"].get("domains", [])
333
+ )
334
+ if ioc in all_iocs:
335
+ relevant = True
336
+ # Extra reward for blocking C2 server IPs
337
+ if ioc in threat.get("c2_servers", []):
338
+ reward += 0.15 # High value: cutting C2
339
+ else:
340
+ reward += 0.10 # Good: blocking relevant IOC
341
+ break
342
+
343
+ if not relevant:
344
+ reward = -0.03 # Noise: blocking irrelevant IOC
345
+
346
+ return reward, f"Blocked IOC '{ioc}' (type={action.ioc_type}). Relevant: {relevant}"
347
+
348
+ def _handle_run_forensics(self, action: RunForensics) -> tuple[float, str]:
349
+ """Run forensic analysis on a host."""
350
+ hostname = action.hostname
351
+
352
+ if hostname not in self._host_index:
353
+ self._last_forensics = None
354
+ return -0.05, f"Host '{hostname}' not found"
355
+
356
+ host = self._host_index[hostname]
357
+
358
+ # Build forensics result based on actual host state
359
+ is_compromised = host["status"] == "compromised"
360
+ malicious_procs = []
361
+ suspicious_files = []
362
+ network_conns = []
363
+ registry_mods = []
364
+ memory_artifacts = []
365
+
366
+ if is_compromised:
367
+ # Find which threat(s) affect this host
368
+ for threat in self._task_def["attack_chain"]:
369
+ if hostname in threat["compromised_hosts"]:
370
+ malicious_procs.extend(threat["malicious_processes"])
371
+ # Generate deterministic forensic artifacts
372
+ for proc in threat["malicious_processes"]:
373
+ suspicious_files.append(f"C:\\Windows\\Temp\\{proc}.dat")
374
+ registry_mods.append(f"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run\\{proc}")
375
+ for c2 in threat.get("c2_servers", []):
376
+ network_conns.append(f"{c2}:443")
377
+ for ioc_hash in threat["iocs"].get("hashes", []):
378
+ memory_artifacts.append(f"memory_inject_{ioc_hash[:8]}")
379
+
380
+ self._last_forensics = ForensicsResult(
381
+ hostname=hostname,
382
+ malicious_processes=malicious_procs,
383
+ suspicious_files=suspicious_files,
384
+ network_connections=network_conns,
385
+ registry_modifications=registry_mods,
386
+ memory_artifacts=memory_artifacts,
387
+ is_compromised=is_compromised,
388
+ )
389
+
390
+ # Reward
391
+ reward = 0.0
392
+ if hostname not in self._state.forensics_run:
393
+ if is_compromised:
394
+ reward = 0.10 # Good: found evidence
395
+ else:
396
+ reward = 0.02 # Cleared a host (some value)
397
+ self._state.forensics_run.append(hostname)
398
+ else:
399
+ reward = -0.02 # Re-running forensics wastes time
400
+
401
+ return reward, f"Forensics on {hostname}: compromised={is_compromised}, procs={malicious_procs}"
402
+
403
+ def _handle_kill_process(self, action: KillProcess) -> tuple[float, str]:
404
+ """Kill a process on a host."""
405
+ hostname = action.hostname
406
+ process = action.process_name
407
+ self._last_forensics = None
408
+
409
+ if hostname not in self._host_index:
410
+ return -0.05, f"Host '{hostname}' not found"
411
+
412
+ host = self._host_index[hostname]
413
+
414
+ if host["status"] == "isolated":
415
+ return -0.02, f"Host '{hostname}' is isolated — cannot interact"
416
+
417
+ if process not in host["running_processes"]:
418
+ return -0.03, f"Process '{process}' not running on {hostname}"
419
+
420
+ # Kill the process
421
+ host["running_processes"].remove(process)
422
+ self._state.killed_processes.append({"hostname": hostname, "process": process})
423
+
424
+ # Check if this was a malicious process
425
+ reward = 0.0
426
+ was_malicious = False
427
+ for threat in self._task_def["attack_chain"]:
428
+ if hostname in threat["compromised_hosts"] and process in threat["malicious_processes"]:
429
+ was_malicious = True
430
+ reward = 0.15 # Major reward: stopping malicious activity
431
+
432
+ # Check if all processes for this threat are killed
433
+ all_killed = True
434
+ for th_host in threat["compromised_hosts"]:
435
+ for th_proc in threat["malicious_processes"]:
436
+ still_running = (
437
+ th_host in self._host_index
438
+ and th_proc in self._host_index[th_host]["running_processes"]
439
+ )
440
+ if still_running:
441
+ all_killed = False
442
+ break
443
+
444
+ if all_killed and threat["threat_id"] in self._state.active_threats:
445
+ self._state.active_threats.remove(threat["threat_id"])
446
+ if threat["threat_id"] not in self._state.contained_threats:
447
+ self._state.contained_threats.append(threat["threat_id"])
448
+ reward += 0.10 # Bonus: fully contained a threat
449
+ break
450
+
451
+ if not was_malicious:
452
+ reward = -0.08 # Penalty: killing legitimate process = downtime
453
+ self._state.business_impact = min(1.0, self._state.business_impact + 0.03)
454
+
455
+ return reward, f"Killed '{process}' on {hostname}. Malicious: {was_malicious}"
456
+
457
+ def _handle_submit_plan(self, action: SubmitContainmentPlan) -> tuple[float, str]:
458
+ """Submit the final containment plan."""
459
+ self._last_forensics = None
460
+ self._state.submitted_plan = True
461
+ self._plan_entries = [entry.model_dump() for entry in action.plan]
462
+
463
+ # Grade the episode
464
+ final_score = grade_episode(
465
+ task_id=self._state.task_id,
466
+ task_def=self._task_def,
467
+ killed_processes=self._state.killed_processes,
468
+ blocked_iocs=self._state.blocked_iocs,
469
+ forensics_run=self._state.forensics_run,
470
+ isolated_subnets=self._state.isolated_subnets,
471
+ submitted_plan=True,
472
+ plan_entries=self._plan_entries,
473
+ final_business_impact=self._state.business_impact,
474
+ step_count=self._state.step_count,
475
+ total_reward=self._state.total_reward,
476
+ )
477
+
478
+ # Reward proportional to final grade
479
+ reward = final_score * 1.0 # Scale: perfect score = 1.0 reward
480
+ description = (
481
+ f"Containment plan submitted. "
482
+ f"Grade: {final_score:.3f}. "
483
+ f"Threats contained: {len(self._state.contained_threats)}/{len(self._task_def['attack_chain'])}. "
484
+ f"Business impact: {self._state.business_impact:.2f}"
485
+ )
486
+
487
+ return reward, description
488
+
489
+ # ===========================================================================
490
+ # Helpers
491
+ # ===========================================================================
492
+
493
+ def _build_observation(self, reward: float, done: bool) -> SOCObservation:
494
+ """Build the observation from current state."""
495
+ # Compute network topology summary
496
+ subnet_counts = {name: len(hosts) for name, hosts in self._network.items()}
497
+ compromised = sum(
498
+ 1 for hosts in self._network.values()
499
+ for h in hosts if h["status"] == "compromised"
500
+ )
501
+ isolated = sum(
502
+ 1 for hosts in self._network.values()
503
+ for h in hosts if h["status"] == "isolated"
504
+ )
505
+ total = sum(len(hosts) for hosts in self._network.values())
506
+
507
+ topology = NetworkTopology(
508
+ total_hosts=total,
509
+ subnets=subnet_counts,
510
+ compromised_count=compromised,
511
+ isolated_count=isolated,
512
+ online_count=total - compromised - isolated,
513
+ )
514
+
515
+ # Build alert list
516
+ alerts = [Alert(**a) for a in self._alert_queue]
517
+
518
+ # Build timeline
519
+ timeline = [
520
+ TimelineEntry(
521
+ step=t["step"],
522
+ action_type=t["action_type"],
523
+ target=t["target"],
524
+ result=t["result"],
525
+ reward=t["reward"],
526
+ )
527
+ for t in self._state.timeline
528
+ ]
529
+
530
+ # Compute final grade if done
531
+ final_score_val = None
532
+ grade_breakdown_val = None
533
+
534
+ if done and self._state.submitted_plan:
535
+ computed_score = grade_episode(
536
+ task_id=self._state.task_id,
537
+ task_def=self._task_def,
538
+ killed_processes=self._state.killed_processes,
539
+ blocked_iocs=self._state.blocked_iocs,
540
+ forensics_run=self._state.forensics_run,
541
+ isolated_subnets=self._state.isolated_subnets,
542
+ submitted_plan=self._state.submitted_plan,
543
+ plan_entries=self._plan_entries,
544
+ final_business_impact=self._state.business_impact,
545
+ step_count=self._state.step_count,
546
+ total_reward=self._state.total_reward,
547
+ )
548
+ final_score_val = round(computed_score, 4)
549
+ grade_breakdown_val = {
550
+ "threats_contained": len(self._state.contained_threats),
551
+ "total_threats": len(self._task_def["attack_chain"]),
552
+ "iocs_blocked": len(self._state.blocked_iocs),
553
+ "hosts_forensics": len(self._state.forensics_run),
554
+ "subnets_isolated": len(self._state.isolated_subnets),
555
+ "business_impact": round(self._state.business_impact, 4),
556
+ }
557
+
558
+ return SOCObservation(
559
+ alert_queue=alerts,
560
+ network_topology=topology,
561
+ host_forensics=self._last_forensics,
562
+ timeline=timeline,
563
+ business_impact_score=round(self._state.business_impact, 4),
564
+ step_count=self._state.step_count,
565
+ active_threats=list(self._state.active_threats),
566
+ max_steps=self._state.max_steps,
567
+ task_id=self._state.task_id,
568
+ total_reward=round(self._state.total_reward, 4),
569
+ final_score=final_score_val,
570
+ grade_breakdown=grade_breakdown_val,
571
+ done=done,
572
+ reward=round(reward, 4),
573
+ )
574
+
575
+ def _get_action_target(self, action: Any) -> str:
576
+ """Extract the target string from a typed action for timeline logging."""
577
+ if isinstance(action, QueryHost):
578
+ return action.hostname
579
+ elif isinstance(action, IsolateSegment):
580
+ return action.subnet
581
+ elif isinstance(action, BlockIOC):
582
+ return f"{action.ioc_type}:{action.ioc_value}"
583
+ elif isinstance(action, RunForensics):
584
+ return action.hostname
585
+ elif isinstance(action, KillProcess):
586
+ return f"{action.hostname}/{action.process_name}"
587
+ elif isinstance(action, SubmitContainmentPlan):
588
+ return f"{len(action.plan)} entries"
589
+ return "unknown"
590
+
591
+ @property
592
+ def state(self) -> SOCState:
593
+ """Get the current internal environment state."""
594
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ openai>=1.0.0
3
+ websockets>=12.0
server/task_generator.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Procedural Task Generator for CyberSOCEnv.
9
+
10
+ Generates 1000+ unique, deterministic attack scenarios from a task_id seed.
11
+ Each task_id (e.g. 'gen_0001') always produces the exact same scenario.
12
+
13
+ Design:
14
+ - hash(task_id) -> deterministic seed -> random.Random instance
15
+ - Seed drives ALL choices: attack type, hosts, processes, IOCs, alerts
16
+ - 12 attack categories, 50+ malware names, 40+ C2 domains
17
+ - 3 difficulty tiers based on task number
18
+
19
+ No actual randomness — reproducible across runs and platforms.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import hashlib
25
+ import random
26
+ from typing import Any, Dict, List, Tuple
27
+
28
+
29
+ # =============================================================================
30
+ # Template Pools (the "vocabulary" of the generator)
31
+ # =============================================================================
32
+
33
+ # --- Malware process names by category ---
34
+ MALWARE_PROCESSES = {
35
+ "ransomware": [
36
+ "cryptolocker.exe", "wannacry.exe", "blackcat_ransom.exe",
37
+ "lockbit3.exe", "revil_encrypt.exe", "hive_locker.exe",
38
+ "conti_crypt.exe", "ryuk_payload.exe", "maze_encrypt.exe",
39
+ "darkside_enc.exe", "babuk_lock.exe", "avaddon_crypt.exe",
40
+ ],
41
+ "phishing": [
42
+ "outlook_macro.exe", "word_dropper.exe", "macro_loader.exe",
43
+ "vba_agent.exe", "pdf_exploit.exe", "html_smuggler.exe",
44
+ "iso_mounter.exe", "lnk_runner.exe",
45
+ ],
46
+ "credential_theft": [
47
+ "mimikatz.exe", "lazagne.exe", "hashdump.exe",
48
+ "procdump_lsass.exe", "rubeus.exe", "kerbrute.exe",
49
+ "sharphound.exe", "bloodhound_collect.exe",
50
+ ],
51
+ "lateral_movement": [
52
+ "svchost_backdoor.exe", "psexec_svc.exe", "wmic_lateral.exe",
53
+ "rdp_hijack.exe", "ssh_brute.exe", "evil_winrm.exe",
54
+ "dcom_exec.exe", "smb_relay.exe",
55
+ ],
56
+ "c2_communication": [
57
+ "svchost_c2.exe", "cobalt_beacon.exe", "sliver_implant.exe",
58
+ "meterpreter.exe", "covenant_grunt.exe", "mythic_agent.exe",
59
+ "dns_tunnel.exe", "icmp_beacon.exe",
60
+ ],
61
+ "privilege_escalation": [
62
+ "exploit_kernel.exe", "potato_exploit.exe", "uac_bypass.exe",
63
+ "printspoofer.exe", "juicy_potato.exe", "named_pipe_exploit.exe",
64
+ "token_impersonate.exe", "dll_hijack.exe",
65
+ ],
66
+ "data_exfiltration": [
67
+ "data_pump.exe", "rclone_sync.exe", "mega_upload.exe",
68
+ "ftp_exfil.exe", "dns_exfil.exe", "cloud_sync_mal.exe",
69
+ "archive_send.exe", "stealer_agent.exe",
70
+ ],
71
+ "cryptomining": [
72
+ "xmrig_miner.exe", "ethminer.exe", "cpuminer.exe",
73
+ "nicehash_mal.exe", "coinhive_svc.exe", "monero_mine.exe",
74
+ ],
75
+ "supply_chain": [
76
+ "update_agent_mal.exe", "npm_backdoor.exe", "pip_trojan.exe",
77
+ "vscode_ext_mal.exe", "docker_implant.exe", "nuget_poison.exe",
78
+ ],
79
+ "insider_threat": [
80
+ "usb_copy.exe", "screen_capture.exe", "keylogger_svc.exe",
81
+ "email_forward.exe", "cloud_upload.exe", "print_spooler_mal.exe",
82
+ ],
83
+ "webshell": [
84
+ "cmd_webshell.php", "asp_backdoor.exe", "jsp_shell.exe",
85
+ "python_rshell.exe", "nodejs_shell.exe", "perl_cgi_shell.exe",
86
+ ],
87
+ "botnet": [
88
+ "mirai_bot.exe", "emotet_loader.exe", "trickbot_svc.exe",
89
+ "qbot_agent.exe", "dridex_dll.exe", "zloader_inject.exe",
90
+ ],
91
+ }
92
+
93
+ # --- C2 domains ---
94
+ C2_DOMAINS = [
95
+ "cdn-update.malware-c2.net", "api.darkc2.io", "telemetry-svc.ru",
96
+ "secure-update.evil.net", "cdn.payload-delivery.com", "api.shadownet.io",
97
+ "sync.cloud-c2.xyz", "update.legit-looking.com", "beacon.covert-ops.net",
98
+ "dns.tunnel-relay.org", "img.cdn-malware.com", "static.evil-cdn.net",
99
+ "api.stealthc2.io", "ws.encrypted-relay.net", "feed.darkweb-proxy.com",
100
+ "auth.phish-server.net", "login.fake-portal.com", "mail.spoof-relay.org",
101
+ "git.supply-chain.dev", "npm.compromised-pkg.io", "pypi.trojan-lib.org",
102
+ "dl.ransomware-pay.onion", "tor.exit-node-c2.net", "i2p.covert-chan.net",
103
+ "iot.botnet-c2.xyz", "cam.mirai-variant.net", "mqtt.iot-exploit.io",
104
+ "ftp.exfil-server.ru", "sftp.data-steal.com", "mega.cloud-drop.io",
105
+ "gist.code-exfil.dev", "paste.data-dump.xyz", "bin.steganography.net",
106
+ "vpn.tunnel-c2.com", "proxy.relay-beacon.org", "socks.covert-proxy.io",
107
+ "wpad.evil-config.net", "ntp.time-beacon.com", "ldap.ad-exploit.org",
108
+ "kerberos.ticket-steal.net",
109
+ ]
110
+
111
+ # --- C2 IPs (RFC 5737 documentation ranges + realistic-looking) ---
112
+ C2_IPS = [
113
+ "198.51.100.10", "198.51.100.22", "198.51.100.33", "198.51.100.44",
114
+ "198.51.100.55", "198.51.100.66", "198.51.100.77", "198.51.100.88",
115
+ "198.51.100.99", "198.51.100.110", "198.51.100.121", "198.51.100.132",
116
+ "203.0.113.10", "203.0.113.21", "203.0.113.32", "203.0.113.43",
117
+ "203.0.113.54", "203.0.113.65", "203.0.113.76", "203.0.113.87",
118
+ "203.0.113.98", "203.0.113.109", "203.0.113.120", "203.0.113.131",
119
+ "192.0.2.10", "192.0.2.21", "192.0.2.32", "192.0.2.43",
120
+ "192.0.2.54", "192.0.2.65", "192.0.2.76", "192.0.2.87",
121
+ "100.64.0.10", "100.64.0.22", "100.64.0.33", "100.64.0.44",
122
+ ]
123
+
124
+ # --- Subnet definitions (must match build_network() in tasks.py) ---
125
+ SUBNETS = {
126
+ "corporate": {"prefix": "WS", "count": 150, "criticality": 0.3},
127
+ "engineering": {"prefix": "DEV", "count": 100, "criticality": 0.5},
128
+ "finance": {"prefix": "FIN", "count": 50, "criticality": 0.8},
129
+ "dmz": {"prefix": "DMZ", "count": 30, "criticality": 0.6},
130
+ "datacenter": {"prefix": "SRV", "count": 50, "criticality": 0.9},
131
+ "executive": {"prefix": "EXEC", "count": 20, "criticality": 1.0},
132
+ }
133
+
134
+ # --- Attack phases in kill-chain order ---
135
+ ATTACK_PHASES = [
136
+ "initial_access", "execution", "persistence", "privilege_escalation",
137
+ "credential_access", "lateral_movement", "command_and_control",
138
+ "exfiltration", "impact",
139
+ ]
140
+
141
+ # --- Alert description templates ---
142
+ ALERT_TEMPLATES = {
143
+ "ransomware": [
144
+ "EDR detected file encryption activity on {host}. Process '{proc}' is encrypting files in user directories.",
145
+ "Anomalous file system activity: {count} files renamed with .{ext} extension in {secs} seconds on {host}.",
146
+ "Ransomware signature detected in process '{proc}' on {host}. Volume shadow copies being deleted.",
147
+ ],
148
+ "phishing": [
149
+ "User on {host} clicked suspicious link in email. {proc} execution detected downloading payload from {domain}.",
150
+ "Macro-enabled document opened on {host}. Outbound connection to {domain} detected.",
151
+ "Suspicious email attachment executed on {host}. Process '{proc}' spawned child processes.",
152
+ ],
153
+ "credential_theft": [
154
+ "LSASS memory access detected on {host} — possible credential dumping via {proc}.",
155
+ "Kerberos ticket request anomaly on {host}. Process '{proc}' attempting ticket manipulation.",
156
+ "SAM database access detected on {host}. Credential harvesting tool '{proc}' identified.",
157
+ ],
158
+ "lateral_movement": [
159
+ "Suspicious RDP login to {host} from compromised source using admin credentials. Process '{proc}' spawned.",
160
+ "SMB lateral movement detected: '{proc}' deployed on {host} via remote service creation.",
161
+ "WMI remote execution detected on {host}. Process '{proc}' launched from external host.",
162
+ ],
163
+ "c2_communication": [
164
+ "Periodic beaconing detected from {host} to {ip} every {interval} seconds. Encrypted payload exchange observed.",
165
+ "DNS tunneling activity from {host}. Suspicious queries to {domain} with encoded payloads.",
166
+ "Cobalt Strike beacon profile detected on {host}. Process '{proc}' communicating with {ip}.",
167
+ ],
168
+ "privilege_escalation": [
169
+ "Kernel exploit attempt on {host}. Process '{proc}' gained SYSTEM privileges.",
170
+ "UAC bypass detected on {host}. Process '{proc}' elevated to admin without user consent.",
171
+ "Token impersonation attack on {host}. Process '{proc}' obtained domain admin token.",
172
+ ],
173
+ "data_exfiltration": [
174
+ "Large data transfer ({size} GB) to external IP {ip} from {host}. Possible exfiltration of {data_type}.",
175
+ "Staging activity detected on {host}. Process '{proc}' archiving sensitive directories for extraction.",
176
+ "Cloud storage upload from {host} to unauthorized account. Process '{proc}' transferring {data_type}.",
177
+ ],
178
+ "cryptomining": [
179
+ "High CPU usage (98%) on {host}. Process '{proc}' identified as cryptocurrency miner.",
180
+ "Mining pool connection from {host} to {ip}:{port}. Process '{proc}' consuming all available cores.",
181
+ "Stratum protocol detected on {host}. Unauthorized mining process '{proc}' active.",
182
+ ],
183
+ "supply_chain": [
184
+ "Compromised package detected in CI/CD pipeline on {host}. Process '{proc}' executing post-install scripts.",
185
+ "Backdoored update agent on {host}. Process '{proc}' downloading payloads from {domain}.",
186
+ "Malicious dependency loaded on {host}. Process '{proc}' establishing covert communication channels.",
187
+ ],
188
+ "insider_threat": [
189
+ "Unusual data access pattern on {host}. Process '{proc}' accessing files outside user's normal scope.",
190
+ "USB mass storage device connected on {host}. Process '{proc}' copying sensitive files to removable media.",
191
+ "After-hours bulk file download on {host}. Process '{proc}' archiving {data_type} documents.",
192
+ ],
193
+ "webshell": [
194
+ "Web shell detected on {host}. Process '{proc}' executing system commands via HTTP POST requests.",
195
+ "Suspicious file upload on {host}. Process '{proc}' created in web-accessible directory with bash capabilities.",
196
+ "Remote code execution on {host}. Process '{proc}' spawned from web server with SYSTEM context.",
197
+ ],
198
+ "botnet": [
199
+ "Bot agent detected on {host}. Process '{proc}' joining command pool at {ip}.",
200
+ "DDoS toolkit loaded on {host}. Process '{proc}' ready to receive attack instructions from {domain}.",
201
+ "Worm propagation from {host}. Process '{proc}' scanning network for vulnerable hosts.",
202
+ ],
203
+ }
204
+
205
+ # --- Severity levels with weights ---
206
+ SEVERITIES = ["low", "medium", "high", "critical"]
207
+ SEVERITY_WEIGHTS = {"easy": [0.1, 0.4, 0.4, 0.1], "medium": [0.0, 0.2, 0.5, 0.3], "hard": [0.0, 0.1, 0.3, 0.6]}
208
+
209
+ # --- Data types for exfil descriptions ---
210
+ DATA_TYPES = [
211
+ "customer PII", "financial records", "employee credentials",
212
+ "source code", "trade secrets", "medical records",
213
+ "encryption keys", "database backups", "API tokens",
214
+ "board meeting minutes", "M&A documents", "patent filings",
215
+ ]
216
+
217
+ # --- File extensions for ransomware ---
218
+ RANSOM_EXTENSIONS = [
219
+ "locked", "encrypted", "crypted", "crypt", "enc", "pay",
220
+ "ransom", "darkside", "blackcat", "hive", "lockbit", "ryuk",
221
+ ]
222
+
223
+
224
+ # =============================================================================
225
+ # Deterministic Seed Helper
226
+ # =============================================================================
227
+
228
+ def _seed_from_task_id(task_id: str) -> int:
229
+ """Create a deterministic integer seed from a task_id string."""
230
+ h = hashlib.sha256(task_id.encode("utf-8")).hexdigest()
231
+ return int(h[:16], 16)
232
+
233
+
234
+ def _make_hash(rng: random.Random) -> str:
235
+ """Generate a fake MD5-like hash deterministically."""
236
+ return "".join(rng.choice("0123456789abcdef") for _ in range(32))
237
+
238
+
239
+ # =============================================================================
240
+ # Difficulty Classification
241
+ # =============================================================================
242
+
243
+ def _get_difficulty(task_id: str, rng: random.Random) -> str:
244
+ """Determine difficulty from task_id pattern or seed."""
245
+ # If task_id has an explicit difficulty prefix, use it
246
+ if task_id.startswith("easy_") or task_id.startswith("gen_easy_"):
247
+ return "easy"
248
+ if task_id.startswith("medium_") or task_id.startswith("gen_medium_"):
249
+ return "medium"
250
+ if task_id.startswith("hard_") or task_id.startswith("gen_hard_"):
251
+ return "hard"
252
+
253
+ # For gen_NNNN pattern, use number ranges
254
+ if task_id.startswith("gen_"):
255
+ try:
256
+ num = int(task_id.split("_")[1])
257
+ if num <= 333:
258
+ return "easy"
259
+ elif num <= 666:
260
+ return "medium"
261
+ else:
262
+ return "hard"
263
+ except (ValueError, IndexError):
264
+ pass
265
+
266
+ # Fallback: use seed-based distribution
267
+ return rng.choice(["easy", "medium", "hard"])
268
+
269
+
270
+ # =============================================================================
271
+ # Core Generator
272
+ # =============================================================================
273
+
274
+ def _pick_hosts(rng: random.Random, subnet: str, count: int) -> List[str]:
275
+ """Pick `count` unique host names from a subnet."""
276
+ info = SUBNETS[subnet]
277
+ prefix = info["prefix"]
278
+ max_idx = info["count"]
279
+ indices = rng.sample(range(1, max_idx + 1), min(count, max_idx))
280
+ return [f"{prefix}-{idx:03d}" for idx in indices]
281
+
282
+
283
+ def _pick_subnets(rng: random.Random, count: int) -> List[str]:
284
+ """Pick `count` unique subnet names."""
285
+ all_subnets = list(SUBNETS.keys())
286
+ return rng.sample(all_subnets, min(count, len(all_subnets)))
287
+
288
+
289
+ def _generate_threat(
290
+ rng: random.Random,
291
+ threat_id: str,
292
+ attack_type: str,
293
+ phase: str,
294
+ available_subnets: List[str],
295
+ used_hosts: set,
296
+ ) -> Tuple[Dict[str, Any], List[str]]:
297
+ """Generate a single threat in the attack chain.
298
+
299
+ Returns:
300
+ (threat_dict, list_of_compromised_hosts)
301
+ """
302
+ # Pick target subnet and hosts
303
+ subnet = rng.choice(available_subnets)
304
+ num_hosts = rng.randint(1, 3) if attack_type != "ransomware" else rng.randint(1, 2)
305
+
306
+ hosts = _pick_hosts(rng, subnet, num_hosts + 3) # Pick extra to avoid collisions
307
+ hosts = [h for h in hosts if h not in used_hosts][:num_hosts]
308
+ if not hosts:
309
+ # Fallback: pick from any subnet
310
+ fallback_subnet = rng.choice(list(SUBNETS.keys()))
311
+ hosts = _pick_hosts(rng, fallback_subnet, num_hosts + 5)
312
+ hosts = [h for h in hosts if h not in used_hosts][:max(1, num_hosts)]
313
+
314
+ # Pick malware process
315
+ procs = MALWARE_PROCESSES.get(attack_type, MALWARE_PROCESSES["lateral_movement"])
316
+ proc = rng.choice(procs)
317
+
318
+ # Generate IOCs
319
+ num_hashes = rng.randint(1, 2)
320
+ hashes = [_make_hash(rng) for _ in range(num_hashes)]
321
+
322
+ num_ips = rng.randint(0, 2) if attack_type in ("c2_communication", "data_exfiltration", "cryptomining", "botnet") else rng.randint(0, 1)
323
+ ips = rng.sample(C2_IPS, min(num_ips, len(C2_IPS))) if num_ips > 0 else []
324
+
325
+ num_domains = rng.randint(0, 2) if attack_type in ("c2_communication", "phishing", "supply_chain", "botnet") else rng.randint(0, 1)
326
+ domains = rng.sample(C2_DOMAINS, min(num_domains, len(C2_DOMAINS))) if num_domains > 0 else []
327
+
328
+ # C2 servers (subset of IPs for c2/exfil types)
329
+ c2_servers = ips[:1] if attack_type in ("c2_communication", "data_exfiltration", "botnet") else []
330
+
331
+ # Lateral targets (for movement-type threats)
332
+ lateral_targets: List[str] = []
333
+ if attack_type in ("lateral_movement", "credential_theft", "c2_communication"):
334
+ lat_subnet = rng.choice(list(SUBNETS.keys()))
335
+ lat_hosts = _pick_hosts(rng, lat_subnet, 2)
336
+ lateral_targets = [h for h in lat_hosts if h not in used_hosts and h not in hosts][:rng.randint(0, 2)]
337
+
338
+ # Exfil targets
339
+ exfil_targets: List[str] = []
340
+ if attack_type == "data_exfiltration":
341
+ exfil_targets = list(hosts)
342
+
343
+ threat = {
344
+ "threat_id": threat_id,
345
+ "threat_type": attack_type,
346
+ "phase": phase,
347
+ "compromised_hosts": hosts,
348
+ "malicious_processes": [proc],
349
+ "c2_servers": c2_servers,
350
+ "iocs": {
351
+ "hashes": hashes,
352
+ "ips": ips,
353
+ "domains": domains,
354
+ },
355
+ "lateral_targets": lateral_targets,
356
+ "exfil_targets": exfil_targets,
357
+ }
358
+
359
+ return threat, hosts
360
+
361
+
362
+ def _generate_alert(
363
+ rng: random.Random,
364
+ alert_idx: int,
365
+ task_prefix: str,
366
+ threat: Dict[str, Any],
367
+ timestamp_base: int,
368
+ ) -> Dict[str, Any]:
369
+ """Generate a single SIEM alert for a threat."""
370
+ attack_type = threat["threat_type"]
371
+ host = rng.choice(threat["compromised_hosts"])
372
+ proc = threat["malicious_processes"][0]
373
+
374
+ # Pick template
375
+ templates = ALERT_TEMPLATES.get(attack_type, ALERT_TEMPLATES["lateral_movement"])
376
+ template = rng.choice(templates)
377
+
378
+ # Fill template
379
+ description = template.format(
380
+ host=host,
381
+ proc=proc,
382
+ domain=rng.choice(threat["iocs"]["domains"]) if threat["iocs"]["domains"] else "unknown.example.com",
383
+ ip=rng.choice(threat["iocs"]["ips"]) if threat["iocs"]["ips"] else "0.0.0.0",
384
+ count=rng.randint(50, 500),
385
+ ext=rng.choice(RANSOM_EXTENSIONS),
386
+ secs=rng.randint(10, 120),
387
+ interval=rng.choice([30, 60, 90, 120, 300]),
388
+ size=round(rng.uniform(0.5, 15.0), 1),
389
+ data_type=rng.choice(DATA_TYPES),
390
+ port=rng.choice([3333, 4444, 5555, 8080, 8443, 9090]),
391
+ )
392
+
393
+ # Collect IOC indicators for the alert
394
+ ioc_indicators = []
395
+ if threat["iocs"]["hashes"]:
396
+ ioc_indicators.append(rng.choice(threat["iocs"]["hashes"]))
397
+ if threat["iocs"]["ips"]:
398
+ ioc_indicators.append(rng.choice(threat["iocs"]["ips"]))
399
+ if threat["iocs"]["domains"]:
400
+ ioc_indicators.append(rng.choice(threat["iocs"]["domains"]))
401
+
402
+ # Determine subnet from host prefix
403
+ subnet = "corporate"
404
+ for sn, info in SUBNETS.items():
405
+ if host.startswith(info["prefix"]):
406
+ subnet = sn
407
+ break
408
+
409
+ # Severity
410
+ severity_weights = SEVERITY_WEIGHTS.get(
411
+ "hard" if attack_type in ("data_exfiltration", "ransomware", "privilege_escalation") else "medium",
412
+ SEVERITY_WEIGHTS["medium"]
413
+ )
414
+ severity = rng.choices(SEVERITIES, weights=severity_weights, k=1)[0]
415
+
416
+ # Timestamp (spread across a few hours)
417
+ minutes_offset = timestamp_base + alert_idx * rng.randint(5, 45)
418
+ hour = 6 + (minutes_offset // 60)
419
+ minute = minutes_offset % 60
420
+ timestamp = f"2025-01-15T{hour:02d}:{minute:02d}:00Z"
421
+
422
+ return {
423
+ "alert_id": f"ALERT-{task_prefix}{alert_idx + 1:03d}",
424
+ "timestamp": timestamp,
425
+ "source_host": host,
426
+ "severity": severity,
427
+ "threat_type": attack_type,
428
+ "description": description,
429
+ "ioc_indicators": ioc_indicators,
430
+ "subnet": subnet,
431
+ "is_acknowledged": False,
432
+ }
433
+
434
+
435
+ # =============================================================================
436
+ # Main Generator Function
437
+ # =============================================================================
438
+
439
+ def generate_task(task_id: str) -> Dict[str, Any]:
440
+ """Generate a complete, deterministic task definition from a task_id.
441
+
442
+ The task_id is hashed to create a seed, ensuring the same task_id
443
+ always produces the exact same scenario.
444
+
445
+ Args:
446
+ task_id: Any string (e.g. 'gen_0001', 'gen_0500', 'phishing_test')
447
+
448
+ Returns:
449
+ A task_def dict compatible with CyberSOCEnvironment.reset()
450
+ """
451
+ seed = _seed_from_task_id(task_id)
452
+ rng = random.Random(seed)
453
+
454
+ # Determine difficulty
455
+ difficulty = _get_difficulty(task_id, rng)
456
+
457
+ # Configure parameters based on difficulty
458
+ if difficulty == "easy":
459
+ num_threats = 1
460
+ max_steps = rng.randint(12, 18)
461
+ initial_impact = round(rng.uniform(0.02, 0.08), 2)
462
+ impact_per_step = round(rng.uniform(0.01, 0.03), 3)
463
+ num_subnets = rng.randint(1, 2)
464
+ elif difficulty == "medium":
465
+ num_threats = rng.randint(2, 3)
466
+ max_steps = rng.randint(20, 28)
467
+ initial_impact = round(rng.uniform(0.08, 0.15), 2)
468
+ impact_per_step = round(rng.uniform(0.02, 0.04), 3)
469
+ num_subnets = rng.randint(2, 4)
470
+ else: # hard
471
+ num_threats = rng.randint(3, 6)
472
+ max_steps = rng.randint(25, 35)
473
+ initial_impact = round(rng.uniform(0.15, 0.25), 2)
474
+ impact_per_step = round(rng.uniform(0.03, 0.05), 3)
475
+ num_subnets = rng.randint(3, 6)
476
+
477
+ # Pick attack types for this scenario
478
+ all_attack_types = list(MALWARE_PROCESSES.keys())
479
+ if difficulty == "easy":
480
+ # Easy: single focused attack
481
+ attack_types = [rng.choice(all_attack_types)]
482
+ elif difficulty == "medium":
483
+ # Medium: multi-stage, pick a plausible chain
484
+ chains = [
485
+ ["phishing", "credential_theft", "lateral_movement"],
486
+ ["phishing", "c2_communication", "data_exfiltration"],
487
+ ["webshell", "privilege_escalation", "lateral_movement"],
488
+ ["supply_chain", "c2_communication", "credential_theft"],
489
+ ["botnet", "cryptomining", "lateral_movement"],
490
+ ["insider_threat", "data_exfiltration"],
491
+ ]
492
+ chain = rng.choice(chains)
493
+ attack_types = chain[:num_threats]
494
+ else:
495
+ # Hard: complex multi-phase APT
496
+ chains = [
497
+ ["phishing", "c2_communication", "privilege_escalation", "data_exfiltration", "ransomware"],
498
+ ["supply_chain", "c2_communication", "lateral_movement", "credential_theft", "data_exfiltration", "ransomware"],
499
+ ["webshell", "privilege_escalation", "c2_communication", "lateral_movement", "data_exfiltration"],
500
+ ["phishing", "credential_theft", "lateral_movement", "cryptomining", "botnet"],
501
+ ["insider_threat", "privilege_escalation", "data_exfiltration", "c2_communication"],
502
+ ["botnet", "lateral_movement", "privilege_escalation", "ransomware", "data_exfiltration"],
503
+ ]
504
+ chain = rng.choice(chains)
505
+ attack_types = chain[:num_threats]
506
+
507
+ # Pick subnets involved
508
+ involved_subnets = _pick_subnets(rng, num_subnets)
509
+
510
+ # Generate attack chain
511
+ attack_chain: List[Dict[str, Any]] = []
512
+ used_hosts: set = set()
513
+ task_prefix = task_id.replace("gen_", "G").upper()[:6]
514
+
515
+ for i, attack_type in enumerate(attack_types):
516
+ phase_idx = min(i, len(ATTACK_PHASES) - 1)
517
+ # Use realistic phase based on attack type
518
+ phase_map = {
519
+ "phishing": "initial_access",
520
+ "webshell": "initial_access",
521
+ "supply_chain": "initial_access",
522
+ "credential_theft": "credential_access",
523
+ "privilege_escalation": "privilege_escalation",
524
+ "lateral_movement": "lateral_movement",
525
+ "c2_communication": "command_and_control",
526
+ "data_exfiltration": "exfiltration",
527
+ "ransomware": "impact",
528
+ "cryptomining": "impact",
529
+ "insider_threat": "exfiltration",
530
+ "botnet": "command_and_control",
531
+ }
532
+ phase = phase_map.get(attack_type, ATTACK_PHASES[phase_idx])
533
+
534
+ threat_id = f"T-{task_prefix}-{i + 1:03d}"
535
+ threat, new_hosts = _generate_threat(
536
+ rng, threat_id, attack_type, phase, involved_subnets, used_hosts
537
+ )
538
+ attack_chain.append(threat)
539
+ used_hosts.update(new_hosts)
540
+
541
+ # Generate alerts (1-2 per threat)
542
+ initial_alerts: List[Dict[str, Any]] = []
543
+ timestamp_base = rng.randint(0, 60)
544
+ for i, threat in enumerate(attack_chain):
545
+ num_alerts = rng.randint(1, 2)
546
+ for j in range(num_alerts):
547
+ alert = _generate_alert(
548
+ rng, len(initial_alerts), task_prefix, threat, timestamp_base
549
+ )
550
+ initial_alerts.append(alert)
551
+
552
+ # Generate containment requirements
553
+ must_kill = []
554
+ must_block_iocs = []
555
+ must_forensics = []
556
+ must_not_isolate = []
557
+
558
+ for threat in attack_chain:
559
+ for host in threat["compromised_hosts"]:
560
+ for proc in threat["malicious_processes"]:
561
+ must_kill.append({"hostname": host, "process": proc})
562
+ if host not in must_forensics:
563
+ must_forensics.append(host)
564
+
565
+ # Collect all IOCs as required blocks
566
+ for h in threat["iocs"]["hashes"]:
567
+ if h not in must_block_iocs:
568
+ must_block_iocs.append(h)
569
+ for ip in threat["iocs"]["ips"]:
570
+ if ip not in must_block_iocs:
571
+ must_block_iocs.append(ip)
572
+ for d in threat["iocs"]["domains"]:
573
+ if d not in must_block_iocs:
574
+ must_block_iocs.append(d)
575
+
576
+ # Subnets that should NOT be isolated (business-critical ones not in the attack)
577
+ non_involved = [s for s in SUBNETS if s not in involved_subnets]
578
+ if difficulty == "easy":
579
+ must_not_isolate = non_involved
580
+ elif difficulty == "medium":
581
+ must_not_isolate = [s for s in non_involved if SUBNETS[s]["criticality"] >= 0.8]
582
+
583
+ # Build description
584
+ type_names = list(set(t["threat_type"] for t in attack_chain))
585
+ host_count = len(used_hosts)
586
+ desc = (
587
+ f"[{difficulty.upper()}] {', '.join(type_names).replace('_', ' ').title()} "
588
+ f"across {host_count} host(s) in {', '.join(involved_subnets)}."
589
+ )
590
+
591
+ return {
592
+ "description": desc,
593
+ "max_steps": max_steps,
594
+ "initial_business_impact": initial_impact,
595
+ "impact_per_step": impact_per_step,
596
+ "attack_chain": attack_chain,
597
+ "initial_alerts": initial_alerts,
598
+ "optimal_actions": [
599
+ "run_forensics", "kill_process", "block_ioc", "submit_containment_plan"
600
+ ],
601
+ "containment_requirements": {
602
+ "must_kill": must_kill,
603
+ "must_block_iocs": must_block_iocs,
604
+ "must_forensics": must_forensics,
605
+ "must_not_isolate": must_not_isolate,
606
+ },
607
+ }
608
+
609
+
610
+ # =============================================================================
611
+ # Batch Generation (for openenv.yaml and validation)
612
+ # =============================================================================
613
+
614
+ def list_generated_task_ids(count: int = 1000) -> List[str]:
615
+ """Return the list of generated task IDs."""
616
+ return [f"gen_{i:04d}" for i in range(1, count + 1)]
617
+
618
+
619
+ def get_task_summary(task_id: str) -> Dict[str, str]:
620
+ """Get a short summary of a generated task (for openenv.yaml)."""
621
+ task_def = generate_task(task_id)
622
+ difficulty = _get_difficulty(task_id, random.Random(_seed_from_task_id(task_id)))
623
+ return {
624
+ "description": task_def["description"],
625
+ "max_steps": task_def["max_steps"],
626
+ "difficulty": difficulty,
627
+ }
server/tasks.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Deterministic task definitions for CyberSOCEnv.
9
+
10
+ Each task defines a fixed attack chain, network layout, and expected
11
+ containment actions. No randomness — every run of the same task_id
12
+ produces identical initial state.
13
+
14
+ Tasks:
15
+ - easy: Single ransomware endpoint on the corporate subnet.
16
+ - medium: Multi-stage lateral movement (phishing -> cred theft -> 3 subnets).
17
+ - hard: APT + ransomware with C2, exfiltration, and executive pressure.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import Any, Dict, List
23
+
24
+
25
+ # =============================================================================
26
+ # Network Topology Builder (deterministic, 500-node)
27
+ # =============================================================================
28
+
29
+ def _build_subnet(
30
+ name: str,
31
+ role: str,
32
+ prefix: str,
33
+ ip_base: str,
34
+ count: int,
35
+ start_idx: int,
36
+ criticality: float,
37
+ default_ports: List[int],
38
+ default_procs: List[str],
39
+ ) -> List[Dict[str, Any]]:
40
+ """Build a list of host dicts for a subnet."""
41
+ hosts = []
42
+ for i in range(count):
43
+ idx = start_idx + i
44
+ hosts.append({
45
+ "hostname": f"{prefix}-{idx:03d}",
46
+ "ip_address": f"{ip_base}.{idx}",
47
+ "subnet": name,
48
+ "role": role,
49
+ "status": "online",
50
+ "running_processes": list(default_procs),
51
+ "open_ports": list(default_ports),
52
+ "criticality": criticality,
53
+ })
54
+ return hosts
55
+
56
+
57
+ def build_network() -> Dict[str, List[Dict[str, Any]]]:
58
+ """Build the deterministic 500-node enterprise network.
59
+
60
+ Returns:
61
+ Dict mapping subnet name -> list of host dicts.
62
+ """
63
+ network: Dict[str, List[Dict[str, Any]]] = {}
64
+
65
+ # Corporate (150 workstations)
66
+ network["corporate"] = _build_subnet(
67
+ name="corporate", role="corporate", prefix="WS",
68
+ ip_base="10.1.1", count=150, start_idx=1,
69
+ criticality=0.3,
70
+ default_ports=[135, 445, 3389],
71
+ default_procs=["outlook.exe", "chrome.exe", "explorer.exe"],
72
+ )
73
+
74
+ # Engineering (100 dev machines)
75
+ network["engineering"] = _build_subnet(
76
+ name="engineering", role="engineering", prefix="DEV",
77
+ ip_base="10.2.1", count=100, start_idx=1,
78
+ criticality=0.5,
79
+ default_ports=[22, 443, 8080, 3389],
80
+ default_procs=["vscode.exe", "python.exe", "docker.exe", "git.exe"],
81
+ )
82
+
83
+ # Finance (50 machines)
84
+ network["finance"] = _build_subnet(
85
+ name="finance", role="finance", prefix="FIN",
86
+ ip_base="10.3.1", count=50, start_idx=1,
87
+ criticality=0.8,
88
+ default_ports=[443, 1433, 3389],
89
+ default_procs=["excel.exe", "sap.exe", "sqlcmd.exe"],
90
+ )
91
+
92
+ # DMZ (30 servers)
93
+ network["dmz"] = _build_subnet(
94
+ name="dmz", role="dmz", prefix="DMZ",
95
+ ip_base="10.4.1", count=30, start_idx=1,
96
+ criticality=0.6,
97
+ default_ports=[80, 443, 8443],
98
+ default_procs=["nginx", "node", "java"],
99
+ )
100
+
101
+ # Datacenter (50 servers)
102
+ network["datacenter"] = _build_subnet(
103
+ name="datacenter", role="datacenter", prefix="SRV",
104
+ ip_base="10.5.1", count=50, start_idx=1,
105
+ criticality=0.9,
106
+ default_ports=[22, 443, 5432, 6379, 9200],
107
+ default_procs=["postgres", "redis-server", "elasticsearch", "kubelet"],
108
+ )
109
+
110
+ # Executive (20 machines)
111
+ network["executive"] = _build_subnet(
112
+ name="executive", role="executive", prefix="EXEC",
113
+ ip_base="10.6.1", count=20, start_idx=1,
114
+ criticality=1.0,
115
+ default_ports=[443, 3389],
116
+ default_procs=["outlook.exe", "teams.exe", "chrome.exe"],
117
+ )
118
+
119
+ return network
120
+
121
+
122
+ # =============================================================================
123
+ # Attack Chain Definitions
124
+ # =============================================================================
125
+
126
+ TASKS: Dict[str, Dict[str, Any]] = {
127
+ # ----- EASY: Single ransomware endpoint -----
128
+ "easy": {
129
+ "description": "Ransomware detected on a single corporate workstation. Isolate and contain.",
130
+ "max_steps": 15,
131
+ "initial_business_impact": 0.05,
132
+ "impact_per_step": 0.02, # Impact grows slowly per step
133
+ "attack_chain": [
134
+ {
135
+ "threat_id": "T-EASY-001",
136
+ "threat_type": "ransomware",
137
+ "phase": "execution",
138
+ "compromised_hosts": ["WS-042"],
139
+ "malicious_processes": ["cryptolocker.exe"],
140
+ "c2_servers": [],
141
+ "iocs": {
142
+ "hashes": ["e99a18c428cb38d5f260853678922e03"],
143
+ "ips": [],
144
+ "domains": [],
145
+ },
146
+ "lateral_targets": [],
147
+ "exfil_targets": [],
148
+ },
149
+ ],
150
+ "initial_alerts": [
151
+ {
152
+ "alert_id": "ALERT-E001",
153
+ "timestamp": "2025-01-15T09:23:17Z",
154
+ "source_host": "WS-042",
155
+ "severity": "critical",
156
+ "threat_type": "ransomware",
157
+ "description": "EDR detected file encryption activity on WS-042. Process 'cryptolocker.exe' is encrypting files in C:\\Users\\jsmith\\Documents.",
158
+ "ioc_indicators": ["e99a18c428cb38d5f260853678922e03"],
159
+ "subnet": "corporate",
160
+ "is_acknowledged": False,
161
+ },
162
+ {
163
+ "alert_id": "ALERT-E002",
164
+ "timestamp": "2025-01-15T09:23:45Z",
165
+ "source_host": "WS-042",
166
+ "severity": "high",
167
+ "threat_type": "ransomware",
168
+ "description": "Anomalous file system activity: 147 files renamed with .locked extension in 28 seconds.",
169
+ "ioc_indicators": [],
170
+ "subnet": "corporate",
171
+ "is_acknowledged": False,
172
+ },
173
+ ],
174
+ # Optimal containment: kill process, run forensics, block hash, submit plan
175
+ "optimal_actions": ["kill_process", "run_forensics", "block_ioc", "submit_containment_plan"],
176
+ "containment_requirements": {
177
+ "must_kill": [{"hostname": "WS-042", "process": "cryptolocker.exe"}],
178
+ "must_block_iocs": ["e99a18c428cb38d5f260853678922e03"],
179
+ "must_forensics": ["WS-042"],
180
+ "must_not_isolate": ["finance", "engineering", "datacenter"], # Unnecessary isolation = downtime
181
+ },
182
+ },
183
+
184
+ # ----- MEDIUM: Multi-stage lateral movement -----
185
+ "medium": {
186
+ "description": "Phishing attack led to credential theft and lateral movement across 3 subnets.",
187
+ "max_steps": 25,
188
+ "initial_business_impact": 0.10,
189
+ "impact_per_step": 0.03,
190
+ "attack_chain": [
191
+ {
192
+ "threat_id": "T-MED-001",
193
+ "threat_type": "phishing",
194
+ "phase": "initial_access",
195
+ "compromised_hosts": ["WS-017"],
196
+ "malicious_processes": ["powershell.exe"],
197
+ "c2_servers": [],
198
+ "iocs": {
199
+ "hashes": ["d41d8cd98f00b204e9800998ecf8427e"],
200
+ "ips": [],
201
+ "domains": ["evil-login.example.com"],
202
+ },
203
+ "lateral_targets": [],
204
+ "exfil_targets": [],
205
+ },
206
+ {
207
+ "threat_id": "T-MED-002",
208
+ "threat_type": "credential_theft",
209
+ "phase": "credential_access",
210
+ "compromised_hosts": ["WS-017"],
211
+ "malicious_processes": ["mimikatz.exe"],
212
+ "c2_servers": [],
213
+ "iocs": {
214
+ "hashes": ["aabbccdd11223344eeff5566778899aa"],
215
+ "ips": [],
216
+ "domains": [],
217
+ },
218
+ "lateral_targets": ["DEV-033", "FIN-012"],
219
+ "exfil_targets": [],
220
+ },
221
+ {
222
+ "threat_id": "T-MED-003",
223
+ "threat_type": "lateral_movement",
224
+ "phase": "lateral_movement",
225
+ "compromised_hosts": ["DEV-033", "FIN-012"],
226
+ "malicious_processes": ["svchost_backdoor.exe"],
227
+ "c2_servers": [],
228
+ "iocs": {
229
+ "hashes": ["112233445566778899aabbccddeeff00"],
230
+ "ips": ["203.0.113.50"],
231
+ "domains": [],
232
+ },
233
+ "lateral_targets": ["SRV-005"],
234
+ "exfil_targets": [],
235
+ },
236
+ ],
237
+ "initial_alerts": [
238
+ {
239
+ "alert_id": "ALERT-M001",
240
+ "timestamp": "2025-01-15T08:15:00Z",
241
+ "source_host": "WS-017",
242
+ "severity": "medium",
243
+ "threat_type": "phishing",
244
+ "description": "User clicked suspicious link in email. PowerShell execution detected downloading payload from evil-login.example.com.",
245
+ "ioc_indicators": ["evil-login.example.com"],
246
+ "subnet": "corporate",
247
+ "is_acknowledged": False,
248
+ },
249
+ {
250
+ "alert_id": "ALERT-M002",
251
+ "timestamp": "2025-01-15T08:32:00Z",
252
+ "source_host": "WS-017",
253
+ "severity": "high",
254
+ "threat_type": "credential_theft",
255
+ "description": "LSASS memory access detected — possible credential dumping via Mimikatz.",
256
+ "ioc_indicators": ["aabbccdd11223344eeff5566778899aa"],
257
+ "subnet": "corporate",
258
+ "is_acknowledged": False,
259
+ },
260
+ {
261
+ "alert_id": "ALERT-M003",
262
+ "timestamp": "2025-01-15T09:05:00Z",
263
+ "source_host": "DEV-033",
264
+ "severity": "high",
265
+ "threat_type": "lateral_movement",
266
+ "description": "Suspicious RDP login from WS-017 using admin credentials. New process svchost_backdoor.exe spawned.",
267
+ "ioc_indicators": ["203.0.113.50", "112233445566778899aabbccddeeff00"],
268
+ "subnet": "engineering",
269
+ "is_acknowledged": False,
270
+ },
271
+ {
272
+ "alert_id": "ALERT-M004",
273
+ "timestamp": "2025-01-15T09:12:00Z",
274
+ "source_host": "FIN-012",
275
+ "severity": "critical",
276
+ "threat_type": "lateral_movement",
277
+ "description": "Unauthorized access to FIN-012 from compromised credentials. Backdoor process active.",
278
+ "ioc_indicators": ["112233445566778899aabbccddeeff00"],
279
+ "subnet": "finance",
280
+ "is_acknowledged": False,
281
+ },
282
+ ],
283
+ "optimal_actions": [
284
+ "query_host", "run_forensics", "kill_process", "block_ioc",
285
+ "isolate_segment", "run_forensics", "submit_containment_plan",
286
+ ],
287
+ "containment_requirements": {
288
+ "must_kill": [
289
+ {"hostname": "WS-017", "process": "powershell.exe"},
290
+ {"hostname": "WS-017", "process": "mimikatz.exe"},
291
+ {"hostname": "DEV-033", "process": "svchost_backdoor.exe"},
292
+ {"hostname": "FIN-012", "process": "svchost_backdoor.exe"},
293
+ ],
294
+ "must_block_iocs": [
295
+ "evil-login.example.com",
296
+ "203.0.113.50",
297
+ "d41d8cd98f00b204e9800998ecf8427e",
298
+ "aabbccdd11223344eeff5566778899aa",
299
+ "112233445566778899aabbccddeeff00",
300
+ ],
301
+ "must_forensics": ["WS-017", "DEV-033", "FIN-012"],
302
+ "must_not_isolate": ["executive", "datacenter"],
303
+ },
304
+ },
305
+
306
+ # ----- HARD: APT + Ransomware, C2, exfiltration, executive pressure -----
307
+ "hard": {
308
+ "description": "Advanced Persistent Threat with active C2 comms, data exfiltration in progress, and ransomware deployment imminent. Board is watching — minimize downtime.",
309
+ "max_steps": 30,
310
+ "initial_business_impact": 0.20,
311
+ "impact_per_step": 0.04,
312
+ "attack_chain": [
313
+ {
314
+ "threat_id": "T-HARD-001",
315
+ "threat_type": "phishing",
316
+ "phase": "initial_access",
317
+ "compromised_hosts": ["EXEC-003"],
318
+ "malicious_processes": ["outlook_macro.exe"],
319
+ "c2_servers": ["198.51.100.77"],
320
+ "iocs": {
321
+ "hashes": ["deadbeef0123456789abcdef01234567"],
322
+ "ips": ["198.51.100.77"],
323
+ "domains": ["cdn-update.malware-c2.net"],
324
+ },
325
+ "lateral_targets": ["WS-088"],
326
+ "exfil_targets": [],
327
+ },
328
+ {
329
+ "threat_id": "T-HARD-002",
330
+ "threat_type": "c2_communication",
331
+ "phase": "command_and_control",
332
+ "compromised_hosts": ["EXEC-003", "WS-088"],
333
+ "malicious_processes": ["svchost_c2.exe"],
334
+ "c2_servers": ["198.51.100.77"],
335
+ "iocs": {
336
+ "hashes": ["cafebabe9876543210fedcba98765432"],
337
+ "ips": ["198.51.100.77"],
338
+ "domains": ["cdn-update.malware-c2.net", "exfil.malware-c2.net"],
339
+ },
340
+ "lateral_targets": ["SRV-002", "FIN-008"],
341
+ "exfil_targets": [],
342
+ },
343
+ {
344
+ "threat_id": "T-HARD-003",
345
+ "threat_type": "privilege_escalation",
346
+ "phase": "privilege_escalation",
347
+ "compromised_hosts": ["SRV-002"],
348
+ "malicious_processes": ["exploit_kernel.exe"],
349
+ "c2_servers": ["198.51.100.77"],
350
+ "iocs": {
351
+ "hashes": ["1122334455667788aabbccddeeff0011"],
352
+ "ips": ["198.51.100.77"],
353
+ "domains": [],
354
+ },
355
+ "lateral_targets": ["SRV-010", "SRV-015"],
356
+ "exfil_targets": ["SRV-002"],
357
+ },
358
+ {
359
+ "threat_id": "T-HARD-004",
360
+ "threat_type": "data_exfiltration",
361
+ "phase": "exfiltration",
362
+ "compromised_hosts": ["SRV-002", "FIN-008"],
363
+ "malicious_processes": ["data_pump.exe"],
364
+ "c2_servers": ["198.51.100.77"],
365
+ "iocs": {
366
+ "hashes": ["ffeeddccbbaa99887766554433221100"],
367
+ "ips": ["198.51.100.77", "203.0.113.99"],
368
+ "domains": ["exfil.malware-c2.net"],
369
+ },
370
+ "lateral_targets": [],
371
+ "exfil_targets": ["SRV-002", "FIN-008"],
372
+ },
373
+ {
374
+ "threat_id": "T-HARD-005",
375
+ "threat_type": "ransomware",
376
+ "phase": "impact",
377
+ "compromised_hosts": ["SRV-010", "SRV-015"],
378
+ "malicious_processes": ["blackcat_ransom.exe"],
379
+ "c2_servers": [],
380
+ "iocs": {
381
+ "hashes": ["aabb0011ccdd2233eeff4455667788"],
382
+ "ips": [],
383
+ "domains": [],
384
+ },
385
+ "lateral_targets": [],
386
+ "exfil_targets": [],
387
+ },
388
+ ],
389
+ "initial_alerts": [
390
+ {
391
+ "alert_id": "ALERT-H001",
392
+ "timestamp": "2025-01-15T06:00:00Z",
393
+ "source_host": "EXEC-003",
394
+ "severity": "medium",
395
+ "threat_type": "phishing",
396
+ "description": "Executive VP opened macro-enabled document. Outbound connection to cdn-update.malware-c2.net detected.",
397
+ "ioc_indicators": ["cdn-update.malware-c2.net", "198.51.100.77"],
398
+ "subnet": "executive",
399
+ "is_acknowledged": False,
400
+ },
401
+ {
402
+ "alert_id": "ALERT-H002",
403
+ "timestamp": "2025-01-15T06:45:00Z",
404
+ "source_host": "WS-088",
405
+ "severity": "high",
406
+ "threat_type": "c2_communication",
407
+ "description": "Periodic beaconing detected to 198.51.100.77 every 60 seconds. Encrypted payload exchange observed.",
408
+ "ioc_indicators": ["198.51.100.77", "cafebabe9876543210fedcba98765432"],
409
+ "subnet": "corporate",
410
+ "is_acknowledged": False,
411
+ },
412
+ {
413
+ "alert_id": "ALERT-H003",
414
+ "timestamp": "2025-01-15T07:30:00Z",
415
+ "source_host": "SRV-002",
416
+ "severity": "critical",
417
+ "threat_type": "privilege_escalation",
418
+ "description": "Kernel exploit attempt on SRV-002 (database server). Process exploit_kernel.exe gained SYSTEM privileges.",
419
+ "ioc_indicators": ["1122334455667788aabbccddeeff0011"],
420
+ "subnet": "datacenter",
421
+ "is_acknowledged": False,
422
+ },
423
+ {
424
+ "alert_id": "ALERT-H004",
425
+ "timestamp": "2025-01-15T08:00:00Z",
426
+ "source_host": "SRV-002",
427
+ "severity": "critical",
428
+ "threat_type": "data_exfiltration",
429
+ "description": "Large data transfer (2.3 GB) to external IP 203.0.113.99 from database server SRV-002. Possible exfiltration of customer PII.",
430
+ "ioc_indicators": ["203.0.113.99", "exfil.malware-c2.net"],
431
+ "subnet": "datacenter",
432
+ "is_acknowledged": False,
433
+ },
434
+ {
435
+ "alert_id": "ALERT-H005",
436
+ "timestamp": "2025-01-15T08:10:00Z",
437
+ "source_host": "FIN-008",
438
+ "severity": "critical",
439
+ "threat_type": "data_exfiltration",
440
+ "description": "Financial records being staged for exfiltration on FIN-008. Process data_pump.exe accessing sensitive directories.",
441
+ "ioc_indicators": ["ffeeddccbbaa99887766554433221100"],
442
+ "subnet": "finance",
443
+ "is_acknowledged": False,
444
+ },
445
+ {
446
+ "alert_id": "ALERT-H006",
447
+ "timestamp": "2025-01-15T08:30:00Z",
448
+ "source_host": "SRV-010",
449
+ "severity": "critical",
450
+ "threat_type": "ransomware",
451
+ "description": "BlackCat ransomware deployment detected on SRV-010! File encryption starting on production storage.",
452
+ "ioc_indicators": ["aabb0011ccdd2233eeff4455667788"],
453
+ "subnet": "datacenter",
454
+ "is_acknowledged": False,
455
+ },
456
+ ],
457
+ "optimal_actions": [
458
+ "block_ioc", "kill_process", "run_forensics", "isolate_segment",
459
+ "kill_process", "block_ioc", "run_forensics", "kill_process",
460
+ "submit_containment_plan",
461
+ ],
462
+ "containment_requirements": {
463
+ "must_kill": [
464
+ {"hostname": "EXEC-003", "process": "outlook_macro.exe"},
465
+ {"hostname": "EXEC-003", "process": "svchost_c2.exe"},
466
+ {"hostname": "WS-088", "process": "svchost_c2.exe"},
467
+ {"hostname": "SRV-002", "process": "exploit_kernel.exe"},
468
+ {"hostname": "SRV-002", "process": "data_pump.exe"},
469
+ {"hostname": "FIN-008", "process": "data_pump.exe"},
470
+ {"hostname": "SRV-010", "process": "blackcat_ransom.exe"},
471
+ {"hostname": "SRV-015", "process": "blackcat_ransom.exe"},
472
+ ],
473
+ "must_block_iocs": [
474
+ "198.51.100.77",
475
+ "203.0.113.99",
476
+ "cdn-update.malware-c2.net",
477
+ "exfil.malware-c2.net",
478
+ "deadbeef0123456789abcdef01234567",
479
+ "cafebabe9876543210fedcba98765432",
480
+ ],
481
+ "must_forensics": ["EXEC-003", "WS-088", "SRV-002", "FIN-008", "SRV-010"],
482
+ "must_not_isolate": [], # In APT scenario, any isolation decision is valid
483
+ },
484
+ },
485
+ }
486
+
487
+
488
+ def get_task(task_id: str) -> Dict[str, Any]:
489
+ """Retrieve a task definition by ID.
490
+
491
+ Supports:
492
+ - 'easy', 'medium', 'hard': Hand-crafted curated benchmarks
493
+ - 'gen_0001' through 'gen_1000': Procedurally generated scenarios
494
+ - Any other string: Generated on-the-fly via seeded procedural generation
495
+
496
+ Args:
497
+ task_id: Task identifier string.
498
+
499
+ Returns:
500
+ Task definition dict.
501
+ """
502
+ # Check hand-crafted tasks first
503
+ if task_id in TASKS:
504
+ return TASKS[task_id]
505
+
506
+ # Fall back to procedural generation
507
+ try:
508
+ from .task_generator import generate_task
509
+ except ImportError:
510
+ from server.task_generator import generate_task
511
+
512
+ return generate_task(task_id)
513
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
validate_submission.sh ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -uo pipefail
3
+
4
+ DOCKER_BUILD_TIMEOUT=600
5
+ if [ -t 1 ]; then
6
+ RED='\033[0;31m'
7
+ GREEN='\033[0;32m'
8
+ YELLOW='\033[1;33m'
9
+ BOLD='\033[1m'
10
+ NC='\033[0m'
11
+ else
12
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
13
+ fi
14
+
15
+ run_with_timeout() {
16
+ local secs="$1"; shift
17
+ if command -v timeout &>/dev/null; then
18
+ timeout "$secs" "$@"
19
+ elif command -v gtimeout &>/dev/null; then
20
+ gtimeout "$secs" "$@"
21
+ else
22
+ "$@" &
23
+ local pid=$!
24
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
25
+ local watcher=$!
26
+ wait "$pid" 2>/dev/null
27
+ local rc=$?
28
+ kill "$watcher" 2>/dev/null
29
+ wait "$watcher" 2>/dev/null
30
+ return $rc
31
+ fi
32
+ }
33
+
34
+ portable_mktemp() {
35
+ local prefix="${1:-validate}"
36
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
37
+ }
38
+
39
+ CLEANUP_FILES=()
40
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
41
+ trap cleanup EXIT
42
+
43
+ PING_URL="${1:-}"
44
+ REPO_DIR="${2:-.}"
45
+
46
+ if [ -z "$PING_URL" ]; then
47
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
48
+ printf "\n"
49
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
50
+ printf " repo_dir Path to your repo (default: current directory)\n"
51
+ exit 1
52
+ fi
53
+
54
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
55
+ printf "Error: directory '%s' not found\n" "${2:-.}"
56
+ exit 1
57
+ fi
58
+ PING_URL="${PING_URL%/}"
59
+ export PING_URL
60
+ PASS=0
61
+
62
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
63
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
64
+ fail() { log "${RED}FAILED${NC} -- $1"; }
65
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
66
+ stop_at() {
67
+ printf "\n"
68
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
69
+ exit 1
70
+ }
71
+
72
+ printf "\n"
73
+ printf "${BOLD}========================================${NC}\n"
74
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
75
+ printf "${BOLD}========================================${NC}\n"
76
+ log "Repo: $REPO_DIR"
77
+ log "Ping URL: $PING_URL"
78
+ printf "\n"
79
+
80
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
81
+
82
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
83
+ CLEANUP_FILES+=("$CURL_OUTPUT")
84
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
85
+ -H "Content-Type: application/json" -d '{}' \
86
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
87
+
88
+ if [ "$HTTP_CODE" = "200" ]; then
89
+ pass "HF Space is live and responds to /reset"
90
+ elif [ "$HTTP_CODE" = "000" ]; then
91
+ fail "HF Space not reachable (connection failed or timed out)"
92
+ hint "Check your network connection and that the Space is running."
93
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
94
+ stop_at "Step 1"
95
+ else
96
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
97
+ hint "Make sure your Space is running and the URL is correct."
98
+ hint "Try opening $PING_URL in your browser first."
99
+ stop_at "Step 1"
100
+ fi
101
+
102
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
103
+
104
+ if ! command -v docker &>/dev/null; then
105
+ fail "docker command not found"
106
+ hint "Install Docker: https://docs.docker.com/get-docker/"
107
+ stop_at "Step 2"
108
+ fi
109
+
110
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
111
+ DOCKER_CONTEXT="$REPO_DIR"
112
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
113
+ DOCKER_CONTEXT="$REPO_DIR/server"
114
+ else
115
+ fail "No Dockerfile found in repo root or server/ directory"
116
+ stop_at "Step 2"
117
+ fi
118
+
119
+ log " Found Dockerfile in $DOCKER_CONTEXT"
120
+
121
+ BUILD_OK=false
122
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
123
+
124
+ if [ "$BUILD_OK" = true ]; then
125
+ pass "Docker build succeeded"
126
+ else
127
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
128
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
129
+ stop_at "Step 2"
130
+ fi
131
+
132
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
133
+
134
+ if ! command -v openenv &>/dev/null; then
135
+ fail "openenv command not found"
136
+ hint "Install it: pip install openenv-core"
137
+ stop_at "Step 3"
138
+ fi
139
+
140
+ VALIDATE_OK=false
141
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
142
+
143
+ if [ "$VALIDATE_OK" = true ]; then
144
+ pass "openenv validate passed"
145
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
146
+ else
147
+ fail "openenv validate failed"
148
+ printf "%s\n" "$VALIDATE_OUTPUT"
149
+ stop_at "Step 3"
150
+ fi
151
+
152
+ printf "\n"
153
+ printf "${BOLD}========================================${NC}\n"
154
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
155
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
156
+ printf "${BOLD}========================================${NC}\n"
157
+ printf "\n"
158
+
159
+ exit 0