suraj-01 commited on
Commit
eea342f
Β·
1 Parent(s): dd69fa9
EasterEgg.jpeg ADDED
inference.py CHANGED
@@ -67,7 +67,7 @@ except ImportError:
67
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
68
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
69
  HF_TOKEN = os.environ.get("HF_TOKEN")
70
- _API_KEY = HF_TOKEN or os.environ.get("OPENAI_API_KEY", "no-key-set")
71
 
72
  # ── Task registry ─────────────────────────────────────────────────────────────
73
  _TASKS: Dict[str, Dict[str, Any]] = {
@@ -171,6 +171,7 @@ class LLMTriageAgent:
171
  def act(self, obs: Observation) -> Action:
172
  if not obs.alerts:
173
  raise ValueError("act() called with empty alerts")
 
174
  text = self._call_api(_build_user_message(obs))
175
  if text is None:
176
  self.fallbacks += 1
@@ -306,7 +307,7 @@ def run_episode(agent: LLMTriageAgent, task_id: str, episode: int, seed: int) ->
306
 
307
  def run_baseline(
308
  tasks: List[str],
309
- num_episodes: int = 3,
310
  seed_offset: int = 42,
311
  ) -> Dict[str, Any]:
312
  """
@@ -383,9 +384,9 @@ if __name__ == "__main__":
383
  )
384
  p.add_argument("--task", choices=["easy", "medium", "hard"],
385
  default=None, help="Single task (default: all three)")
386
- p.add_argument("--n", type=int, default=3,
387
  metavar="N",
388
- help="Episodes per task (default: 3 β€” fits in 20 min budget)")
389
  p.add_argument("--seed", type=int, default=42,
390
  help="Base random seed (default: 42)")
391
  args = p.parse_args()
 
67
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
68
  MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
69
  HF_TOKEN = os.environ.get("HF_TOKEN")
70
+ _API_KEY = os.environ.get("API_KEY") or HF_TOKEN or os.environ.get("OPENAI_API_KEY", "no-key-set")
71
 
72
  # ── Task registry ─────────────────────────────────────────────────────────────
73
  _TASKS: Dict[str, Dict[str, Any]] = {
 
171
  def act(self, obs: Observation) -> Action:
172
  if not obs.alerts:
173
  raise ValueError("act() called with empty alerts")
174
+
175
  text = self._call_api(_build_user_message(obs))
176
  if text is None:
177
  self.fallbacks += 1
 
307
 
308
  def run_baseline(
309
  tasks: List[str],
310
+ num_episodes: int = 1,
311
  seed_offset: int = 42,
312
  ) -> Dict[str, Any]:
313
  """
 
384
  )
385
  p.add_argument("--task", choices=["easy", "medium", "hard"],
386
  default=None, help="Single task (default: all three)")
387
+ p.add_argument("--n", type=int, default=1,
388
  metavar="N",
389
+ help="Episodes per task (default: 1 β€” strict API budget)")
390
  p.add_argument("--seed", type=int, default=42,
391
  help="Base random seed (default: 42)")
392
  args = p.parse_args()
requirements.txt CHANGED
@@ -6,6 +6,7 @@
6
  # ── Core environment ──────────────────────────────────────────────────────────
7
  numpy>=1.24.0
8
  openenv>=0.1.0
 
9
  pydantic>=2.0.0
10
 
11
  # ── Web framework (FastAPI server) ────────────────────────────────────────────
 
6
  # ── Core environment ──────────────────────────────────────────────────────────
7
  numpy>=1.24.0
8
  openenv>=0.1.0
9
+ openenv-core>=0.2.0
10
  pydantic>=2.0.0
11
 
12
  # ── Web framework (FastAPI server) ────────────────────────────────────────────
rewards/reward.py CHANGED
@@ -315,6 +315,7 @@ def calculate_reward(
315
  components = {k: v * multiplier for k, v in components.items()}
316
 
317
  total_reward: float = sum(components.values())
 
318
 
319
  # -----------------------------------------------------------------------
320
  # Info payload β€” consumed by graders and evaluation scripts
@@ -331,10 +332,11 @@ def calculate_reward(
331
  action_type, is_critical, is_false_positive, resource_constrained
332
  ),
333
  "task_multiplier": multiplier,
 
334
  }
335
 
336
  return Reward(
337
- value=total_reward,
338
  components=components,
339
  info=info,
340
  )
@@ -611,7 +613,8 @@ if __name__ == "__main__":
611
  for desc, act, alert, cfg, expected in cases:
612
  action = Action(alert_id=alert.id, action_type=act)
613
  result = calculate_reward(action, alert, cfg)
614
- ok = abs(result.value - expected) < 1e-4
 
615
  status = "PASS" if ok else "FAIL"
616
  if not ok:
617
  all_pass = False
 
315
  components = {k: v * multiplier for k, v in components.items()}
316
 
317
  total_reward: float = sum(components.values())
318
+ norm_reward: float = max(0.01, min(0.99, (total_reward + 40.0) / 80.0))
319
 
320
  # -----------------------------------------------------------------------
321
  # Info payload β€” consumed by graders and evaluation scripts
 
332
  action_type, is_critical, is_false_positive, resource_constrained
333
  ),
334
  "task_multiplier": multiplier,
335
+ "raw_reward": total_reward,
336
  }
337
 
