krishuggingface commited on
Commit
cd2d585
·
1 Parent(s): a437a7c

fix(inference): final cleanup and user modifications

Browse files
Files changed (1) hide show
  1. inference.py +25 -98
inference.py CHANGED
@@ -38,22 +38,10 @@ You receive time-windowed sensor readings each step and must detect cyberattacks
38
  vq_window: q-axis voltage error (should be ~0 when healthy)
39
  vd_window: d-axis voltage
40
  omega_window: estimated frequency (normalized, nominal=0)
41
- omega_deviation_window: frequency deviation from nominal in rad/s (useful for detecting slow phase drift)
42
  raw_voltages: [va, vb, vc] at current step
43
  task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
44
 
45
- For task_id=0: Focus on detecting any attack (attack_detected=True/False).
46
- For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
47
- For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.
48
-
49
- Analysis tips:
50
- - In healthy state, vq values should be near 0 and stable.
51
- - Sinusoidal attacks cause oscillating patterns in vq.
52
- - Ramp attacks cause steadily increasing vq magnitude.
53
- - Pulse attacks cause sudden step changes in vq.
54
- - Stealthy attacks cause very slow, gradual drift in omega_deviation_window.
55
- - Look at trends across the full window, not just the latest value.
56
-
57
  Respond ONLY with valid JSON, no explanation:
58
  {
59
  "attack_detected": <bool>,
@@ -170,12 +158,6 @@ def detector_agent(prev_info: dict) -> Optional[dict]:
170
  det = prev_info.get("detector", {})
171
  if not det or "attack_detected" not in det:
172
  return None
173
-
174
- # Fall back to heuristic if detector confidence is < 0.5
175
- # to preserve heuristic base logic scoring results.
176
- if float(det.get("confidence", 0.0)) < 0.5:
177
- return None
178
-
179
  return safe_clamp_action(det)
180
  except Exception:
181
  return None
@@ -186,29 +168,20 @@ def detector_agent(prev_info: dict) -> Optional[dict]:
186
  # =====================================================================
187
 
188
  class HeuristicState:
189
- """Tracks running state for the heuristic agent across steps."""
190
  def __init__(self):
191
  self.reset()
192
-
193
  def reset(self):
194
- self.vq_history = [] # all vq_mean(abs) values
195
- self.omega_dev_history = [] # all omega_dev_mean(abs) values
196
- self.attack_detected = False # latched detection flag
197
- self.predicted_type = 0 # latched classification
198
- self.settled_baseline = None # omega_dev baseline when PLL settles
199
- self.peak_vq = 0.0 # highest vq_mean seen
200
-
201
 
202
  _hstate = HeuristicState()
203
 
204
  def heuristic_agent(obs: dict) -> dict:
205
- """
206
- Rule-based attack detector using cumulative state tracking.
207
- This runs instantly.
208
- The key insight is that the PLL's closed-loop response transforms
209
- attack signals, so I track statistics over time rather than
210
- trying to classify from a single 20-step vq window shape.
211
- """
212
  try:
213
  global _hstate
214
  vq = obs.get("vq_window", [])
@@ -222,41 +195,27 @@ def heuristic_agent(obs: dict) -> dict:
222
  if step == 0:
223
  _hstate.reset()
224
 
225
- # --- Computing per-step features ---
226
  vq_abs = [abs(v) for v in vq]
227
- vq_mean = sum(vq_abs) / len(vq_abs)
228
- vq_max = max(vq_abs)
229
- vq_latest = abs(vq[-1])
230
 
231
  omega_dev_abs = [abs(v) for v in omega_dev]
232
- omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs)
233
 
234
- # Tracking history
235
  _hstate.vq_history.append(vq_mean)
236
  _hstate.omega_dev_history.append(omega_dev_mean)
237
  _hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
238
 
239
- # Recording baseline around step 45-50 (PLL settled)
240
  if step == 50:
241
  _hstate.settled_baseline = omega_dev_mean
242
 
243
- # -----------------------------------------------------------------
244
- # Detection: is vq significantly elevated?
245
- # After PLL warm-start settles (~step 20-30), healthy vq < 0.005
246
- # -----------------------------------------------------------------
247
- if step < 25:
248
- # PLL still settling, don't detect
249
- detected = False
250
- else:
251
  detected = vq_mean > 0.01 or vq_max > 0.025
252
 
253
- # Latch detection on
254
  if detected:
255
  _hstate.attack_detected = True
256
 
257
- # -----------------------------------------------------------------
258
- # Task 0: Binary detection only
259
- # -----------------------------------------------------------------
260
  if task_id == 0:
261
  return safe_clamp_action({
262
  "attack_detected": _hstate.attack_detected,
@@ -265,9 +224,6 @@ def heuristic_agent(obs: dict) -> dict:
265
  "protective_action": 1 if _hstate.attack_detected else 0,
266
  })
267
 
268
- # -----------------------------------------------------------------
269
- # Task 1: Classification using cumulative patterns
270
- # -----------------------------------------------------------------
271
  if task_id == 1:
272
  if not _hstate.attack_detected:
273
  return safe_clamp_action({
@@ -276,26 +232,16 @@ def heuristic_agent(obs: dict) -> dict:
276
  "confidence": 0.7,
277
  "protective_action": 0,
278
  })
279
-
280
- # Classify using cumulative vq_history
281
- # Only classify after enough attack data (10+ steps of elevated vq)
282
  n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
283
-
284
  if n_elevated < 5:
285
- # Not enough data yet, use simple guess
286
  attack_type = 1
287
  else:
288
- # Get recent vq trend (last 10 elevated values)
289
  elevated = [v for v in _hstate.vq_history if v > 0.005]
290
  recent = elevated[-min(20, len(elevated)):]
291
-
292
- # Feature 1: Is vq currently high or has it decayed?
293
  current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
294
-
295
- # Feature 2: How many zero crossings in current window
296
  zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)
297
-
298
- # Feature 3: Is vq growing or shrinking over recent history
299
  if len(recent) >= 6:
300
  first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
301
  last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
@@ -303,35 +249,23 @@ def heuristic_agent(obs: dict) -> dict:
303
  else:
304
  growth = 1.0
305
 
306
- # Classification logic:
307
- # Sinusoidal: persistent oscillation, zero crossings, stable amplitude
308
- # Ramp: growing vq over time (growth > 1)
309
- # Pulse: high initial vq that decays to near zero (current_vs_peak < 0.3)
310
-
311
  if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
312
- # vq has decayed significantly from peak -> pulse (ended)
313
  attack_type = 3
314
  elif current_vs_peak < 0.4 and n_elevated > 30:
315
- # vq decayed after a long time -> pulse
316
  attack_type = 3
317
  elif zero_crossings >= 2 and growth < 1.5:
318
- # Active oscillation without growing -> sinusoidal
319
  attack_type = 1
320
  elif growth > 1.3:
321
- # Growing signal -> ramp
322
  attack_type = 2
323
  elif zero_crossings >= 1:
324
- # Some oscillation -> sinusoidal
325
  attack_type = 1
326
  else:
327
- # Default: if mono-decrease, pulse; else sinusoidal
328
  vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
329
  neg = sum(1 for d in vq_diffs if d < 0)
330
- if neg > 14: # 14/19 = 73% decreasing
331
  attack_type = 3
332
  else:
333
  attack_type = 1
334
-
335
  _hstate.predicted_type = attack_type
336
 
337
  return safe_clamp_action({
@@ -341,20 +275,12 @@ def heuristic_agent(obs: dict) -> dict:
341
  "protective_action": 1,
342
  })
343
 
344
- # -----------------------------------------------------------------
345
- # Task 2: Stealthy attack — detecting omega_dev rising above baseline
346
- # -----------------------------------------------------------------
347
  if task_id == 2:
348
  drift_detected = False
349
  confidence = 0.3
350
-
351
  if step > 50 and _hstate.settled_baseline is not None:
352
  baseline = _hstate.settled_baseline
353
-
354
- # Compare current to baseline
355
  ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
356
-
357
- # Checking if omega_dev is rising relative to recent history
358
  if len(_hstate.omega_dev_history) > 10:
359
  recent_10 = _hstate.omega_dev_history[-10:]
360
  old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
@@ -363,7 +289,6 @@ def heuristic_agent(obs: dict) -> dict:
363
  rising = recent_avg > old_avg * 1.1
364
  else:
365
  rising = False
366
-
367
  if ratio > 2.0:
368
  drift_detected = True
369
  confidence = 0.9
@@ -376,10 +301,8 @@ def heuristic_agent(obs: dict) -> dict:
376
  elif vq_mean > 0.2:
377
  drift_detected = True
378
  confidence = 0.5
379
-
380
  if drift_detected:
381
  _hstate.attack_detected = True
