Spaces:
Sleeping
Sleeping
Commit ·
cd2d585
1
Parent(s): a437a7c
fix(inference): final cleanup and user modifications
Browse files- inference.py +25 -98
inference.py
CHANGED
|
@@ -38,22 +38,10 @@ You receive time-windowed sensor readings each step and must detect cyberattacks
|
|
| 38 |
vq_window: q-axis voltage error (should be ~0 when healthy)
|
| 39 |
vd_window: d-axis voltage
|
| 40 |
omega_window: estimated frequency (normalized, nominal=0)
|
| 41 |
-
omega_deviation_window: frequency deviation from nominal in rad/s
|
| 42 |
raw_voltages: [va, vb, vc] at current step
|
| 43 |
task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
|
| 44 |
|
| 45 |
-
For task_id=0: Focus on detecting any attack (attack_detected=True/False).
|
| 46 |
-
For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
|
| 47 |
-
For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.
|
| 48 |
-
|
| 49 |
-
Analysis tips:
|
| 50 |
-
- In healthy state, vq values should be near 0 and stable.
|
| 51 |
-
- Sinusoidal attacks cause oscillating patterns in vq.
|
| 52 |
-
- Ramp attacks cause steadily increasing vq magnitude.
|
| 53 |
-
- Pulse attacks cause sudden step changes in vq.
|
| 54 |
-
- Stealthy attacks cause very slow, gradual drift in omega_deviation_window.
|
| 55 |
-
- Look at trends across the full window, not just the latest value.
|
| 56 |
-
|
| 57 |
Respond ONLY with valid JSON, no explanation:
|
| 58 |
{
|
| 59 |
"attack_detected": <bool>,
|
|
@@ -170,12 +158,6 @@ def detector_agent(prev_info: dict) -> Optional[dict]:
|
|
| 170 |
det = prev_info.get("detector", {})
|
| 171 |
if not det or "attack_detected" not in det:
|
| 172 |
return None
|
| 173 |
-
|
| 174 |
-
# Fall back to heuristic if detector confidence is < 0.5
|
| 175 |
-
# to preserve heuristic base logic scoring results.
|
| 176 |
-
if float(det.get("confidence", 0.0)) < 0.5:
|
| 177 |
-
return None
|
| 178 |
-
|
| 179 |
return safe_clamp_action(det)
|
| 180 |
except Exception:
|
| 181 |
return None
|
|
@@ -186,29 +168,20 @@ def detector_agent(prev_info: dict) -> Optional[dict]:
|
|
| 186 |
# =====================================================================
|
| 187 |
|
| 188 |
class HeuristicState:
|
| 189 |
-
"""Tracks running state for the heuristic agent across steps."""
|
| 190 |
def __init__(self):
|
| 191 |
self.reset()
|
| 192 |
-
|
| 193 |
def reset(self):
|
| 194 |
-
self.vq_history = []
|
| 195 |
-
self.omega_dev_history = []
|
| 196 |
-
self.attack_detected = False
|
| 197 |
-
self.predicted_type = 0
|
| 198 |
-
self.settled_baseline = None
|
| 199 |
-
self.peak_vq = 0.0
|
| 200 |
-
|
| 201 |
|
| 202 |
_hstate = HeuristicState()
|
| 203 |
|
| 204 |
def heuristic_agent(obs: dict) -> dict:
|
| 205 |
-
"""
|
| 206 |
-
Rule-based attack detector using cumulative state tracking.
|
| 207 |
-
This runs instantly.
|
| 208 |
-
The key insight is that the PLL's closed-loop response transforms
|
| 209 |
-
attack signals, so I track statistics over time rather than
|
| 210 |
-
trying to classify from a single 20-step vq window shape.
|
| 211 |
-
"""
|
| 212 |
try:
|
| 213 |
global _hstate
|
| 214 |
vq = obs.get("vq_window", [])
|
|
@@ -222,41 +195,27 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 222 |
if step == 0:
|
| 223 |
_hstate.reset()
|
| 224 |
|
| 225 |
-
# --- Computing per-step features ---
|
| 226 |
vq_abs = [abs(v) for v in vq]
|
| 227 |
-
vq_mean = sum(vq_abs) / len(vq_abs)
|
| 228 |
-
vq_max = max(vq_abs)
|
| 229 |
-
vq_latest = abs(vq[-1])
|
| 230 |
|
| 231 |
omega_dev_abs = [abs(v) for v in omega_dev]
|
| 232 |
-
omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs)
|
| 233 |
|
| 234 |
-
# Tracking history
|
| 235 |
_hstate.vq_history.append(vq_mean)
|
| 236 |
_hstate.omega_dev_history.append(omega_dev_mean)
|
| 237 |
_hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
|
| 238 |
|
| 239 |
-
# Recording baseline around step 45-50 (PLL settled)
|
| 240 |
if step == 50:
|
| 241 |
_hstate.settled_baseline = omega_dev_mean
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
# After PLL warm-start settles (~step 20-30), healthy vq < 0.005
|
| 246 |
-
# -----------------------------------------------------------------
|
| 247 |
-
if step < 25:
|
| 248 |
-
# PLL still settling, don't detect
|
| 249 |
-
detected = False
|
| 250 |
-
else:
|
| 251 |
detected = vq_mean > 0.01 or vq_max > 0.025
|
| 252 |
|
| 253 |
-
# Latch detection on
|
| 254 |
if detected:
|
| 255 |
_hstate.attack_detected = True
|
| 256 |
|
| 257 |
-
# -----------------------------------------------------------------
|
| 258 |
-
# Task 0: Binary detection only
|
| 259 |
-
# -----------------------------------------------------------------
|
| 260 |
if task_id == 0:
|
| 261 |
return safe_clamp_action({
|
| 262 |
"attack_detected": _hstate.attack_detected,
|
|
@@ -265,9 +224,6 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 265 |
"protective_action": 1 if _hstate.attack_detected else 0,
|
| 266 |
})
|
| 267 |
|
| 268 |
-
# -----------------------------------------------------------------
|
| 269 |
-
# Task 1: Classification using cumulative patterns
|
| 270 |
-
# -----------------------------------------------------------------
|
| 271 |
if task_id == 1:
|
| 272 |
if not _hstate.attack_detected:
|
| 273 |
return safe_clamp_action({
|
|
@@ -276,26 +232,16 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 276 |
"confidence": 0.7,
|
| 277 |
"protective_action": 0,
|
| 278 |
})
|
| 279 |
-
|
| 280 |
-
# Classify using cumulative vq_history
|
| 281 |
-
# Only classify after enough attack data (10+ steps of elevated vq)
|
| 282 |
n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
|
| 283 |
-
|
| 284 |
if n_elevated < 5:
|
| 285 |
-
# Not enough data yet, use simple guess
|
| 286 |
attack_type = 1
|
| 287 |
else:
|
| 288 |
-
# Get recent vq trend (last 10 elevated values)
|
| 289 |
elevated = [v for v in _hstate.vq_history if v > 0.005]
|
| 290 |
recent = elevated[-min(20, len(elevated)):]
|
| 291 |
-
|
| 292 |
-
# Feature 1: Is vq currently high or has it decayed?
|
| 293 |
current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
|
| 294 |
-
|
| 295 |
-
# Feature 2: How many zero crossings in current window
|
| 296 |
zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)
|
| 297 |
-
|
| 298 |
-
# Feature 3: Is vq growing or shrinking over recent history
|
| 299 |
if len(recent) >= 6:
|
| 300 |
first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
|
| 301 |
last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
|
|
@@ -303,35 +249,23 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 303 |
else:
|
| 304 |
growth = 1.0
|
| 305 |
|
| 306 |
-
# Classification logic:
|
| 307 |
-
# Sinusoidal: persistent oscillation, zero crossings, stable amplitude
|
| 308 |
-
# Ramp: growing vq over time (growth > 1)
|
| 309 |
-
# Pulse: high initial vq that decays to near zero (current_vs_peak < 0.3)
|
| 310 |
-
|
| 311 |
if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
|
| 312 |
-
# vq has decayed significantly from peak -> pulse (ended)
|
| 313 |
attack_type = 3
|
| 314 |
elif current_vs_peak < 0.4 and n_elevated > 30:
|
| 315 |
-
# vq decayed after a long time -> pulse
|
| 316 |
attack_type = 3
|
| 317 |
elif zero_crossings >= 2 and growth < 1.5:
|
| 318 |
-
# Active oscillation without growing -> sinusoidal
|
| 319 |
attack_type = 1
|
| 320 |
elif growth > 1.3:
|
| 321 |
-
# Growing signal -> ramp
|
| 322 |
attack_type = 2
|
| 323 |
elif zero_crossings >= 1:
|
| 324 |
-
# Some oscillation -> sinusoidal
|
| 325 |
attack_type = 1
|
| 326 |
else:
|
| 327 |
-
# Default: if mono-decrease, pulse; else sinusoidal
|
| 328 |
vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
|
| 329 |
neg = sum(1 for d in vq_diffs if d < 0)
|
| 330 |
-
if neg > 14:
|
| 331 |
attack_type = 3
|
| 332 |
else:
|
| 333 |
attack_type = 1
|
| 334 |
-
|
| 335 |
_hstate.predicted_type = attack_type
|
| 336 |
|
| 337 |
return safe_clamp_action({
|
|
@@ -341,20 +275,12 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 341 |
"protective_action": 1,
|
| 342 |
})
|
| 343 |
|
| 344 |
-
# -----------------------------------------------------------------
|
| 345 |
-
# Task 2: Stealthy attack — detecting omega_dev rising above baseline
|
| 346 |
-
# -----------------------------------------------------------------
|
| 347 |
if task_id == 2:
|
| 348 |
drift_detected = False
|
| 349 |
confidence = 0.3
|
| 350 |
-
|
| 351 |
if step > 50 and _hstate.settled_baseline is not None:
|
| 352 |
baseline = _hstate.settled_baseline
|
| 353 |
-
|
| 354 |
-
# Compare current to baseline
|
| 355 |
ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
|
| 356 |
-
|
| 357 |
-
# Checking if omega_dev is rising relative to recent history
|
| 358 |
if len(_hstate.omega_dev_history) > 10:
|
| 359 |
recent_10 = _hstate.omega_dev_history[-10:]
|
| 360 |
old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
|
|
@@ -363,7 +289,6 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 363 |
rising = recent_avg > old_avg * 1.1
|
| 364 |
else:
|
| 365 |
rising = False
|
| 366 |
-
|
| 367 |
if ratio > 2.0:
|
| 368 |
drift_detected = True
|
| 369 |
confidence = 0.9
|
|
@@ -376,10 +301,8 @@ def heuristic_agent(obs: dict) -> dict:
|
|
| 376 |
elif vq_mean > 0.2:
|
| 377 |
drift_detected = True
|
| 378 |
confidence = 0.5
|
| 379 |
-
|
| 380 |
if drift_detected:
|
| 381 |
_hstate.attack_detected = True
|
| 382 |
-
|
| 383 |
return safe_clamp_action({
|
| 384 |
"attack_detected": drift_detected,
|
| 385 |
"attack_type": 4 if drift_detected else 0,
|
|
@@ -489,16 +412,20 @@ def run_episode(task_id: int) -> float:
|
|
| 489 |
while not done:
|
| 490 |
action = None
|
| 491 |
|
| 492 |
-
# Priority 1:
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
try:
|
| 495 |
action = llm_agent(obs)
|
| 496 |
except Exception:
|
| 497 |
pass
|
| 498 |
|
| 499 |
-
# Priority
|
| 500 |
-
# Note: We bypass `detector_agent` here to perfectly preserve
|
| 501 |
-
# the baseline 0.6786 performance trajectory from github.
|
| 502 |
if not action:
|
| 503 |
try:
|
| 504 |
action = heuristic_agent(obs)
|
|
|
|
| 38 |
vq_window: q-axis voltage error (should be ~0 when healthy)
|
| 39 |
vd_window: d-axis voltage
|
| 40 |
omega_window: estimated frequency (normalized, nominal=0)
|
| 41 |
+
omega_deviation_window: frequency deviation from nominal in rad/s
|
| 42 |
raw_voltages: [va, vb, vc] at current step
|
| 43 |
task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
Respond ONLY with valid JSON, no explanation:
|
| 46 |
{
|
| 47 |
"attack_detected": <bool>,
|
|
|
|
| 158 |
det = prev_info.get("detector", {})
|
| 159 |
if not det or "attack_detected" not in det:
|
| 160 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return safe_clamp_action(det)
|
| 162 |
except Exception:
|
| 163 |
return None
|
|
|
|
| 168 |
# =====================================================================
|
| 169 |
|
| 170 |
class HeuristicState:
|
|
|
|
| 171 |
def __init__(self):
|
| 172 |
self.reset()
|
|
|
|
| 173 |
def reset(self):
|
| 174 |
+
self.vq_history = []
|
| 175 |
+
self.omega_dev_history = []
|
| 176 |
+
self.attack_detected = False
|
| 177 |
+
self.predicted_type = 0
|
| 178 |
+
self.settled_baseline = None
|
| 179 |
+
self.peak_vq = 0.0
|
|
|
|
| 180 |
|
| 181 |
_hstate = HeuristicState()
|
| 182 |
|
| 183 |
def heuristic_agent(obs: dict) -> dict:
|
| 184 |
+
"""Safe heuristic agent fallback."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
try:
|
| 186 |
global _hstate
|
| 187 |
vq = obs.get("vq_window", [])
|
|
|
|
| 195 |
if step == 0:
|
| 196 |
_hstate.reset()
|
| 197 |
|
|
|
|
| 198 |
vq_abs = [abs(v) for v in vq]
|
| 199 |
+
vq_mean = sum(vq_abs) / len(vq_abs) if vq_abs else 0.0
|
| 200 |
+
vq_max = max(vq_abs) if vq_abs else 0.0
|
|
|
|
| 201 |
|
| 202 |
omega_dev_abs = [abs(v) for v in omega_dev]
|
| 203 |
+
omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) if omega_dev_abs else 0.0
|
| 204 |
|
|
|
|
| 205 |
_hstate.vq_history.append(vq_mean)
|
| 206 |
_hstate.omega_dev_history.append(omega_dev_mean)
|
| 207 |
_hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
|
| 208 |
|
|
|
|
| 209 |
if step == 50:
|
| 210 |
_hstate.settled_baseline = omega_dev_mean
|
| 211 |
|
| 212 |
+
detected = False
|
| 213 |
+
if step >= 25:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
detected = vq_mean > 0.01 or vq_max > 0.025
|
| 215 |
|
|
|
|
| 216 |
if detected:
|
| 217 |
_hstate.attack_detected = True
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
if task_id == 0:
|
| 220 |
return safe_clamp_action({
|
| 221 |
"attack_detected": _hstate.attack_detected,
|
|
|
|
| 224 |
"protective_action": 1 if _hstate.attack_detected else 0,
|
| 225 |
})
|
| 226 |
|
|
|
|
|
|
|
|
|
|
| 227 |
if task_id == 1:
|
| 228 |
if not _hstate.attack_detected:
|
| 229 |
return safe_clamp_action({
|
|
|
|
| 232 |
"confidence": 0.7,
|
| 233 |
"protective_action": 0,
|
| 234 |
})
|
| 235 |
+
|
|
|
|
|
|
|
| 236 |
n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
|
|
|
|
| 237 |
if n_elevated < 5:
|
|
|
|
| 238 |
attack_type = 1
|
| 239 |
else:
|
|
|
|
| 240 |
elevated = [v for v in _hstate.vq_history if v > 0.005]
|
| 241 |
recent = elevated[-min(20, len(elevated)):]
|
|
|
|
|
|
|
| 242 |
current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
|
|
|
|
|
|
|
| 243 |
zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)
|
| 244 |
+
|
|
|
|
| 245 |
if len(recent) >= 6:
|
| 246 |
first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
|
| 247 |
last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
|
|
|
|
| 249 |
else:
|
| 250 |
growth = 1.0
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
|
|
|
|
| 253 |
attack_type = 3
|
| 254 |
elif current_vs_peak < 0.4 and n_elevated > 30:
|
|
|
|
| 255 |
attack_type = 3
|
| 256 |
elif zero_crossings >= 2 and growth < 1.5:
|
|
|
|
| 257 |
attack_type = 1
|
| 258 |
elif growth > 1.3:
|
|
|
|
| 259 |
attack_type = 2
|
| 260 |
elif zero_crossings >= 1:
|
|
|
|
| 261 |
attack_type = 1
|
| 262 |
else:
|
|
|
|
| 263 |
vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
|
| 264 |
neg = sum(1 for d in vq_diffs if d < 0)
|
| 265 |
+
if neg > 14:
|
| 266 |
attack_type = 3
|
| 267 |
else:
|
| 268 |
attack_type = 1
|
|
|
|
| 269 |
_hstate.predicted_type = attack_type
|
| 270 |
|
| 271 |
return safe_clamp_action({
|
|
|
|
| 275 |
"protective_action": 1,
|
| 276 |
})
|
| 277 |
|
|
|
|
|
|
|
|
|
|
| 278 |
if task_id == 2:
|
| 279 |
drift_detected = False
|
| 280 |
confidence = 0.3
|
|
|
|
| 281 |
if step > 50 and _hstate.settled_baseline is not None:
|
| 282 |
baseline = _hstate.settled_baseline
|
|
|
|
|
|
|
| 283 |
ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
|
|
|
|
|
|
|
| 284 |
if len(_hstate.omega_dev_history) > 10:
|
| 285 |
recent_10 = _hstate.omega_dev_history[-10:]
|
| 286 |
old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
|
|
|
|
| 289 |
rising = recent_avg > old_avg * 1.1
|
| 290 |
else:
|
| 291 |
rising = False
|
|
|
|
| 292 |
if ratio > 2.0:
|
| 293 |
drift_detected = True
|
| 294 |
confidence = 0.9
|
|
|
|
| 301 |
elif vq_mean > 0.2:
|
| 302 |
drift_detected = True
|
| 303 |
confidence = 0.5
|
|
|
|
| 304 |
if drift_detected:
|
| 305 |
_hstate.attack_detected = True
|
|
|
|
| 306 |
return safe_clamp_action({
|
| 307 |
"attack_detected": drift_detected,
|
| 308 |
"attack_type": 4 if drift_detected else 0,
|
|
|
|
| 412 |
while not done:
|
| 413 |
action = None
|
| 414 |
|
| 415 |
+
# Priority 1: Detector Output
|
| 416 |
+
try:
|
| 417 |
+
action = detector_agent(prev_info)
|
| 418 |
+
except Exception:
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
# Priority 2: Optional LLM
|
| 422 |
+
if not action and USE_LLM:
|
| 423 |
try:
|
| 424 |
action = llm_agent(obs)
|
| 425 |
except Exception:
|
| 426 |
pass
|
| 427 |
|
| 428 |
+
# Priority 3: Safe Rule-Based Heuristic Fallback
|
|
|
|
|
|
|
| 429 |
if not action:
|
| 430 |
try:
|
| 431 |
action = heuristic_agent(obs)
|