338
  return Reward(
339
+ value=norm_reward,
340
  components=components,
341
  info=info,
342
  )
 
613
  for desc, act, alert, cfg, expected in cases:
614
  action = Action(alert_id=alert.id, action_type=act)
615
  result = calculate_reward(action, alert, cfg)
616
+ normalized_expected = max(0.01, min(0.99, (expected + 40.0) / 80.0))
617
+ ok = abs(result.value - normalized_expected) < 1e-4
618
  status = "PASS" if ok else "FAIL"
619
  if not ok:
620
  all_pass = False
src/adaptive_alert_triage/env.py CHANGED
@@ -76,22 +76,22 @@ except ImportError:
76
 
77
  _TASK_CONFIGS: Dict[str, Dict[str, Any]] = {
78
  "easy": {
79
- "max_steps": 30,
80
- "failure_threshold": 5,
81
  "max_investigations": None, # unconstrained
82
  "correlation_probability": 0.10,
83
  "description": "Basic alert prioritisation β€” no resource constraint.",
84
  },
85
  "medium": {
86
- "max_steps": 40,
87
- "failure_threshold": 5,
88
  "max_investigations": 3, # K = 3 per step
89
  "correlation_probability": 0.20,
90
  "description": "Resource-constrained triage β€” K=3 investigations/step.",
91
  },
92
  "hard": {
93
- "max_steps": 50,
94
- "failure_threshold": 3, # stricter
95
  "max_investigations": 3,
96
  "correlation_probability": 0.40,
97
  "description": (
@@ -267,7 +267,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
267
  alert = self._get_alert_by_id(action.alert_id)
268
  if alert is None:
269
  reward = Reward(
270
- value=-5.0,
271
  components={"invalid_action": -5.0},
272
  info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
273
  )
@@ -284,7 +284,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
284
  ):
285
  if self.investigations_used >= self.max_investigations_per_step:
286
  reward = Reward(
287
- value=-3.0,
288
  components={"resource_budget_exceeded": -3.0},
289
  info={
290
  "error": "Investigation budget exhausted for this step",
 
76
 
77
  _TASK_CONFIGS: Dict[str, Dict[str, Any]] = {
78
  "easy": {
79
+ "max_steps": 10,
80
+ "failure_threshold": 2,
81
  "max_investigations": None, # unconstrained
82
  "correlation_probability": 0.10,
83
  "description": "Basic alert prioritisation β€” no resource constraint.",
84
  },
85
  "medium": {
86
+ "max_steps": 15,
87
+ "failure_threshold": 3,
88
  "max_investigations": 3, # K = 3 per step
89
  "correlation_probability": 0.20,
90
  "description": "Resource-constrained triage β€” K=3 investigations/step.",
91
  },
92
  "hard": {
93
+ "max_steps": 20,
94
+ "failure_threshold": 2, # stricter
95
  "max_investigations": 3,
96
  "correlation_probability": 0.40,
97
  "description": (
 
267
  alert = self._get_alert_by_id(action.alert_id)
268
  if alert is None:
269
  reward = Reward(
270
+ value=0.01,
271
  components={"invalid_action": -5.0},
272
  info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
273
  )
 
284
  ):
285
  if self.investigations_used >= self.max_investigations_per_step:
286
  reward = Reward(
287
+ value=0.01,
288
  components={"resource_budget_exceeded": -3.0},
289
  info={
290
  "error": "Investigation budget exhausted for this step",
src/adaptive_alert_triage/models.py CHANGED
@@ -222,7 +222,14 @@ class Reward(BaseModel):
222
  info: Debugging / logging extras (ground-truth reveal, etc.).
223
  """
224
 
225
- value: float = Field(..., description="Total scalar reward")
 
 
 
 
 
 
 
226
  components: Dict[str, float] = Field(
227
  default_factory=dict, description="Per-component reward breakdown"
228
  )
 
222
  info: Debugging / logging extras (ground-truth reveal, etc.).
223
  """
224
 
225
+ value: float = Field(..., ge=0.0, le=1.0, description="Total scalar reward in [0.0, 1.0]")
226
+
227
+ @field_validator("value", mode="before")
228
+ @classmethod
229
+ def clamp_reward_value(cls, v: float) -> float:
230
+ """Silently clamp reward value to [0.01, 0.99] β€” strict (0, 1) bounds."""
231
+ return float(max(0.01, min(0.99, float(v))))
232
+
233
  components: Dict[str, float] = Field(
234
  default_factory=dict, description="Per-component reward breakdown"
235
  )
src/adaptive_alert_triage/validate.py CHANGED
@@ -123,7 +123,7 @@ class OpenEnvValidator:
123
  action_ok = restored.alert_id == action.alert_id
124
  self.check("Action serialization round-trip", action_ok)
125
 
126
- reward = Reward(value=10.0, components={"test": 10.0})
127
  restored = Reward.model_validate_json(reward.model_dump_json())
128
  reward_ok = restored.value == reward.value
129
  self.check("Reward serialization round-trip", reward_ok)
 
123
  action_ok = restored.alert_id == action.alert_id
124
  self.check("Action serialization round-trip", action_ok)
125
 
126
+ reward = Reward(value=0.5, components={"test": 0.5})
127
  restored = Reward.model_validate_json(reward.model_dump_json())
128
  reward_ok = restored.value == reward.value
129
  self.check("Reward serialization round-trip", reward_ok)