382
-
383
  return safe_clamp_action({
384
  "attack_detected": drift_detected,
385
  "attack_type": 4 if drift_detected else 0,
@@ -489,16 +412,20 @@ def run_episode(task_id: int) -> float:
489
  while not done:
490
  action = None
491
 
492
- # Priority 1: Optional LLM
493
- if USE_LLM:
 
 
 
 
 
 
494
  try:
495
  action = llm_agent(obs)
496
  except Exception:
497
  pass
498
 
499
- # Priority 2: Safe Rule-Based Heuristic Fallback
500
- # Note: We bypass `detector_agent` here to perfectly preserve
501
- # the baseline 0.6786 performance trajectory from github.
502
  if not action:
503
  try:
504
  action = heuristic_agent(obs)
 
38
  vq_window: q-axis voltage error (should be ~0 when healthy)
39
  vd_window: d-axis voltage
40
  omega_window: estimated frequency (normalized, nominal=0)
41
+ omega_deviation_window: frequency deviation from nominal in rad/s
42
  raw_voltages: [va, vb, vc] at current step
43
  task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  Respond ONLY with valid JSON, no explanation:
46
  {
47
  "attack_detected": <bool>,
 
158
  det = prev_info.get("detector", {})
159
  if not det or "attack_detected" not in det:
160
  return None
 
 
 
 
 
 
161
  return safe_clamp_action(det)
162
  except Exception:
163
  return None
 
168
  # =====================================================================
169
 
170
  class HeuristicState:
 
171
  def __init__(self):
172
  self.reset()
 
173
  def reset(self):
174
+ self.vq_history = []
175
+ self.omega_dev_history = []
176
+ self.attack_detected = False
177
+ self.predicted_type = 0
178
+ self.settled_baseline = None
179
+ self.peak_vq = 0.0
 
180
 
181
  _hstate = HeuristicState()
182
 
183
  def heuristic_agent(obs: dict) -> dict:
184
+ """Safe heuristic agent fallback."""
 
 
 
 
 
 
185
  try:
186
  global _hstate
187
  vq = obs.get("vq_window", [])
 
195
  if step == 0:
196
  _hstate.reset()
197
 
 
198
  vq_abs = [abs(v) for v in vq]
199
+ vq_mean = sum(vq_abs) / len(vq_abs) if vq_abs else 0.0
200
+ vq_max = max(vq_abs) if vq_abs else 0.0
 
201
 
202
  omega_dev_abs = [abs(v) for v in omega_dev]
203
+ omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) if omega_dev_abs else 0.0
204
 
 
205
  _hstate.vq_history.append(vq_mean)
206
  _hstate.omega_dev_history.append(omega_dev_mean)
207
  _hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
208
 
 
209
  if step == 50:
210
  _hstate.settled_baseline = omega_dev_mean
211
 
212
+ detected = False
213
+ if step >= 25:
 
 
 
 
 
 
214
  detected = vq_mean > 0.01 or vq_max > 0.025
215
 
 
216
  if detected:
217
  _hstate.attack_detected = True
218
 
 
 
 
219
  if task_id == 0:
220
  return safe_clamp_action({
221
  "attack_detected": _hstate.attack_detected,
 
224
  "protective_action": 1 if _hstate.attack_detected else 0,
225
  })
226
 
 
 
 
227
  if task_id == 1:
228
  if not _hstate.attack_detected:
229
  return safe_clamp_action({
 
232
  "confidence": 0.7,
233
  "protective_action": 0,
234
  })
235
+
 
 
236
  n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
 
237
  if n_elevated < 5:
 
238
  attack_type = 1
239
  else:
 
240
  elevated = [v for v in _hstate.vq_history if v > 0.005]
241
  recent = elevated[-min(20, len(elevated)):]
 
 
242
  current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
 
 
243
  zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)
244
+
 
245
  if len(recent) >= 6:
246
  first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
247
  last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
 
249
  else:
250
  growth = 1.0
251
 
 
 
 
 
 
252
  if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
 
253
  attack_type = 3
254
  elif current_vs_peak < 0.4 and n_elevated > 30:
 
255
  attack_type = 3
256
  elif zero_crossings >= 2 and growth < 1.5:
 
257
  attack_type = 1
258
  elif growth > 1.3:
 
259
  attack_type = 2
260
  elif zero_crossings >= 1:
 
261
  attack_type = 1
262
  else:
 
263
  vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
264
  neg = sum(1 for d in vq_diffs if d < 0)
265
+ if neg > 14:
266
  attack_type = 3
267
  else:
268
  attack_type = 1
 
269
  _hstate.predicted_type = attack_type
270
 
271
  return safe_clamp_action({
 
275
  "protective_action": 1,
276
  })
277
 
 
 
 
278
  if task_id == 2:
279
  drift_detected = False
280
  confidence = 0.3
 
281
  if step > 50 and _hstate.settled_baseline is not None:
282
  baseline = _hstate.settled_baseline
 
 
283
  ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
 
 
284
  if len(_hstate.omega_dev_history) > 10:
285
  recent_10 = _hstate.omega_dev_history[-10:]
286
  old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
 
289
  rising = recent_avg > old_avg * 1.1
290
  else:
291
  rising = False
 
292
  if ratio > 2.0:
293
  drift_detected = True
294
  confidence = 0.9
 
301
  elif vq_mean > 0.2:
302
  drift_detected = True
303
  confidence = 0.5
 
304
  if drift_detected:
305
  _hstate.attack_detected = True
 
306
  return safe_clamp_action({
307
  "attack_detected": drift_detected,
308
  "attack_type": 4 if drift_detected else 0,
 
412
  while not done:
413
  action = None
414
 
415
+ # Priority 1: Detector Output
416
+ try:
417
+ action = detector_agent(prev_info)
418
+ except Exception:
419
+ pass
420
+
421
+ # Priority 2: Optional LLM
422
+ if not action and USE_LLM:
423
  try:
424
  action = llm_agent(obs)
425
  except Exception:
426
  pass
427
 
428
+ # Priority 3: Safe Rule-Based Heuristic Fallback
 
 
429
  if not action:
430
  try:
431
  action = heuristic_agent(obs)