Tusharp2006 commited on
Commit
02daf9d
Β·
1 Parent(s): 650ccdc
pyproject.toml CHANGED
@@ -82,9 +82,14 @@ Repository = "https://github.com/scalar/adaptive-alert-triage"
82
  [project.scripts]
83
  alert-triage = "adaptive_alert_triage.env:main"
84
  openenv = "adaptive_alert_triage.validate:main"
 
 
 
85
 
86
  [tool.setuptools.packages.find]
87
- where = ["src"]
 
 
88
 
89
  [tool.black]
90
  line-length = 100
 
82
  [project.scripts]
83
  alert-triage = "adaptive_alert_triage.env:main"
84
  openenv = "adaptive_alert_triage.validate:main"
85
+ server = "server.app:main"
86
+
87
+
88
 
89
  [tool.setuptools.packages.find]
90
+ where = ["src", "."]
91
+ include = ["adaptive_alert_triage*", "server*"]
92
+
93
 
94
  [tool.black]
95
  line-length = 100
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package
server/app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # Fix: Ensure 'src' is in sys.path so we can find 'adaptive_alert_triage' and 'openenv_shim'
6
+ _HERE = Path(__file__).resolve()
7
+ _REPO_ROOT = _HERE.parent.parent
8
+ _SRC = _REPO_ROOT / "src"
9
+ if _SRC.exists() and str(_SRC) not in sys.path:
10
+ sys.path.insert(0, str(_SRC))
11
+
12
+ from adaptive_alert_triage.server import app
13
+
14
+ import uvicorn
15
+ import os
16
+
17
+ def main():
18
+ """Main entry point for the server application."""
19
+ port = int(os.environ.get("PORT", 7860))
20
+ host = os.environ.get("HOST", "0.0.0.0")
21
+ uvicorn.run(app, host=host, port=port)
22
+
23
+ if __name__ == "__main__":
24
+ main()
src/adaptive_alert_triage/server.py CHANGED
@@ -1,709 +1,721 @@
1
- """
2
- FastAPI OpenEnv Server for Adaptive Alert Triage Environment β€” v0.3.1
3
-
4
- Root-cause fixes from v0.3.0:
5
- FIX 1 β€” "No active episode" on /agent/recommend
6
- FIX 2 β€” Queued alerts (real_alerts_queue) never appeared in env.alerts
7
- FIX 3 β€” alert.dict() / obs.dict() removed in Pydantic v2
8
- FIX 4 β€” task_score missing from info dict
9
- FIX 5 β€” real_alerts_queue dropped on /env/reset
10
- FIX 6 β€” state.system_load AttributeError
11
-
12
- New in v0.3.1 (pre-submission compliance):
13
- FIX 7 β€” Added POST /reset (OpenEnv spec requires top-level /reset endpoint)
14
- FIX 8 β€” Added POST /env/reset (alias without task_id, defaults to "hard")
15
- FIX 9 β€” Registered `openenv validate` CLI entry-point via pyproject.toml
16
- (see companion pyproject.toml fix)
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import asyncio
22
- import os
23
- import sys
24
- import traceback
25
- from collections import deque
26
- from typing import Any, Dict, List, Optional
27
-
28
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
29
- from fastapi.middleware.cors import CORSMiddleware
30
- from fastapi.responses import FileResponse
31
- from pydantic import BaseModel
32
-
33
- from .env import AdaptiveAlertTriageEnv
34
- from .models import Action, Observation, Reward
35
-
36
-
37
- # ── Try to load trained PPO agent (lazy import, server starts without it) ─────
38
- _PPO_AVAILABLE = False
39
- try:
40
- _project_root = os.path.dirname(
41
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
- )
43
- if _project_root not in sys.path:
44
- sys.path.insert(0, _project_root)
45
- from rl_agent import PPOTrainer, encode_state, _ACTION_NAMES # type: ignore
46
- _PPO_AVAILABLE = True
47
- except ImportError:
48
- _project_root = ""
49
-
50
-
51
- # ── Request / response models ─────────────────────────────────────────────────
52
-
53
- class IngestAlert(BaseModel):
54
- id: str
55
- visible_severity: float
56
- confidence: float
57
- type: str
58
-
59
-
60
- class StepRequest(BaseModel):
61
- alert_id: str
62
- action_type: str
63
-
64
-
65
- class ResetRequest(BaseModel):
66
- """Optional body for POST /reset β€” task_id defaults to 'hard'."""
67
- task_id: Optional[str] = "hard"
68
- seed: Optional[int] = None
69
-
70
-
71
- class HealthResponse(BaseModel):
72
- status: str
73
- env_ready: bool
74
- queue_size: int
75
-
76
-
77
- # ── Alert-type normaliser ─────────────────────────────────────────────────────
78
-
79
- _TYPE_REMAP: Dict[str, str] = {
80
- "cpu": "CPU", "cpu_spike": "CPU",
81
- "memory": "MEMORY", "memory_leak": "MEMORY",
82
- "disk": "DISK", "disk_full": "DISK",
83
- "network": "NETWORK", "net": "NETWORK", "network_latency": "NETWORK",
84
- "application": "APPLICATION", "app": "APPLICATION",
85
- "security": "SECURITY", "sec": "SECURITY",
86
- }
87
- _VALID = {"CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY"}
88
-
89
-
90
- def _norm(raw: str) -> str:
91
- return _TYPE_REMAP.get(raw.lower(), raw.upper()) if raw else "APPLICATION"
92
-
93
-
94
- # ── App ───────────────────────────────────────────────────────────────────────
95
-
96
- app = FastAPI(title="Adaptive Alert Triage RL Server", version="0.3.1")
97
- app.add_middleware(CORSMiddleware, allow_origins=["*"],
98
- allow_credentials=False, allow_methods=["*"], allow_headers=["*"])
99
-
100
- @app.middleware("http")
101
- async def log_requests(request, call_next):
102
- print(f"REQUEST: {request.method} {request.url}")
103
- return await call_next(request)
104
-
105
- # ── Global state ──────────────────────────────────────────────────────────────
106
-
107
- env: Optional[AdaptiveAlertTriageEnv] = None
108
- episode_scores: List[float] = []
109
- _ppo_agents: Dict[str, Any] = {} # task_id β†’ PPOTrainer
110
- _loop_task: Optional[asyncio.Task] = None
111
- _last_action: Optional[str] = None
112
- _step_correct: int = 0
113
- _step_total: int = 0
114
-
115
- STEP_INTERVAL = 1.0 # seconds between autonomous episode-loop steps
116
-
117
-
118
- # ── Score helpers ─────────────────────────────────────────────────────────────
119
-
120
- def _reset_score() -> None:
121
- global _step_correct, _step_total
122
- _step_correct = _step_total = 0
123
-
124
-
125
- def _tick(info: Dict) -> None:
126
- global _step_correct, _step_total
127
- _step_total += 1
128
- if info.get("action_correct", False):
129
- _step_correct += 1
130
-
131
-
132
- def _score() -> float:
133
- return _step_correct / _step_total if _step_total else 0.0
134
-
135
-
136
- # ── PPO helpers ─────────────────────────────────────────��─────────────────────
137
-
138
- def _load_ppo(task_id: str) -> Optional[Any]:
139
- if not _PPO_AVAILABLE:
140
- return None
141
- path = os.path.join(_project_root, "weights", f"ppo_{task_id}.json")
142
- if not os.path.exists(path):
143
- print(f" [PPO] weights not found: {path}")
144
- return None
145
- try:
146
- agent = PPOTrainer(task_id=task_id)
147
- agent.load(path)
148
- print(f" [PPO] loaded {path}")
149
- return agent
150
- except Exception as e:
151
- print(f" [PPO] load error: {e}")
152
- return None
153
-
154
-
155
- def _ppo_act() -> Optional[Action]:
156
- if not env or not env.alerts:
157
- return None
158
- agent = _ppo_agents.get(env.task_id)
159
- if agent is None:
160
- return None
161
- try:
162
- obs = Observation(
163
- alerts = list(env.alerts),
164
- system_load = getattr(env, "_last_system_load", 0.5),
165
- queue_length = len(env.alerts),
166
- time_remaining = env.max_steps - env.current_step,
167
- resource_budget=(
168
- env.max_investigations_per_step - env.investigations_used
169
- if env.max_investigations_per_step is not None else None
170
- ),
171
- episode_step = env.current_step,
172
- )
173
- return agent.act(obs)
174
- except Exception:
175
- return None
176
-
177
-
178
- def _rule_act() -> Optional[Action]:
179
- if not env or not env.alerts:
180
- return None
181
- top = max(env.alerts, key=lambda a: a.visible_severity)
182
- sev = top.visible_severity
183
- conf = top.confidence
184
- rem = (env.max_investigations_per_step - env.investigations_used
185
- if env.max_investigations_per_step is not None else None)
186
- if sev >= 0.75 and conf >= 0.60:
187
- atype = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE"
188
- elif conf < 0.30 or sev < 0.30:
189
- atype = "IGNORE"
190
- elif sev >= 0.55:
191
- atype = "ESCALATE"
192
- else:
193
- atype = "DELAY"
194
- return Action(alert_id=top.id, action_type=atype)
195
-
196
-
197
- # ── Always-live episode loop ──────────────────────────────────────────────────
198
-
199
- async def _episode_loop() -> None:
200
- global env, _last_action
201
-
202
- while True:
203
- try:
204
- if env is None:
205
- await asyncio.sleep(STEP_INTERVAL)
206
- continue
207
-
208
- if not env.alerts or env._is_terminal():
209
- if _step_total > 0:
210
- episode_scores.append(_score())
211
- _reset_score()
212
- env.reset()
213
-
214
- if not env.alerts:
215
- await asyncio.sleep(STEP_INTERVAL)
216
- continue
217
-
218
- import time
219
- if time.time() - globals().get("_last_manual_step_time", 0.0) < 5.0:
220
- await asyncio.sleep(STEP_INTERVAL)
221
- continue
222
-
223
- action = _ppo_act() or _rule_act()
224
- if action is None:
225
- await asyncio.sleep(STEP_INTERVAL)
226
- continue
227
-
228
- _last_action = action.action_type
229
- _, reward, done, info = env.step(action)
230
- _tick(info)
231
-
232
- if done:
233
- episode_scores.append(_score())
234
- if len(episode_scores) > 1000:
235
- episode_scores[:] = episode_scores[-1000:]
236
- _reset_score()
237
- env.reset()
238
-
239
- except Exception as exc:
240
- print(f"[episode_loop] {exc}")
241
-
242
- await asyncio.sleep(STEP_INTERVAL)
243
-
244
-
245
- # ── Startup / shutdown ────────────────────────────────────────────────────────
246
-
247
- def _restore_pristine_weights():
248
- import shutil
249
- pristine_dir = os.path.join(_project_root if _project_root else os.getcwd(), "weights_pristine")
250
- weights_dir = os.path.join(_project_root if _project_root else os.getcwd(), "weights")
251
-
252
- if not os.path.exists(pristine_dir):
253
- print(" [STARTUP] No pristine weights found, skipping restore.")
254
- return
255
-
256
- os.makedirs(weights_dir, exist_ok=True)
257
- for f in os.listdir(pristine_dir):
258
- if f.startswith("ppo_") and f.endswith(".json"):
259
- src = os.path.join(pristine_dir, f)
260
- dst = os.path.join(weights_dir, f)
261
- shutil.copy2(src, dst)
262
- print(f" [STARTUP] Restored pristine weights: {f}")
263
-
264
-
265
- @app.on_event("startup")
266
- async def startup():
267
- global env, _loop_task
268
-
269
- _restore_pristine_weights()
270
-
271
- env = AdaptiveAlertTriageEnv(task_id="hard")
272
- env.real_alerts_queue = deque(maxlen=50)
273
- env.reset()
274
-
275
- for tid in ("easy", "medium", "hard"):
276
- agent = _load_ppo(tid)
277
- if agent:
278
- _ppo_agents[tid] = agent
279
-
280
- _loop_task = asyncio.create_task(_episode_loop())
281
-
282
- print("βœ… Alert Triage RL Server v0.3.1")
283
- print(f" Active alerts : {len(env.alerts)}")
284
- print(f" PPO loaded : {list(_ppo_agents.keys()) or 'none (run train_rl.py first)'}")
285
- print(f" Episode loop : every {STEP_INTERVAL}s")
286
-
287
-
288
- @app.on_event("shutdown")
289
- async def shutdown():
290
- if _loop_task:
291
- _loop_task.cancel()
292
-
293
-
294
- # ── Health ────────────────────────────────────────────────────────────────────
295
-
296
- @app.get("/health", response_model=HealthResponse)
297
- async def health():
298
- return HealthResponse(
299
- status = "ok",
300
- env_ready = env is not None and bool(env.alerts),
301
- queue_size= len(env.real_alerts_queue) if env and hasattr(env, "real_alerts_queue") else 0,
302
- )
303
-
304
-
305
- @app.get("/metrics")
306
- async def metrics():
307
- if not env:
308
- return {"error": "not initialized"}
309
- mean = sum(episode_scores[-100:]) / len(episode_scores[-100:]) if episode_scores else 0.0
310
- delta = (mean - 0.61) * 100
311
- return {
312
- "mean_score": round(mean, 3),
313
- "vs_baseline": f"+{delta:.0f}%" if delta >= 0 else f"{delta:.0f}%",
314
- "active_alerts": len(env.alerts),
315
- "episodes_completed": len(episode_scores),
316
- "current_step_score": round(_score(), 3),
317
- "current_step": env.current_step,
318
- "last_action": _last_action,
319
- "queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0,
320
- "ppo_loaded": list(_ppo_agents.keys()),
321
- }
322
-
323
-
324
- # ── Alert ingestion ───────────────────────────────────────────────────────────
325
-
326
- @app.post("/ingest/alerts")
327
- async def ingest_one(alert: IngestAlert):
328
- if not env:
329
- return {"error": "not initialized"}
330
- if not hasattr(env, "real_alerts_queue"):
331
- env.real_alerts_queue = deque(maxlen=50)
332
- raw = alert.model_dump()
333
- raw["type"] = _norm(raw.get("type", "APPLICATION"))
334
- env.real_alerts_queue.appendleft(raw)
335
- return {
336
- "status": "queued", "queued": len(env.real_alerts_queue),
337
- "alert_id": alert.id, "resolved_type": raw["type"],
338
- "note": "Episode loop will process this within ~1s",
339
- }
340
-
341
-
342
- @app.post("/ingest/alert-batch")
343
- async def ingest_batch(alerts: List[IngestAlert]):
344
- if not env:
345
- return {"error": "not initialized"}
346
- if not hasattr(env, "real_alerts_queue"):
347
- env.real_alerts_queue = deque(maxlen=50)
348
- ingested = []
349
- for alert in alerts:
350
- raw = alert.model_dump()
351
- raw["type"] = _norm(raw.get("type", "APPLICATION"))
352
- env.real_alerts_queue.appendleft(raw)
353
- ingested.append({"alert_id": alert.id, "resolved_type": raw["type"]})
354
- return {"status": "queued", "queued": len(env.real_alerts_queue), "ingested": ingested}
355
-
356
-
357
- # ── Environment control ───────────────────────────────────────────────────────
358
-
359
- async def _do_reset(task_id: str = "hard", seed: Optional[int] = None) -> dict:
360
- """
361
- Shared reset logic used by all reset endpoints.
362
- Returns a dict suitable for JSON response.
363
- """
364
- global env
365
- if task_id not in ("easy", "medium", "hard"):
366
- return {"error": f"Invalid task_id '{task_id}'. Must be one of: easy, medium, hard"}
367
- try:
368
- saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None
369
- env = AdaptiveAlertTriageEnv(task_id=task_id)
370
- env.real_alerts_queue = saved if saved is not None else deque(maxlen=50)
371
- agent = _load_ppo(task_id)
372
- if agent:
373
- _ppo_agents[task_id] = agent
374
- obs = env.reset(seed=seed)
375
- _reset_score()
376
- return {"status": "reset", "task_id": task_id, "obs": obs.model_dump()}
377
- except Exception as e:
378
- return {"error": str(e), "traceback": traceback.format_exc()}
379
-
380
-
381
- # FIX 7 β€” Top-level /reset endpoint required by OpenEnv validator ping
382
- # The pre-submission checker does: POST $PING_URL/reset
383
- # This must return 200 and a valid Observation.
384
- @app.post("/reset")
385
- async def reset_top_level(request: Optional[ResetRequest] = None):
386
- """
387
- OpenEnv-required top-level reset endpoint.
388
-
389
- POST /reset
390
- Body (optional JSON): {"task_id": "easy"|"medium"|"hard", "seed": int}
391
-
392
- Returns the initial Observation for the new episode.
393
- This is the endpoint pinged by the pre-submission checker.
394
- """
395
- task_id = "hard"
396
- seed = None
397
- if request is not None:
398
- task_id = request.task_id or "hard"
399
- seed = request.seed
400
- return await _do_reset(task_id=task_id, seed=seed)
401
-
402
-
403
- # FIX 8 β€” /env/reset without a path parameter (alias, defaults to "hard")
404
- @app.post("/env/reset")
405
- async def reset_env_default(request: Optional[ResetRequest] = None):
406
- """
407
- Alias for /env/reset/{task_id} without requiring a path parameter.
408
- Accepts the same optional JSON body as /reset.
409
- """
410
- task_id = "hard"
411
- seed = None
412
- if request is not None:
413
- task_id = request.task_id or "hard"
414
- seed = request.seed
415
- return await _do_reset(task_id=task_id, seed=seed)
416
-
417
-
418
- @app.post("/env/reset/{task_id}")
419
- async def reset_env(task_id: str = "hard"):
420
- """Reset with explicit task_id in path (original endpoint, kept for compatibility)."""
421
- return await _do_reset(task_id=task_id)
422
-
423
-
424
- import time
425
- _last_manual_step_time = 0.0
426
-
427
- @app.post("/env/step")
428
- async def step_env(request: StepRequest):
429
- global episode_scores, _last_manual_step_time
430
- _last_manual_step_time = time.time()
431
-
432
- if not env:
433
- return {"error": "not initialized"}
434
- if request.action_type not in {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"}:
435
- return {"error": f"Invalid action '{request.action_type}'"}
436
- try:
437
- from rl_agent import encode_state # type: ignore
438
- old_obs = Observation(
439
- alerts = list(env.alerts),
440
- system_load = getattr(env, "_last_system_load", 0.5),
441
- queue_length = len(env.alerts),
442
- time_remaining = env.max_steps - env.current_step,
443
- resource_budget=(
444
- env.max_investigations_per_step - env.investigations_used
445
- if env.max_investigations_per_step is not None else None
446
- ),
447
- episode_step = env.current_step,
448
- )
449
-
450
- action = Action(alert_id=request.alert_id, action_type=request.action_type)
451
- obs, reward, done, info = env.step(action)
452
-
453
- agent = _ppo_agents.get(env.task_id)
454
- if agent is not None:
455
- agent.net.forward(encode_state(old_obs))
456
-
457
- _tick(info)
458
- s = _score()
459
- info["task_score"] = s
460
- if done:
461
- episode_scores.append(s)
462
- _reset_score()
463
- return {"obs": obs.model_dump(), "reward": reward.value,
464
- "done": done, "info": info, "score": s}
465
- except Exception as e:
466
- return {"error": str(e), "traceback": traceback.format_exc()}
467
-
468
-
469
- @app.get("/env/state")
470
- async def get_state():
471
- if not env:
472
- return {"error": "not initialized"}
473
- try:
474
- state = env.state()
475
- return {
476
- "visible_state": {
477
- "alerts": [a.model_dump() for a in env.alerts],
478
- "current_step": env.current_step,
479
- "max_steps": env.max_steps,
480
- "failures_count": env.failures_count,
481
- "system_load": state.observation.system_load,
482
- "queue_length": len(env.alerts),
483
- "task_id": env.task_id,
484
- "real_queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0,
485
- },
486
- "hidden_state": state.hidden_state,
487
- "cumulative_reward": state.cumulative_reward,
488
- }
489
- except Exception as e:
490
- return {"error": str(e), "traceback": traceback.format_exc()}
491
-
492
-
493
- # ── Agent recommendation ──────────────────────────────────────────────────────
494
-
495
- @app.get("/agent/recommend")
496
- async def recommend():
497
- if not env or not env.alerts:
498
- return {
499
- "error": "No alerts yet β€” episode loop is starting, retry in 2s",
500
- "active_alerts": len(env.alerts) if env else 0,
501
- }
502
-
503
- task_id = env.task_id
504
- top = max(env.alerts, key=lambda a: a.visible_severity)
505
-
506
- ppo = _ppo_agents.get(task_id)
507
- if ppo is not None:
508
- try:
509
- import numpy as np
510
- obs = Observation(
511
- alerts = list(env.alerts),
512
- system_load = getattr(env, "_last_system_load", 0.5),
513
- queue_length = len(env.alerts),
514
- time_remaining = env.max_steps - env.current_step,
515
- resource_budget=(
516
- env.max_investigations_per_step - env.investigations_used
517
- if env.max_investigations_per_step is not None else None
518
- ),
519
- episode_step = env.current_step,
520
- )
521
- s = encode_state(obs)
522
- old_h, old_c = ppo.net.h.copy(), ppo.net.c.copy()
523
- probs, val = ppo.net.forward(s)
524
- ppo.net.h, ppo.net.c = old_h, old_c
525
- idx = int(np.random.choice(4, p=probs))
526
- act = _ACTION_NAMES[idx]
527
- conf = round(float(probs[idx]) * 100, 1)
528
- return {
529
- "alert_id": top.id,
530
- "action_type": act,
531
- "reasoning": f"PPO ({conf:.1f}% confidence)",
532
- "source": "trained_ppo",
533
- "model_confidence": conf,
534
- "probabilities": {_ACTION_NAMES[i]: round(float(probs[i]), 4) for i in range(4)},
535
- "value_estimate": round(float(val), 3),
536
- "alert_severity": top.visible_severity,
537
- "alert_confidence": top.confidence,
538
- "alert_age": top.age,
539
- "alert_type": top.alert_type,
540
- "active_alerts": len(env.alerts),
541
- "episode_step": env.current_step,
542
- "task_id": task_id,
543
- }
544
- except Exception as exc:
545
- print(f"PPO recommend error: {exc}")
546
-
547
- # Rule-based fallback
548
- sev, conf = top.visible_severity, top.confidence
549
- rem = (env.max_investigations_per_step - env.investigations_used
550
- if env.max_investigations_per_step is not None else None)
551
- if sev >= 0.75 and conf >= 0.60:
552
- act = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE"
553
- elif conf < 0.30 or sev < 0.30:
554
- act = "IGNORE"
555
- elif sev >= 0.55:
556
- act = "ESCALATE"
557
- else:
558
- act = "DELAY"
559
-
560
- return {
561
- "alert_id": top.id, "action_type": act,
562
- "source": "rule_based",
563
- "alert_severity": sev, "alert_confidence": conf,
564
- "alert_type": top.alert_type, "active_alerts": len(env.alerts),
565
- "task_id": task_id,
566
- "hint": "Run `python train_rl.py --episodes 300` to load PPO weights",
567
- }
568
-
569
-
570
- @app.get("/agent/weights/{task_id}")
571
- async def download_weights(task_id: str):
572
- from fastapi import HTTPException
573
- path = os.path.join(_project_root if _project_root else os.getcwd(), "weights", f"ppo_{task_id}.json")
574
- if not os.path.exists(path):
575
- raise HTTPException(status_code=404, detail=f"No trained weights found for {task_id}")
576
- return FileResponse(path, media_type='application/json', filename=f"ppo_{task_id}.json")
577
-
578
-
579
- # ── WebSocket ─────────────────────────────────────────────────────────────────
580
-
581
- @app.websocket("/ws/train")
582
- async def ws_train(websocket: WebSocket):
583
- global env, episode_scores
584
- await websocket.accept()
585
- lc = lt = 0
586
- try:
587
- while True:
588
- data = await websocket.receive_json()
589
- if data.get("type") == "reset":
590
- tid = data.get("task_id", "hard")
591
- saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None
592
- env = AdaptiveAlertTriageEnv(task_id=tid)
593
- env.real_alerts_queue = saved or deque(maxlen=50)
594
- obs = env.reset()
595
- lc = lt = 0
596
- await websocket.send_json({"obs": obs.model_dump(), "task_id": tid})
597
- elif data.get("type") == "step":
598
- if not env:
599
- await websocket.send_json({"error": "Reset first"}); continue
600
- ad = data.get("action", {})
601
- act = Action(alert_id=ad.get("alert_id",""), action_type=ad.get("action_type","IGNORE"))
602
- obs, reward, done, info = env.step(act)
603
- lt += 1
604
- if info.get("action_correct", False): lc += 1
605
- s = lc / lt if lt else 0.0
606
- if done: episode_scores.append(s)
607
- info["task_score"] = s
608
- await websocket.send_json({
609
- "obs": obs.model_dump(), "reward": reward.value,
610
- "done": done, "info": info, "task_score": s,
611
- "action_correct": info.get("action_correct", False),
612
- "failures_this_step": info.get("failures_this_step", 0),
613
- })
614
- elif data.get("type") == "close":
615
- break
616
- except WebSocketDisconnect:
617
- pass
618
- except Exception as e:
619
- try: await websocket.send_json({"error": str(e)})
620
- except Exception: pass
621
-
622
-
623
- # ── Utility ───────────────────────────────────────────────────────────────────
624
-
625
- @app.get("/")
626
- async def root():
627
- return {
628
- "name": "Adaptive Alert Triage RL Server", "version": "0.3.1",
629
- "openenv_endpoints": {
630
- "reset": "POST /reset",
631
- "step": "POST /env/step",
632
- "state": "GET /env/state",
633
- "health": "GET /health",
634
- },
635
- "quick_start": [
636
- "1. python train_rl.py --episodes 300",
637
- "2. uvicorn src.adaptive_alert_triage.server:app --port 7860",
638
- "3. curl -X POST localhost:7860/reset",
639
- "4. curl localhost:7860/agent/recommend",
640
- ],
641
- }
642
-
643
-
644
- import threading
645
- import subprocess
646
-
647
- _training_proc = None
648
- _training_logs = []
649
-
650
- def _run_training(episodes: int):
651
- global _training_proc, _training_logs, _ppo_agents
652
- _training_logs = [f"Starting training with --episodes {episodes}..."]
653
- try:
654
- _training_proc = subprocess.Popen(
655
- [sys.executable, "train_rl.py", "--episodes", str(episodes)],
656
- stdout=subprocess.PIPE,
657
- stderr=subprocess.STDOUT,
658
- text=True,
659
- bufsize=1,
660
- cwd=_project_root if _project_root else os.getcwd()
661
- )
662
- for line in iter(_training_proc.stdout.readline, ''):
663
- if line:
664
- _training_logs.append(line.rstrip('\n'))
665
- if len(_training_logs) > 1000:
666
- _training_logs.pop(0)
667
- _training_proc.wait()
668
- _training_logs.append(f"Training finished with exit code {_training_proc.returncode}")
669
-
670
- if _training_proc.returncode == 0:
671
- for tid in ("easy", "medium", "hard"):
672
- agent = _load_ppo(tid)
673
- if agent:
674
- _ppo_agents[tid] = agent
675
- _training_logs.append("Successfully reloaded PPO weights for all tasks.")
676
- except Exception as e:
677
- _training_logs.append(f"Error starting training: {e}")
678
-
679
- @app.post("/train")
680
- async def start_training(episodes: int = 300):
681
- global _training_proc
682
- if _training_proc is not None and _training_proc.poll() is None:
683
- return {"status": "already running"}
684
- threading.Thread(target=_run_training, args=(episodes,), daemon=True).start()
685
- return {"status": "started"}
686
-
687
- @app.get("/train/status")
688
- async def get_training_status():
689
- global _training_proc, _training_logs
690
- is_running = _training_proc is not None and _training_proc.poll() is None
691
- return {"is_running": is_running, "logs": _training_logs}
692
-
693
- @app.get("/web")
694
- async def web_ui():
695
- import os
696
- dashboard_path = os.path.join(
697
- os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
698
- "dashboard.html"
699
- )
700
- return FileResponse(dashboard_path, media_type="text/html")
701
-
702
-
703
- @app.get("/tasks")
704
- async def list_tasks():
705
- return {"tasks": [
706
- {"id": "easy", "success_threshold": 0.70, "max_steps": 30},
707
- {"id": "medium", "success_threshold": 0.55, "max_steps": 40},
708
- {"id": "hard", "success_threshold": 0.50, "max_steps": 50},
709
- ]}
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI OpenEnv Server for Adaptive Alert Triage Environment β€” v0.3.1
3
+
4
+ Root-cause fixes from v0.3.0:
5
+ FIX 1 β€” "No active episode" on /agent/recommend
6
+ FIX 2 β€” Queued alerts (real_alerts_queue) never appeared in env.alerts
7
+ FIX 3 β€” alert.dict() / obs.dict() removed in Pydantic v2
8
+ FIX 4 β€” task_score missing from info dict
9
+ FIX 5 β€” real_alerts_queue dropped on /env/reset
10
+ FIX 6 β€” state.system_load AttributeError
11
+
12
+ New in v0.3.1 (pre-submission compliance):
13
+ FIX 7 β€” Added POST /reset (OpenEnv spec requires top-level /reset endpoint)
14
+ FIX 8 β€” Added POST /env/reset (alias without task_id, defaults to "hard")
15
+ FIX 9 β€” Registered `openenv validate` CLI entry-point via pyproject.toml
16
+ (see companion pyproject.toml fix)
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import asyncio
22
+ import os
23
+ import sys
24
+ import traceback
25
+ from collections import deque
26
+ from typing import Any, Dict, List, Optional
27
+
28
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import FileResponse
31
+ from pydantic import BaseModel
32
+
33
+ from .env import AdaptiveAlertTriageEnv
34
+ from .models import Action, Observation, Reward
35
+
36
+
37
+ # ── Try to load trained PPO agent (lazy import, server starts without it) ─────
38
+ _PPO_AVAILABLE = False
39
+ try:
40
+ _project_root = os.path.dirname(
41
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
+ )
43
+ if _project_root not in sys.path:
44
+ sys.path.insert(0, _project_root)
45
+ from rl_agent import PPOTrainer, encode_state, _ACTION_NAMES # type: ignore
46
+ _PPO_AVAILABLE = True
47
+ except ImportError:
48
+ _project_root = ""
49
+
50
+
51
+ # ── Request / response models ─────────────────────────────────────────────────
52
+
53
+ class IngestAlert(BaseModel):
54
+ id: str
55
+ visible_severity: float
56
+ confidence: float
57
+ type: str
58
+
59
+
60
+ class StepRequest(BaseModel):
61
+ alert_id: str
62
+ action_type: str
63
+
64
+
65
+ class ResetRequest(BaseModel):
66
+ """Optional body for POST /reset β€” task_id defaults to 'hard'."""
67
+ task_id: Optional[str] = "hard"
68
+ seed: Optional[int] = None
69
+
70
+
71
+ class HealthResponse(BaseModel):
72
+ status: str
73
+ env_ready: bool
74
+ queue_size: int
75
+
76
+
77
+ # ── Alert-type normaliser ─────────────────────────────────────────────────────
78
+
79
+ _TYPE_REMAP: Dict[str, str] = {
80
+ "cpu": "CPU", "cpu_spike": "CPU",
81
+ "memory": "MEMORY", "memory_leak": "MEMORY",
82
+ "disk": "DISK", "disk_full": "DISK",
83
+ "network": "NETWORK", "net": "NETWORK", "network_latency": "NETWORK",
84
+ "application": "APPLICATION", "app": "APPLICATION",
85
+ "security": "SECURITY", "sec": "SECURITY",
86
+ }
87
+ _VALID = {"CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY"}
88
+
89
+
90
+ def _norm(raw: str) -> str:
91
+ return _TYPE_REMAP.get(raw.lower(), raw.upper()) if raw else "APPLICATION"
92
+
93
+
94
+ # ── App ───────────────────────────────────────────────────────────────────────
95
+
96
+ app = FastAPI(title="Adaptive Alert Triage RL Server", version="0.3.1")
97
+ app.add_middleware(CORSMiddleware, allow_origins=["*"],
98
+ allow_credentials=False, allow_methods=["*"], allow_headers=["*"])
99
+
100
+ @app.middleware("http")
101
+ async def log_requests(request, call_next):
102
+ print(f"REQUEST: {request.method} {request.url}")
103
+ return await call_next(request)
104
+
105
+ # ── Global state ──────────────────────────────────────────────────────────────
106
+
107
+ env: Optional[AdaptiveAlertTriageEnv] = None
108
+ episode_scores: List[float] = []
109
+ _ppo_agents: Dict[str, Any] = {} # task_id β†’ PPOTrainer
110
+ _loop_task: Optional[asyncio.Task] = None
111
+ _last_action: Optional[str] = None
112
+ _step_correct: int = 0
113
+ _step_total: int = 0
114
+
115
+ STEP_INTERVAL = 1.0 # seconds between autonomous episode-loop steps
116
+
117
+
118
+ # ── Score helpers ─────────────────────────────────────────────────────────────
119
+
120
+ def _reset_score() -> None:
121
+ global _step_correct, _step_total
122
+ _step_correct = _step_total = 0
123
+
124
+
125
+ def _tick(info: Dict) -> None:
126
+ global _step_correct, _step_total
127
+ _step_total += 1
128
+ if info.get("action_correct", False):
129
+ _step_correct += 1
130
+
131
+
132
+ def _score() -> float:
133
+ return _step_correct / _step_total if _step_total else 0.0
134
+
135
+
136
+ # ── PPO helpers ───────────────────────────────────────────────────────────────
137
+
138
+ def _load_ppo(task_id: str) -> Optional[Any]:
139
+ if not _PPO_AVAILABLE:
140
+ return None
141
+ path = os.path.join(_project_root, "weights", f"ppo_{task_id}.json")
142
+ if not os.path.exists(path):
143
+ print(f" [PPO] weights not found: {path}")
144
+ return None
145
+ try:
146
+ agent = PPOTrainer(task_id=task_id)
147
+ agent.load(path)
148
+ print(f" [PPO] loaded {path}")
149
+ return agent
150
+ except Exception as e:
151
+ print(f" [PPO] load error: {e}")
152
+ return None
153
+
154
+
155
+ def _ppo_act() -> Optional[Action]:
156
+ if not env or not env.alerts:
157
+ return None
158
+ agent = _ppo_agents.get(env.task_id)
159
+ if agent is None:
160
+ return None
161
+ try:
162
+ obs = Observation(
163
+ alerts = list(env.alerts),
164
+ system_load = getattr(env, "_last_system_load", 0.5),
165
+ queue_length = len(env.alerts),
166
+ time_remaining = env.max_steps - env.current_step,
167
+ resource_budget=(
168
+ env.max_investigations_per_step - env.investigations_used
169
+ if env.max_investigations_per_step is not None else None
170
+ ),
171
+ episode_step = env.current_step,
172
+ )
173
+ return agent.act(obs)
174
+ except Exception:
175
+ return None
176
+
177
+
178
+ def _rule_act() -> Optional[Action]:
179
+ if not env or not env.alerts:
180
+ return None
181
+ top = max(env.alerts, key=lambda a: a.visible_severity)
182
+ sev = top.visible_severity
183
+ conf = top.confidence
184
+ rem = (env.max_investigations_per_step - env.investigations_used
185
+ if env.max_investigations_per_step is not None else None)
186
+ if sev >= 0.75 and conf >= 0.60:
187
+ atype = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE"
188
+ elif conf < 0.30 or sev < 0.30:
189
+ atype = "IGNORE"
190
+ elif sev >= 0.55:
191
+ atype = "ESCALATE"
192
+ else:
193
+ atype = "DELAY"
194
+ return Action(alert_id=top.id, action_type=atype)
195
+
196
+
197
+ # ── Always-live episode loop ──────────────────────────────────────────────────
198
+
199
+ async def _episode_loop() -> None:
200
+ global env, _last_action
201
+
202
+ while True:
203
+ try:
204
+ if env is None:
205
+ await asyncio.sleep(STEP_INTERVAL)
206
+ continue
207
+
208
+ if not env.alerts or env._is_terminal():
209
+ if _step_total > 0:
210
+ episode_scores.append(_score())
211
+ _reset_score()
212
+ env.reset()
213
+
214
+ if not env.alerts:
215
+ await asyncio.sleep(STEP_INTERVAL)
216
+ continue
217
+
218
+ import time
219
+ if time.time() - globals().get("_last_manual_step_time", 0.0) < 5.0:
220
+ await asyncio.sleep(STEP_INTERVAL)
221
+ continue
222
+
223
+ action = _ppo_act() or _rule_act()
224
+ if action is None:
225
+ await asyncio.sleep(STEP_INTERVAL)
226
+ continue
227
+
228
+ _last_action = action.action_type
229
+ _, reward, done, info = env.step(action)
230
+ _tick(info)
231
+
232
+ if done:
233
+ episode_scores.append(_score())
234
+ if len(episode_scores) > 1000:
235
+ episode_scores[:] = episode_scores[-1000:]
236
+ _reset_score()
237
+ env.reset()
238
+
239
+ except Exception as exc:
240
+ print(f"[episode_loop] {exc}")
241
+
242
+ await asyncio.sleep(STEP_INTERVAL)
243
+
244
+
245
+ # ── Startup / shutdown ────────────────────────────────────────────────────────
246
+
247
+ def _restore_pristine_weights():
248
+ import shutil
249
+ pristine_dir = os.path.join(_project_root if _project_root else os.getcwd(), "weights_pristine")
250
+ weights_dir = os.path.join(_project_root if _project_root else os.getcwd(), "weights")
251
+
252
+ if not os.path.exists(pristine_dir):
253
+ print(" [STARTUP] No pristine weights found, skipping restore.")
254
+ return
255
+
256
+ os.makedirs(weights_dir, exist_ok=True)
257
+ for f in os.listdir(pristine_dir):
258
+ if f.startswith("ppo_") and f.endswith(".json"):
259
+ src = os.path.join(pristine_dir, f)
260
+ dst = os.path.join(weights_dir, f)
261
+ shutil.copy2(src, dst)
262
+ print(f" [STARTUP] Restored pristine weights: {f}")
263
+
264
+
265
+ @app.on_event("startup")
266
+ async def startup():
267
+ global env, _loop_task
268
+
269
+ _restore_pristine_weights()
270
+
271
+ env = AdaptiveAlertTriageEnv(task_id="hard")
272
+ env.real_alerts_queue = deque(maxlen=50)
273
+ env.reset()
274
+
275
+ for tid in ("easy", "medium", "hard"):
276
+ agent = _load_ppo(tid)
277
+ if agent:
278
+ _ppo_agents[tid] = agent
279
+
280
+ _loop_task = asyncio.create_task(_episode_loop())
281
+
282
+ print("βœ… Alert Triage RL Server v0.3.1")
283
+ print(f" Active alerts : {len(env.alerts)}")
284
+ print(f" PPO loaded : {list(_ppo_agents.keys()) or 'none (run train_rl.py first)'}")
285
+ print(f" Episode loop : every {STEP_INTERVAL}s")
286
+
287
+
288
+ @app.on_event("shutdown")
289
+ async def shutdown():
290
+ if _loop_task:
291
+ _loop_task.cancel()
292
+
293
+
294
+ # ── Health ────────────────────────────────────────────────────────────────────
295
+
296
+ @app.get("/health", response_model=HealthResponse)
297
+ async def health():
298
+ return HealthResponse(
299
+ status = "ok",
300
+ env_ready = env is not None and bool(env.alerts),
301
+ queue_size= len(env.real_alerts_queue) if env and hasattr(env, "real_alerts_queue") else 0,
302
+ )
303
+
304
+
305
+ @app.get("/metrics")
306
+ async def metrics():
307
+ if not env:
308
+ return {"error": "not initialized"}
309
+ mean = sum(episode_scores[-100:]) / len(episode_scores[-100:]) if episode_scores else 0.0
310
+ delta = (mean - 0.61) * 100
311
+ return {
312
+ "mean_score": round(mean, 3),
313
+ "vs_baseline": f"+{delta:.0f}%" if delta >= 0 else f"{delta:.0f}%",
314
+ "active_alerts": len(env.alerts),
315
+ "episodes_completed": len(episode_scores),
316
+ "current_step_score": round(_score(), 3),
317
+ "current_step": env.current_step,
318
+ "last_action": _last_action,
319
+ "queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0,
320
+ "ppo_loaded": list(_ppo_agents.keys()),
321
+ }
322
+
323
+
324
+ # ── Alert ingestion ───────────────────────────────────────────────────────────
325
+
326
+ @app.post("/ingest/alerts")
327
+ async def ingest_one(alert: IngestAlert):
328
+ if not env:
329
+ return {"error": "not initialized"}
330
+ if not hasattr(env, "real_alerts_queue"):
331
+ env.real_alerts_queue = deque(maxlen=50)
332
+ raw = alert.model_dump()
333
+ raw["type"] = _norm(raw.get("type", "APPLICATION"))
334
+ env.real_alerts_queue.appendleft(raw)
335
+ return {
336
+ "status": "queued", "queued": len(env.real_alerts_queue),
337
+ "alert_id": alert.id, "resolved_type": raw["type"],
338
+ "note": "Episode loop will process this within ~1s",
339
+ }
340
+
341
+
342
+ @app.post("/ingest/alert-batch")
343
+ async def ingest_batch(alerts: List[IngestAlert]):
344
+ if not env:
345
+ return {"error": "not initialized"}
346
+ if not hasattr(env, "real_alerts_queue"):
347
+ env.real_alerts_queue = deque(maxlen=50)
348
+ ingested = []
349
+ for alert in alerts:
350
+ raw = alert.model_dump()
351
+ raw["type"] = _norm(raw.get("type", "APPLICATION"))
352
+ env.real_alerts_queue.appendleft(raw)
353
+ ingested.append({"alert_id": alert.id, "resolved_type": raw["type"]})
354
+ return {"status": "queued", "queued": len(env.real_alerts_queue), "ingested": ingested}
355
+
356
+
357
+ # ── Environment control ───────────────────────────────────────────────────────
358
+
359
+ async def _do_reset(task_id: str = "hard", seed: Optional[int] = None) -> dict:
360
+ """
361
+ Shared reset logic used by all reset endpoints.
362
+ Returns a dict suitable for JSON response.
363
+ """
364
+ global env
365
+ if task_id not in ("easy", "medium", "hard"):
366
+ return {"error": f"Invalid task_id '{task_id}'. Must be one of: easy, medium, hard"}
367
+ try:
368
+ saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None
369
+ env = AdaptiveAlertTriageEnv(task_id=task_id)
370
+ env.real_alerts_queue = saved if saved is not None else deque(maxlen=50)
371
+ agent = _load_ppo(task_id)
372
+ if agent:
373
+ _ppo_agents[task_id] = agent
374
+ obs = env.reset(seed=seed)
375
+ _reset_score()
376
+ return {"status": "reset", "task_id": task_id, "obs": obs.model_dump()}
377
+ except Exception as e:
378
+ return {"error": str(e), "traceback": traceback.format_exc()}
379
+
380
+
381
+ # FIX 7 β€” Top-level /reset endpoint required by OpenEnv validator ping
382
+ # The pre-submission checker does: POST $PING_URL/reset
383
+ # This must return 200 and a valid Observation.
384
+ @app.post("/reset")
385
+ async def reset_top_level(request: Optional[ResetRequest] = None):
386
+ """
387
+ OpenEnv-required top-level reset endpoint.
388
+
389
+ POST /reset
390
+ Body (optional JSON): {"task_id": "easy"|"medium"|"hard", "seed": int}
391
+
392
+ Returns the initial Observation for the new episode.
393
+ This is the endpoint pinged by the pre-submission checker.
394
+ """
395
+ task_id = "hard"
396
+ seed = None
397
+ if request is not None:
398
+ task_id = request.task_id or "hard"
399
+ seed = request.seed
400
+ return await _do_reset(task_id=task_id, seed=seed)
401
+
402
+
403
+ # FIX 8 β€” /env/reset without a path parameter (alias, defaults to "hard")
404
+ @app.post("/env/reset")
405
+ async def reset_env_default(request: Optional[ResetRequest] = None):
406
+ """
407
+ Alias for /env/reset/{task_id} without requiring a path parameter.
408
+ Accepts the same optional JSON body as /reset.
409
+ """
410
+ task_id = "hard"
411
+ seed = None
412
+ if request is not None:
413
+ task_id = request.task_id or "hard"
414
+ seed = request.seed
415
+ return await _do_reset(task_id=task_id, seed=seed)
416
+
417
+
418
+ @app.post("/env/reset/{task_id}")
419
+ async def reset_env(task_id: str = "hard"):
420
+ """Reset with explicit task_id in path (original endpoint, kept for compatibility)."""
421
+ return await _do_reset(task_id=task_id)
422
+
423
+
424
+ import time
425
+ _last_manual_step_time = 0.0
426
+
427
+ @app.post("/env/step")
428
+ async def step_env(request: StepRequest):
429
+ global episode_scores, _last_manual_step_time
430
+ _last_manual_step_time = time.time()
431
+
432
+ if not env:
433
+ return {"error": "not initialized"}
434
+ if request.action_type not in {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"}:
435
+ return {"error": f"Invalid action '{request.action_type}'"}
436
+ try:
437
+ from rl_agent import encode_state # type: ignore
438
+ old_obs = Observation(
439
+ alerts = list(env.alerts),
440
+ system_load = getattr(env, "_last_system_load", 0.5),
441
+ queue_length = len(env.alerts),
442
+ time_remaining = env.max_steps - env.current_step,
443
+ resource_budget=(
444
+ env.max_investigations_per_step - env.investigations_used
445
+ if env.max_investigations_per_step is not None else None
446
+ ),
447
+ episode_step = env.current_step,
448
+ )
449
+
450
+ action = Action(alert_id=request.alert_id, action_type=request.action_type)
451
+ obs, reward, done, info = env.step(action)
452
+
453
+ agent = _ppo_agents.get(env.task_id)
454
+ if agent is not None:
455
+ agent.net.forward(encode_state(old_obs))
456
+
457
+ _tick(info)
458
+ s = _score()
459
+ info["task_score"] = s
460
+ if done:
461
+ episode_scores.append(s)
462
+ _reset_score()
463
+ return {"obs": obs.model_dump(), "reward": reward.value,
464
+ "done": done, "info": info, "score": s}
465
+ except Exception as e:
466
+ return {"error": str(e), "traceback": traceback.format_exc()}
467
+
468
+
469
+ @app.get("/env/state")
470
+ async def get_state():
471
+ if not env:
472
+ return {"error": "not initialized"}
473
+ try:
474
+ state = env.state()
475
+ return {
476
+ "visible_state": {
477
+ "alerts": [a.model_dump() for a in env.alerts],
478
+ "current_step": env.current_step,
479
+ "max_steps": env.max_steps,
480
+ "failures_count": env.failures_count,
481
+ "system_load": state.observation.system_load,
482
+ "queue_length": len(env.alerts),
483
+ "task_id": env.task_id,
484
+ "real_queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0,
485
+ },
486
+ "hidden_state": state.hidden_state,
487
+ "cumulative_reward": state.cumulative_reward,
488
+ }
489
+ except Exception as e:
490
+ return {"error": str(e), "traceback": traceback.format_exc()}
491
+
492
+
493
+ # ── Agent recommendation ──────────────────────────────────────────────────────
494
+
495
+ @app.get("/agent/recommend")
496
+ async def recommend():
497
+ if not env or not env.alerts:
498
+ return {
499
+ "error": "No alerts yet β€” episode loop is starting, retry in 2s",
500
+ "active_alerts": len(env.alerts) if env else 0,
501
+ }
502
+
503
+ task_id = env.task_id
504
+ top = max(env.alerts, key=lambda a: a.visible_severity)
505
+
506
+ ppo = _ppo_agents.get(task_id)
507
+ if ppo is not None:
508
+ try:
509
+ import numpy as np
510
+ obs = Observation(
511
+ alerts = list(env.alerts),
512
+ system_load = getattr(env, "_last_system_load", 0.5),
513
+ queue_length = len(env.alerts),
514
+ time_remaining = env.max_steps - env.current_step,
515
+ resource_budget=(
516
+ env.max_investigations_per_step - env.investigations_used
517
+ if env.max_investigations_per_step is not None else None
518
+ ),
519
+ episode_step = env.current_step,
520
+ )
521
+ s = encode_state(obs)
522
+ old_h, old_c = ppo.net.h.copy(), ppo.net.c.copy()
523
+ probs, val = ppo.net.forward(s)
524
+ ppo.net.h, ppo.net.c = old_h, old_c
525
+ idx = int(np.random.choice(4, p=probs))
526
+ act = _ACTION_NAMES[idx]
527
+ conf = round(float(probs[idx]) * 100, 1)
528
+ return {
529
+ "alert_id": top.id,
530
+ "action_type": act,
531
+ "reasoning": f"PPO ({conf:.1f}% confidence)",
532
+ "source": "trained_ppo",
533
+ "model_confidence": conf,
534
+ "probabilities": {_ACTION_NAMES[i]: round(float(probs[i]), 4) for i in range(4)},
535
+ "value_estimate": round(float(val), 3),
536
+ "alert_severity": top.visible_severity,
537
+ "alert_confidence": top.confidence,
538
+ "alert_age": top.age,
539
+ "alert_type": top.alert_type,
540
+ "active_alerts": len(env.alerts),
541
+ "episode_step": env.current_step,
542
+ "task_id": task_id,
543
+ }
544
+ except Exception as exc:
545
+ print(f"PPO recommend error: {exc}")
546
+
547
+ # Rule-based fallback
548
+ sev, conf = top.visible_severity, top.confidence
549
+ rem = (env.max_investigations_per_step - env.investigations_used
550
+ if env.max_investigations_per_step is not None else None)
551
+ if sev >= 0.75 and conf >= 0.60:
552
+ act = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE"
553
+ elif conf < 0.30 or sev < 0.30:
554
+ act = "IGNORE"
555
+ elif sev >= 0.55:
556
+ act = "ESCALATE"
557
+ else:
558
+ act = "DELAY"
559
+
560
+ return {
561
+ "alert_id": top.id, "action_type": act,
562
+ "source": "rule_based",
563
+ "alert_severity": sev, "alert_confidence": conf,
564
+ "alert_type": top.alert_type, "active_alerts": len(env.alerts),
565
+ "task_id": task_id,
566
+ "hint": "Run `python train_rl.py --episodes 300` to load PPO weights",
567
+ }
568
+
569
+
570
+ @app.get("/agent/weights/{task_id}")
571
+ async def download_weights(task_id: str):
572
+ from fastapi import HTTPException
573
+ path = os.path.join(_project_root if _project_root else os.getcwd(), "weights", f"ppo_{task_id}.json")
574
+ if not os.path.exists(path):
575
+ raise HTTPException(status_code=404, detail=f"No trained weights found for {task_id}")
576
+ return FileResponse(path, media_type='application/json', filename=f"ppo_{task_id}.json")
577
+
578
+
579
+ # ── WebSocket ─────────────────────────────────────────────────────────────────
580
+
581
+ @app.websocket("/ws/train")
582
+ async def ws_train(websocket: WebSocket):
583
+ global env, episode_scores
584
+ await websocket.accept()
585
+ lc = lt = 0
586
+ try:
587
+ while True:
588
+ data = await websocket.receive_json()
589
+ if data.get("type") == "reset":
590
+ tid = data.get("task_id", "hard")
591
+ saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None
592
+ env = AdaptiveAlertTriageEnv(task_id=tid)
593
+ env.real_alerts_queue = saved or deque(maxlen=50)
594
+ obs = env.reset()
595
+ lc = lt = 0
596
+ await websocket.send_json({"obs": obs.model_dump(), "task_id": tid})
597
+ elif data.get("type") == "step":
598
+ if not env:
599
+ await websocket.send_json({"error": "Reset first"}); continue
600
+ ad = data.get("action", {})
601
+ act = Action(alert_id=ad.get("alert_id",""), action_type=ad.get("action_type","IGNORE"))
602
+ obs, reward, done, info = env.step(act)
603
+ lt += 1
604
+ if info.get("action_correct", False): lc += 1
605
+ s = lc / lt if lt else 0.0
606
+ if done: episode_scores.append(s)
607
+ info["task_score"] = s
608
+ await websocket.send_json({
609
+ "obs": obs.model_dump(), "reward": reward.value,
610
+ "done": done, "info": info, "task_score": s,
611
+ "action_correct": info.get("action_correct", False),
612
+ "failures_this_step": info.get("failures_this_step", 0),
613
+ })
614
+ elif data.get("type") == "close":
615
+ break
616
+ except WebSocketDisconnect:
617
+ pass
618
+ except Exception as e:
619
+ try: await websocket.send_json({"error": str(e)})
620
+ except Exception: pass
621
+
622
+
623
+ # ── Utility ───────────────────────────────────────────────────────────────────
624
+
625
+ @app.get("/")
626
+ async def root():
627
+ return {
628
+ "name": "Adaptive Alert Triage RL Server", "version": "0.3.1",
629
+ "openenv_endpoints": {
630
+ "reset": "POST /reset",
631
+ "step": "POST /env/step",
632
+ "state": "GET /env/state",
633
+ "health": "GET /health",
634
+ },
635
+ "quick_start": [
636
+ "1. python train_rl.py --episodes 300",
637
+ "2. uvicorn src.adaptive_alert_triage.server:app --port 7860",
638
+ "3. curl -X POST localhost:7860/reset",
639
+ "4. curl localhost:7860/agent/recommend",
640
+ ],
641
+ }
642
+
643
+
644
+ import threading
645
+ import subprocess
646
+
647
+ _training_proc = None
648
+ _training_logs = []
649
+
650
+ def _run_training(episodes: int):
651
+ global _training_proc, _training_logs, _ppo_agents
652
+ _training_logs = [f"Starting training with --episodes {episodes}..."]
653
+ try:
654
+ _training_proc = subprocess.Popen(
655
+ [sys.executable, "train_rl.py", "--episodes", str(episodes)],
656
+ stdout=subprocess.PIPE,
657
+ stderr=subprocess.STDOUT,
658
+ text=True,
659
+ bufsize=1,
660
+ cwd=_project_root if _project_root else os.getcwd()
661
+ )
662
+ for line in iter(_training_proc.stdout.readline, ''):
663
+ if line:
664
+ _training_logs.append(line.rstrip('\n'))
665
+ if len(_training_logs) > 1000:
666
+ _training_logs.pop(0)
667
+ _training_proc.wait()
668
+ _training_logs.append(f"Training finished with exit code {_training_proc.returncode}")
669
+
670
+ if _training_proc.returncode == 0:
671
+ for tid in ("easy", "medium", "hard"):
672
+ agent = _load_ppo(tid)
673
+ if agent:
674
+ _ppo_agents[tid] = agent
675
+ _training_logs.append("Successfully reloaded PPO weights for all tasks.")
676
+ except Exception as e:
677
+ _training_logs.append(f"Error starting training: {e}")
678
+
679
+ @app.post("/train")
680
+ async def start_training(episodes: int = 300):
681
+ global _training_proc
682
+ if _training_proc is not None and _training_proc.poll() is None:
683
+ return {"status": "already running"}
684
+ threading.Thread(target=_run_training, args=(episodes,), daemon=True).start()
685
+ return {"status": "started"}
686
+
687
+ @app.get("/train/status")
688
+ async def get_training_status():
689
+ global _training_proc, _training_logs
690
+ is_running = _training_proc is not None and _training_proc.poll() is None
691
+ return {"is_running": is_running, "logs": _training_logs}
692
+
693
+ @app.get("/web")
694
+ async def web_ui():
695
+ import os
696
+ dashboard_path = os.path.join(
697
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
698
+ "dashboard.html"
699
+ )
700
+ return FileResponse(dashboard_path, media_type="text/html")
701
+
702
+
703
+ @app.get("/tasks")
704
+ async def list_tasks():
705
+ return {"tasks": [
706
+ {"id": "easy", "success_threshold": 0.70, "max_steps": 30},
707
+ {"id": "medium", "success_threshold": 0.55, "max_steps": 40},
708
+ {"id": "hard", "success_threshold": 0.50, "max_steps": 50},
709
+ ]}
710
+
711
+
712
+ def main():
713
+ """Entry point for the server CLI command."""
714
+ import uvicorn
715
+ import os
716
+ port = int(os.environ.get("PORT", 7860))
717
+ uvicorn.run("adaptive_alert_triage.server:app", host="0.0.0.0", port=port)
718
+
719
+
720
+ if __name__ == "__main__":
721
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff