Spaces:
Sleeping
Sleeping
- EasterEgg.jpeg +0 -0
- inference.py +5 -4
- requirements.txt +1 -0
- rewards/reward.py +5 -2
- src/adaptive_alert_triage/env.py +8 -8
- src/adaptive_alert_triage/models.py +8 -1
- src/adaptive_alert_triage/validate.py +1 -1
EasterEgg.jpeg
ADDED
|
inference.py
CHANGED
|
@@ -67,7 +67,7 @@ except ImportError:
|
|
| 67 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 68 |
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
| 69 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 70 |
-
_API_KEY = HF_TOKEN or os.environ.get("OPENAI_API_KEY", "no-key-set")
|
| 71 |
|
| 72 |
# ββ Task registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
_TASKS: Dict[str, Dict[str, Any]] = {
|
|
@@ -171,6 +171,7 @@ class LLMTriageAgent:
|
|
| 171 |
def act(self, obs: Observation) -> Action:
|
| 172 |
if not obs.alerts:
|
| 173 |
raise ValueError("act() called with empty alerts")
|
|
|
|
| 174 |
text = self._call_api(_build_user_message(obs))
|
| 175 |
if text is None:
|
| 176 |
self.fallbacks += 1
|
|
@@ -306,7 +307,7 @@ def run_episode(agent: LLMTriageAgent, task_id: str, episode: int, seed: int) ->
|
|
| 306 |
|
| 307 |
def run_baseline(
|
| 308 |
tasks: List[str],
|
| 309 |
-
num_episodes: int =
|
| 310 |
seed_offset: int = 42,
|
| 311 |
) -> Dict[str, Any]:
|
| 312 |
"""
|
|
@@ -383,9 +384,9 @@ if __name__ == "__main__":
|
|
| 383 |
)
|
| 384 |
p.add_argument("--task", choices=["easy", "medium", "hard"],
|
| 385 |
default=None, help="Single task (default: all three)")
|
| 386 |
-
p.add_argument("--n", type=int, default=
|
| 387 |
metavar="N",
|
| 388 |
-
help="Episodes per task (default:
|
| 389 |
p.add_argument("--seed", type=int, default=42,
|
| 390 |
help="Base random seed (default: 42)")
|
| 391 |
args = p.parse_args()
|
|
|
|
| 67 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 68 |
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
| 69 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 70 |
+
_API_KEY = os.environ.get("API_KEY") or HF_TOKEN or os.environ.get("OPENAI_API_KEY", "no-key-set")
|
| 71 |
|
| 72 |
# ββ Task registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
_TASKS: Dict[str, Dict[str, Any]] = {
|
|
|
|
| 171 |
def act(self, obs: Observation) -> Action:
|
| 172 |
if not obs.alerts:
|
| 173 |
raise ValueError("act() called with empty alerts")
|
| 174 |
+
|
| 175 |
text = self._call_api(_build_user_message(obs))
|
| 176 |
if text is None:
|
| 177 |
self.fallbacks += 1
|
|
|
|
| 307 |
|
| 308 |
def run_baseline(
|
| 309 |
tasks: List[str],
|
| 310 |
+
num_episodes: int = 1,
|
| 311 |
seed_offset: int = 42,
|
| 312 |
) -> Dict[str, Any]:
|
| 313 |
"""
|
|
|
|
| 384 |
)
|
| 385 |
p.add_argument("--task", choices=["easy", "medium", "hard"],
|
| 386 |
default=None, help="Single task (default: all three)")
|
| 387 |
+
p.add_argument("--n", type=int, default=1,
|
| 388 |
metavar="N",
|
| 389 |
+
help="Episodes per task (default: 1 β strict API budget)")
|
| 390 |
p.add_argument("--seed", type=int, default=42,
|
| 391 |
help="Base random seed (default: 42)")
|
| 392 |
args = p.parse_args()
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
# ββ Core environment ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
numpy>=1.24.0
|
| 8 |
openenv>=0.1.0
|
|
|
|
| 9 |
pydantic>=2.0.0
|
| 10 |
|
| 11 |
# ββ Web framework (FastAPI server) ββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 6 |
# ββ Core environment ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
numpy>=1.24.0
|
| 8 |
openenv>=0.1.0
|
| 9 |
+
openenv-core>=0.2.0
|
| 10 |
pydantic>=2.0.0
|
| 11 |
|
| 12 |
# ββ Web framework (FastAPI server) ββββββββββββββββββββββββββββββββββββββββββββ
|
rewards/reward.py
CHANGED
|
@@ -315,6 +315,7 @@ def calculate_reward(
|
|
| 315 |
components = {k: v * multiplier for k, v in components.items()}
|
| 316 |
|
| 317 |
total_reward: float = sum(components.values())
|
|
|
|
| 318 |
|
| 319 |
# -----------------------------------------------------------------------
|
| 320 |
# Info payload β consumed by graders and evaluation scripts
|
|
@@ -331,10 +332,11 @@ def calculate_reward(
|
|
| 331 |
action_type, is_critical, is_false_positive, resource_constrained
|
| 332 |
),
|
| 333 |
"task_multiplier": multiplier,
|
|
|
|
| 334 |
}
|
| 335 |
|
| 336 |
return Reward(
|
| 337 |
-
value=
|
| 338 |
components=components,
|
| 339 |
info=info,
|
| 340 |
)
|
|
@@ -611,7 +613,8 @@ if __name__ == "__main__":
|
|
| 611 |
for desc, act, alert, cfg, expected in cases:
|
| 612 |
action = Action(alert_id=alert.id, action_type=act)
|
| 613 |
result = calculate_reward(action, alert, cfg)
|
| 614 |
-
|
|
|
|
| 615 |
status = "PASS" if ok else "FAIL"
|
| 616 |
if not ok:
|
| 617 |
all_pass = False
|
|
|
|
| 315 |
components = {k: v * multiplier for k, v in components.items()}
|
| 316 |
|
| 317 |
total_reward: float = sum(components.values())
|
| 318 |
+
norm_reward: float = max(0.01, min(0.99, (total_reward + 40.0) / 80.0))
|
| 319 |
|
| 320 |
# -----------------------------------------------------------------------
|
| 321 |
# Info payload β consumed by graders and evaluation scripts
|
|
|
|
| 332 |
action_type, is_critical, is_false_positive, resource_constrained
|
| 333 |
),
|
| 334 |
"task_multiplier": multiplier,
|
| 335 |
+
"raw_reward": total_reward,
|
| 336 |
}
|
| 337 |
|
| 338 |
return Reward(
|
| 339 |
+
value=norm_reward,
|
| 340 |
components=components,
|
| 341 |
info=info,
|
| 342 |
)
|
|
|
|
| 613 |
for desc, act, alert, cfg, expected in cases:
|
| 614 |
action = Action(alert_id=alert.id, action_type=act)
|
| 615 |
result = calculate_reward(action, alert, cfg)
|
| 616 |
+
normalized_expected = max(0.01, min(0.99, (expected + 40.0) / 80.0))
|
| 617 |
+
ok = abs(result.value - normalized_expected) < 1e-4
|
| 618 |
status = "PASS" if ok else "FAIL"
|
| 619 |
if not ok:
|
| 620 |
all_pass = False
|
src/adaptive_alert_triage/env.py
CHANGED
|
@@ -76,22 +76,22 @@ except ImportError:
|
|
| 76 |
|
| 77 |
_TASK_CONFIGS: Dict[str, Dict[str, Any]] = {
|
| 78 |
"easy": {
|
| 79 |
-
"max_steps":
|
| 80 |
-
"failure_threshold":
|
| 81 |
"max_investigations": None, # unconstrained
|
| 82 |
"correlation_probability": 0.10,
|
| 83 |
"description": "Basic alert prioritisation β no resource constraint.",
|
| 84 |
},
|
| 85 |
"medium": {
|
| 86 |
-
"max_steps":
|
| 87 |
-
"failure_threshold":
|
| 88 |
"max_investigations": 3, # K = 3 per step
|
| 89 |
"correlation_probability": 0.20,
|
| 90 |
"description": "Resource-constrained triage β K=3 investigations/step.",
|
| 91 |
},
|
| 92 |
"hard": {
|
| 93 |
-
"max_steps":
|
| 94 |
-
"failure_threshold":
|
| 95 |
"max_investigations": 3,
|
| 96 |
"correlation_probability": 0.40,
|
| 97 |
"description": (
|
|
@@ -267,7 +267,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
|
|
| 267 |
alert = self._get_alert_by_id(action.alert_id)
|
| 268 |
if alert is None:
|
| 269 |
reward = Reward(
|
| 270 |
-
value=
|
| 271 |
components={"invalid_action": -5.0},
|
| 272 |
info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
|
| 273 |
)
|
|
@@ -284,7 +284,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
|
|
| 284 |
):
|
| 285 |
if self.investigations_used >= self.max_investigations_per_step:
|
| 286 |
reward = Reward(
|
| 287 |
-
value=
|
| 288 |
components={"resource_budget_exceeded": -3.0},
|
| 289 |
info={
|
| 290 |
"error": "Investigation budget exhausted for this step",
|
|
|
|
| 76 |
|
| 77 |
_TASK_CONFIGS: Dict[str, Dict[str, Any]] = {
|
| 78 |
"easy": {
|
| 79 |
+
"max_steps": 10,
|
| 80 |
+
"failure_threshold": 2,
|
| 81 |
"max_investigations": None, # unconstrained
|
| 82 |
"correlation_probability": 0.10,
|
| 83 |
"description": "Basic alert prioritisation β no resource constraint.",
|
| 84 |
},
|
| 85 |
"medium": {
|
| 86 |
+
"max_steps": 15,
|
| 87 |
+
"failure_threshold": 3,
|
| 88 |
"max_investigations": 3, # K = 3 per step
|
| 89 |
"correlation_probability": 0.20,
|
| 90 |
"description": "Resource-constrained triage β K=3 investigations/step.",
|
| 91 |
},
|
| 92 |
"hard": {
|
| 93 |
+
"max_steps": 20,
|
| 94 |
+
"failure_threshold": 2, # stricter
|
| 95 |
"max_investigations": 3,
|
| 96 |
"correlation_probability": 0.40,
|
| 97 |
"description": (
|
|
|
|
| 267 |
alert = self._get_alert_by_id(action.alert_id)
|
| 268 |
if alert is None:
|
| 269 |
reward = Reward(
|
| 270 |
+
value=0.01,
|
| 271 |
components={"invalid_action": -5.0},
|
| 272 |
info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
|
| 273 |
)
|
|
|
|
| 284 |
):
|
| 285 |
if self.investigations_used >= self.max_investigations_per_step:
|
| 286 |
reward = Reward(
|
| 287 |
+
value=0.01,
|
| 288 |
components={"resource_budget_exceeded": -3.0},
|
| 289 |
info={
|
| 290 |
"error": "Investigation budget exhausted for this step",
|
src/adaptive_alert_triage/models.py
CHANGED
|
@@ -222,7 +222,14 @@ class Reward(BaseModel):
|
|
| 222 |
info: Debugging / logging extras (ground-truth reveal, etc.).
|
| 223 |
"""
|
| 224 |
|
| 225 |
-
value: float = Field(..., description="Total scalar reward")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
components: Dict[str, float] = Field(
|
| 227 |
default_factory=dict, description="Per-component reward breakdown"
|
| 228 |
)
|
|
|
|
| 222 |
info: Debugging / logging extras (ground-truth reveal, etc.).
|
| 223 |
"""
|
| 224 |
|
| 225 |
+
value: float = Field(..., ge=0.0, le=1.0, description="Total scalar reward in [0.0, 1.0]")
|
| 226 |
+
|
| 227 |
+
@field_validator("value", mode="before")
|
| 228 |
+
@classmethod
|
| 229 |
+
def clamp_reward_value(cls, v: float) -> float:
|
| 230 |
+
"""Silently clamp reward value to [0.01, 0.99] β strict (0, 1) bounds."""
|
| 231 |
+
return float(max(0.01, min(0.99, float(v))))
|
| 232 |
+
|
| 233 |
components: Dict[str, float] = Field(
|
| 234 |
default_factory=dict, description="Per-component reward breakdown"
|
| 235 |
)
|
src/adaptive_alert_triage/validate.py
CHANGED
|
@@ -123,7 +123,7 @@ class OpenEnvValidator:
|
|
| 123 |
action_ok = restored.alert_id == action.alert_id
|
| 124 |
self.check("Action serialization round-trip", action_ok)
|
| 125 |
|
| 126 |
-
reward = Reward(value=
|
| 127 |
restored = Reward.model_validate_json(reward.model_dump_json())
|
| 128 |
reward_ok = restored.value == reward.value
|
| 129 |
self.check("Reward serialization round-trip", reward_ok)
|
|
|
|
| 123 |
action_ok = restored.alert_id == action.alert_id
|
| 124 |
self.check("Action serialization round-trip", action_ok)
|
| 125 |
|
| 126 |
+
reward = Reward(value=0.5, components={"test": 0.5})
|
| 127 |
restored = Reward.model_validate_json(reward.model_dump_json())
|
| 128 |
reward_ok = restored.value == reward.value
|
| 129 |
self.check("Reward serialization round-trip", reward_ok)
|