krishuggingface commited on
Commit
4b406c3
·
1 Parent(s): cd2d585

fix(inference): refactor proxy initialization and baseline logic

Browse files
Files changed (1) hide show
  1. inference.py +338 -261
inference.py CHANGED
@@ -1,36 +1,53 @@
1
- """
2
- Inference Script — PLL Cyberattack Detection OpenEnv
3
- =====================================================
4
- Hardened for the Meta PyTorch Hackathon Validator.
5
- Proxy-compliant, local-env safe, and crash-resistant.
6
-
7
- MANDATORY environment variables (for proxy):
8
- API_BASE_URL The API endpoint for the LLM proxy
9
- API_KEY The injected proxy token
10
- """
11
-
12
  import os
13
  import json
14
  import time
15
- import requests
 
 
16
  from typing import Optional, Dict, Any
17
 
18
- # 1) Validator-injected LLM proxy variables (No HF_TOKEN hardcoding)
19
- API_BASE_URL = os.environ.get("API_BASE_URL")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  API_KEY = os.environ.get("API_KEY")
21
-
22
- # 2) Change ENV_URL default to validator local container
23
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:7860")
24
  USE_LLM = os.environ.get("USE_LLM", "0") == "1"
25
 
26
- # Initialize client ONLY if proxy vars exist
27
- client = None
 
28
  if API_BASE_URL and API_KEY:
29
  try:
30
- from openai import OpenAI
 
31
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
32
  except Exception as e:
33
- print(f"Warning: Failed to initialize OpenAI client: {e}")
 
34
 
35
  SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL).
36
  You receive time-windowed sensor readings each step and must detect cyberattacks.
@@ -38,10 +55,14 @@ You receive time-windowed sensor readings each step and must detect cyberattacks
38
  vq_window: q-axis voltage error (should be ~0 when healthy)
39
  vd_window: d-axis voltage
40
  omega_window: estimated frequency (normalized, nominal=0)
41
- omega_deviation_window: frequency deviation from nominal in rad/s
42
  raw_voltages: [va, vb, vc] at current step
43
  task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
44
 
 
 
 
 
45
  Respond ONLY with valid JSON, no explanation:
46
  {
47
  "attack_detected": <bool>,
@@ -64,112 +85,140 @@ DEFAULT_ACTION = {
64
  }
65
 
66
 
67
- # =====================================================================
68
- # Logging Helpers (OpenEnv compliance)
69
- # =====================================================================
70
-
71
  def log_start(task: str, env: str, model: str) -> None:
72
- try:
73
- print(f"[START] task={task} env={env} model={model}", flush=True)
74
- except Exception:
75
- pass
76
 
77
 
78
  def log_step(step: int, action: dict, reward: float, done: bool, error) -> None:
79
- try:
80
- action_str = json.dumps(action, separators=(',', ':'))
81
- error_val = error if error else "null"
82
- print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
83
- except Exception:
84
- pass
85
 
86
 
87
  def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
88
- try:
89
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
90
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
91
- except Exception:
92
- pass
93
 
94
 
95
- # =====================================================================
96
- # Safe Network Client Helpers
97
- # =====================================================================
 
 
 
 
 
 
 
 
 
98
 
99
- def safe_post_json(url: str, payload: dict, timeout: int = 30, retries: int = 2) -> Optional[Dict[str, Any]]:
100
- """Safe POST request handler with retries and no unhandled exceptions."""
 
 
 
 
 
 
 
 
101
  for attempt in range(retries + 1):
102
  try:
103
  response = requests.post(url, json=payload, timeout=timeout)
104
  response.raise_for_status()
 
105
  return response.json()
106
  except Exception as e:
107
- if attempt == retries:
108
- print(f" Network error on {url} after {retries} retries: {e}")
109
- return None
110
- time.sleep(1.0)
 
 
 
 
111
  return None
112
 
113
 
114
- def warmup_proxy() -> None:
115
- """Make at least one tiny proxy call at startup if client exists."""
116
- global client
117
- if not client:
118
  return
 
 
 
119
  try:
120
- print("Warming up LLM proxy connection...")
121
- client.chat.completions.create(
122
- model=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct"),
123
  messages=[{"role": "user", "content": "ping"}],
124
  max_tokens=1,
125
- timeout=10,
126
  )
127
- print("Proxy warmup successful.")
128
  except Exception as e:
129
- print(f"Proxy warmup failed (non-fatal): {e}")
130
-
131
 
132
- # =====================================================================
133
- # Action Parser and Clamper
134
- # =====================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- def safe_clamp_action(action: dict) -> dict:
137
- """Clamps outputs to valid bounds and handles missing keys safely."""
138
  try:
139
- return {
140
- "attack_detected": bool(action.get("attack_detected", False)),
141
- "attack_type": max(0, min(4, int(action.get("attack_type", 0)))),
142
- "confidence": max(0.0, min(1.0, float(action.get("confidence", 0.5)))),
143
- "protective_action": max(0, min(3, int(action.get("protective_action", 0)))),
144
- }
145
- except Exception:
146
- return DEFAULT_ACTION.copy()
147
 
 
 
 
148
 
149
- # =====================================================================
150
- # Detector-Based Agent
151
- # =====================================================================
152
 
153
  def detector_agent(prev_info: dict) -> Optional[dict]:
154
- """Reads the environment's adaptive detector output."""
155
- try:
156
- if not prev_info:
157
- return None
158
- det = prev_info.get("detector", {})
159
- if not det or "attack_detected" not in det:
160
- return None
161
- return safe_clamp_action(det)
162
- except Exception:
163
  return None
 
 
 
 
 
 
164
 
165
 
166
- # =====================================================================
167
- # Rule-Based Heuristic Agent
168
- # =====================================================================
169
-
170
  class HeuristicState:
171
  def __init__(self):
172
  self.reset()
 
173
  def reset(self):
174
  self.vq_history = []
175
  self.omega_dev_history = []
@@ -178,26 +227,29 @@ class HeuristicState:
178
  self.settled_baseline = None
179
  self.peak_vq = 0.0
180
 
 
181
  _hstate = HeuristicState()
182
 
 
183
  def heuristic_agent(obs: dict) -> dict:
184
- """Safe heuristic agent fallback."""
185
- try:
186
- global _hstate
187
- vq = obs.get("vq_window", [])
188
- omega_dev = obs.get("omega_deviation_window", [])
189
- task_id = obs.get("task_id", 0)
190
- step = obs.get("step", 0)
191
 
192
- if not vq or not omega_dev:
193
- return DEFAULT_ACTION.copy()
 
 
 
 
 
194
 
195
- if step == 0:
196
- _hstate.reset()
197
 
 
198
  vq_abs = [abs(v) for v in vq]
199
- vq_mean = sum(vq_abs) / len(vq_abs) if vq_abs else 0.0
200
- vq_max = max(vq_abs) if vq_abs else 0.0
 
201
 
202
  omega_dev_abs = [abs(v) for v in omega_dev]
203
  omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) if omega_dev_abs else 0.0
@@ -209,42 +261,46 @@ def heuristic_agent(obs: dict) -> dict:
209
  if step == 50:
210
  _hstate.settled_baseline = omega_dev_mean
211
 
212
- detected = False
213
- if step >= 25:
 
214
  detected = vq_mean > 0.01 or vq_max > 0.025
215
 
216
  if detected:
217
  _hstate.attack_detected = True
218
 
219
  if task_id == 0:
220
- return safe_clamp_action({
221
  "attack_detected": _hstate.attack_detected,
222
  "attack_type": 1 if _hstate.attack_detected else 0,
223
  "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8,
224
  "protective_action": 1 if _hstate.attack_detected else 0,
225
- })
226
 
227
  if task_id == 1:
228
  if not _hstate.attack_detected:
229
- return safe_clamp_action({
230
  "attack_detected": False,
231
  "attack_type": 0,
232
  "confidence": 0.7,
233
  "protective_action": 0,
234
- })
235
-
236
  n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
 
237
  if n_elevated < 5:
238
  attack_type = 1
239
  else:
240
  elevated = [v for v in _hstate.vq_history if v > 0.005]
241
  recent = elevated[-min(20, len(elevated)):]
242
- current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
243
- zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)
244
-
 
245
  if len(recent) >= 6:
246
- first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
247
- last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
 
248
  growth = last_third / first_third if first_third > 0.001 else 1.0
249
  else:
250
  growth = 1.0
@@ -260,35 +316,40 @@ def heuristic_agent(obs: dict) -> dict:
260
  elif zero_crossings >= 1:
261
  attack_type = 1
262
  else:
263
- vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
264
  neg = sum(1 for d in vq_diffs if d < 0)
265
- if neg > 14:
266
- attack_type = 3
267
- else:
268
- attack_type = 1
269
  _hstate.predicted_type = attack_type
270
 
271
- return safe_clamp_action({
272
  "attack_detected": True,
273
  "attack_type": _hstate.predicted_type,
274
  "confidence": 0.8,
275
  "protective_action": 1,
276
- })
277
 
278
  if task_id == 2:
279
  drift_detected = False
280
  confidence = 0.3
 
281
  if step > 50 and _hstate.settled_baseline is not None:
282
  baseline = _hstate.settled_baseline
283
- ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
 
284
  if len(_hstate.omega_dev_history) > 10:
285
  recent_10 = _hstate.omega_dev_history[-10:]
286
- old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
 
 
 
 
287
  recent_avg = sum(recent_10) / len(recent_10)
288
  old_avg = sum(old_10) / len(old_10)
289
  rising = recent_avg > old_avg * 1.1
290
  else:
291
  rising = False
 
292
  if ratio > 2.0:
293
  drift_detected = True
294
  confidence = 0.9
@@ -301,58 +362,27 @@ def heuristic_agent(obs: dict) -> dict:
301
  elif vq_mean > 0.2:
302
  drift_detected = True
303
  confidence = 0.5
 
304
  if drift_detected:
305
  _hstate.attack_detected = True
306
- return safe_clamp_action({
 
307
  "attack_detected": drift_detected,
308
  "attack_type": 4 if drift_detected else 0,
309
  "confidence": confidence,
310
  "protective_action": 2 if drift_detected else 0,
311
- })
312
 
313
  return DEFAULT_ACTION.copy()
 
314
  except Exception as e:
315
- print(f"Heuristic agent error: {e}")
316
  return DEFAULT_ACTION.copy()
317
 
318
 
319
- # =====================================================================
320
- # LLM Agent
321
- # =====================================================================
322
-
323
- def llm_agent(obs: dict) -> Optional[dict]:
324
- """Safe LLM execution."""
325
- global client
326
- if not client:
327
- return None
328
-
329
  try:
330
- parts = [
331
- f"Step: {obs.get('step', 0)}",
332
- f"Task: {obs.get('task_id', 0)}",
333
- f"vq_window: {[round(v, 6) for v in obs.get('vq_window', [])]}",
334
- f"vd_window: {[round(v, 6) for v in obs.get('vd_window', [])]}",
335
- f"omega_window: {[round(v, 6) for v in obs.get('omega_window', [])]}",
336
- f"omega_deviation_window: {[round(v, 6) for v in obs.get('omega_deviation_window', [])]}",
337
- f"raw_voltages: {[round(v, 6) for v in obs.get('raw_voltages', [])]}",
338
- ]
339
- obs_text = "\n".join(parts)
340
-
341
- model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
342
- completion = client.chat.completions.create(
343
- model=model_name,
344
- messages=[
345
- {"role": "system", "content": SYSTEM_PROMPT},
346
- {"role": "user", "content": obs_text},
347
- ],
348
- temperature=0.1,
349
- max_tokens=200,
350
- timeout=15,
351
- )
352
- llm_response = completion.choices[0].message.content
353
-
354
- # Parse JSON
355
- text = llm_response.strip()
356
  if text.startswith("```"):
357
  lines = text.split("\n")
358
  json_lines = []
@@ -368,143 +398,190 @@ def llm_agent(obs: dict) -> Optional[dict]:
368
  text = "\n".join(json_lines)
369
 
370
  parsed = json.loads(text)
371
- return safe_clamp_action(parsed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  except Exception as e:
373
- print(f" LLM error: {e}, returning None")
374
- return None
375
 
376
 
377
- # =====================================================================
378
- # Episode Runner
379
- # =====================================================================
 
 
 
 
 
380
 
381
- def run_episode(task_id: int) -> float:
382
- # 3) Detector-first default logic
383
- agent_name = "Hybrid (Detector -> Heuristic)"
384
- if USE_LLM and API_BASE_URL and API_KEY:
385
- agent_name = "Verbose Hybrid (Detector -> LLM -> Heuristic)"
386
 
387
- log_start(task=TASK_NAMES.get(task_id, str(task_id)), env="pll-cyberattack-detection", model=agent_name)
 
 
 
 
 
388
 
389
- print(f"\n{'='*60}")
390
- print(f"Task {task_id}: {TASK_NAMES.get(task_id, 'Unknown')}")
391
- print(f"Agent Hierarchy: {agent_name}")
392
- print(f"{'='*60}")
393
 
394
  step_count = 0
395
  grader_score = 0.0
396
  rewards = []
397
-
 
 
398
  try:
399
- reset_url = f"{ENV_URL}/reset"
400
- reset_payload = {"task_id": task_id}
401
- obs = safe_post_json(reset_url, reset_payload)
402
-
403
- if not obs:
404
- print(f"Failed to reset environment via {reset_url}. Aborting episode.")
405
- log_end(success=False, steps=0, score=0.0, rewards=[])
 
406
  return 0.0
407
 
 
408
  done = False
409
  total_reward = 0.0
410
- prev_info = {}
411
 
412
  while not done:
413
- action = None
414
-
415
- # Priority 1: Detector Output
416
  try:
417
- action = detector_agent(prev_info)
418
- except Exception:
419
- pass
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- # Priority 2: Optional LLM
422
- if not action and USE_LLM:
423
- try:
424
- action = llm_agent(obs)
425
- except Exception:
426
- pass
427
 
428
- # Priority 3: Safe Rule-Based Heuristic Fallback
429
- if not action:
430
  try:
431
- action = heuristic_agent(obs)
432
  except Exception:
433
- action = DEFAULT_ACTION.copy()
434
 
435
- # Execute step safely
436
- step_url = f"{ENV_URL}/step"
437
- result = safe_post_json(step_url, action)
438
 
439
- if not result:
440
- print("Environment step failed after retries. Safely terminating episode.")
441
- break
442
 
 
 
 
 
 
 
 
 
 
 
443
  try:
444
- obs = result.get("observation", {})
445
- reward_info = result.get("reward", {"total": 0.0})
446
- reward = reward_info.get("total", 0.0)
447
- done = bool(result.get("done", True))
448
- info = result.get("info", {})
449
- prev_info = info
450
-
451
- total_reward += reward
452
- rewards.append(reward)
453
- log_step(step=step_count, action=action, reward=reward, done=done, error=None)
454
-
455
- step_count += 1
456
- if step_count % 50 == 0:
457
- print(f" Step {step_count:3d} | Reward: {reward:+.4f} | "
458
- f"Cumulative: {total_reward:+.4f} | "
459
- f"Detected: {action.get('attack_detected', False)} | "
460
- f"Type: {action.get('attack_type', 0)}")
461
-
462
- # Early breaks
463
- if done:
464
- grader_score = info.get("grader_score", 0.0)
465
-
466
- except Exception as loop_e:
467
- print(f"Error handling step response data: {loop_e}. Terminating cleanly.")
468
- break
469
 
470
  print(f"\n Episode complete: {step_count} steps")
471
  print(f" Total reward: {total_reward:+.4f}")
472
  print(f" Grader score: {grader_score:.4f}")
473
-
 
 
474
  except Exception as e:
475
- print(f"Critical episode failure caught safely: {e}")
 
 
476
  finally:
477
  log_end(success=grader_score > 0.0, steps=step_count, score=grader_score, rewards=rewards)
478
 
479
- return grader_score
480
-
481
 
482
  if __name__ == "__main__":
483
- print("PLL Cyberattack Detection Hardened Agentic Inference")
484
- print(f"Proxy Env: {ENV_URL}")
485
-
486
- # 4) Warm up proxy safely
 
 
 
487
  warmup_proxy()
488
 
489
  start_time = time.time()
490
  scores = []
491
 
492
- try:
493
- for task_id in range(3):
494
- score = run_episode(task_id)
495
- print(f"Task {task_id} score: {score:.4f}")
496
- scores.append(score)
497
-
498
- elapsed = time.time() - start_time
499
-
500
- print(f"\n{'='*60}")
501
- print("FINAL RESULTS")
502
- print(f"{'='*60}")
503
- for i, score in enumerate(scores):
504
- print(f" Task {i} ({TASK_NAMES.get(i, str(i))}): {score:.4f}")
505
- if scores:
506
- print(f"\n Average score: {sum(scores)/len(scores):.4f}")
507
- print(f" Total time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
508
- print(f"{'='*60}")
509
- except Exception as e:
510
- print(f"Main loop crashed safely: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
4
+ import logging
5
+ import traceback
6
+ import threading
7
  from typing import Optional, Dict, Any
8
 
9
+ import requests
10
+ from openai import OpenAI
11
+
12
+ # ---------------------------------------------------------------------
13
+ # 1. SETUP LOGGING
14
+ # ---------------------------------------------------------------------
15
+ # Ensure logs look like: [TIMESTAMP] [STAGE] message
16
+ class StageFormatter(logging.Formatter):
17
+ def format(self, record):
18
+ # We manually use the prefix if provided in extra
19
+ stage = getattr(record, 'stage', 'SYSTEM')
20
+ self._style._fmt = f"[%(asctime)s] [{stage}] %(message)s"
21
+ # Ensure fast formatting matching standard requirements
22
+ return super().format(record)
23
+
24
+ logger = logging.getLogger("inference")
25
+ logger.setLevel(logging.DEBUG)
26
+ handler = logging.StreamHandler()
27
+ handler.setFormatter(StageFormatter(datefmt="%Y-%m-%d %H:%M:%S"))
28
+ logger.addHandler(handler)
29
+
30
+ logger.info("Initializing Agent Scripts", extra={"stage": "APP STARTUP"})
31
+
32
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
33
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
34
  API_KEY = os.environ.get("API_KEY")
 
 
35
  ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:7860")
36
  USE_LLM = os.environ.get("USE_LLM", "0") == "1"
37
 
38
+ logger.info("Environment variables loaded.", extra={"stage": "APP STARTUP"})
39
+
40
+ client: Optional[OpenAI] = None
41
  if API_BASE_URL and API_KEY:
42
  try:
43
+ logger.info("Initializing OpenAI Client", extra={"stage": "MODEL LOADING"})
44
+ _start_time = time.time()
45
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
46
+ _end_time = time.time()
47
+ logger.info(f"Client Initialized. Duration: {_end_time - _start_time:.4f}s", extra={"stage": "MODEL LOADING"})
48
  except Exception as e:
49
+ logger.error(f"Failed to initialize OpenAI client: {e}\n{traceback.format_exc()}", extra={"stage": "APP STARTUP"})
50
+ client = None
51
 
52
  SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL).
53
  You receive time-windowed sensor readings each step and must detect cyberattacks.
 
55
  vq_window: q-axis voltage error (should be ~0 when healthy)
56
  vd_window: d-axis voltage
57
  omega_window: estimated frequency (normalized, nominal=0)
58
+ omega_deviation_window: frequency deviation from nominal in rad/s (useful for detecting slow phase drift)
59
  raw_voltages: [va, vb, vc] at current step
60
  task_id: 0=detect only, 1=classify type, 2=detect stealthy attack
61
 
62
+ For task_id=0: Focus on detecting any attack (attack_detected=True/False).
63
+ For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
64
+ For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.
65
+
66
  Respond ONLY with valid JSON, no explanation:
67
  {
68
  "attack_detected": <bool>,
 
85
  }
86
 
87
 
 
 
 
 
88
  def log_start(task: str, env: str, model: str) -> None:
89
+ logger.info(f"task={task} env={env} model={model}", extra={"stage": "EPISODE START"})
 
 
 
90
 
91
 
92
  def log_step(step: int, action: dict, reward: float, done: bool, error) -> None:
93
+ action_str = json.dumps(action, separators=(",", ":"))
94
+ error_val = error if error else "null"
95
+ logger.debug(
96
+ f"step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}",
97
+ extra={"stage": "EPISODE STEP"}
98
+ )
99
 
100
 
101
  def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
102
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
103
+ logger.info(
104
+ f"success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
105
+ extra={"stage": "EPISODE END"}
106
+ )
107
 
108
 
109
+ def safe_action(action: Dict[str, Any]) -> Dict[str, Any]:
110
+ try:
111
+ return {
112
+ "attack_detected": bool(action.get("attack_detected", False)),
113
+ "attack_type": max(0, min(4, int(action.get("attack_type", 0)))),
114
+ "confidence": max(0.0, min(1.0, float(action.get("confidence", 0.5)))),
115
+ "protective_action": max(0, min(3, int(action.get("protective_action", 0)))),
116
+ }
117
+ except Exception as e:
118
+ logger.error(f"Action constraint failed: {e}\n{traceback.format_exc()}", extra={"stage": "POSTPROCESSING"})
119
+ return DEFAULT_ACTION.copy()
120
+
121
 
122
+ def safe_post_json(
123
+ url: str,
124
+ payload: Dict[str, Any],
125
+ timeout: int = 10,
126
+ retries: int = 2,
127
+ ) -> Optional[Dict[str, Any]]:
128
+ last_error = None
129
+ logger.debug(f"Calling endpoint {url}", extra={"stage": "API CALL (REQ)"})
130
+ _start_t = time.time()
131
+
132
  for attempt in range(retries + 1):
133
  try:
134
  response = requests.post(url, json=payload, timeout=timeout)
135
  response.raise_for_status()
136
+ logger.debug(f"Response ok from {url} in {time.time()-_start_t:.4f}s", extra={"stage": "API CALL (RES)"})
137
  return response.json()
138
  except Exception as e:
139
+ last_error = e
140
+ logger.warning(
141
+ f"HTTP error calling {url} (attempt {attempt + 1}/{retries + 1}): {e}",
142
+ extra={"stage": "API CALL (ERR)"}
143
+ )
144
+ time.sleep(0.5)
145
+
146
+ logger.error(f"Giving up on {url}: {last_error}\n{traceback.format_exc()}", extra={"stage": "API CALL (ERR)"})
147
  return None
148
 
149
 
150
+ def _warmup_worker() -> None:
151
+ """Non-blocking LLM warmup executed inside a thread."""
152
+ if client is None:
153
+ logger.info("LLM proxy warmup skipped (client unavailable).", extra={"stage": "MODEL LOADING"})
154
  return
155
+
156
+ logger.info("Initializing LLM Proxy Warmup Thread...", extra={"stage": "MODEL LOADING"})
157
+ _req_t = time.time()
158
  try:
159
+ _ = client.chat.completions.create(
160
+ model=MODEL_NAME,
 
161
  messages=[{"role": "user", "content": "ping"}],
162
  max_tokens=1,
163
+ temperature=0,
164
  )
165
+ logger.info(f"LLM proxy warmup successful in {time.time() - _req_t:.4f}s.", extra={"stage": "MODEL LOADING"})
166
  except Exception as e:
167
+ logger.error(f"LLM proxy warmup failed: {e}\n{traceback.format_exc()}", extra={"stage": "MODEL LOADING (ERR)"})
 
168
 
169
+ def warmup_proxy() -> None:
170
+ """Make one tiny proxy call gracefully via threading to avoid app blocking"""
171
+ t = threading.Thread(target=_warmup_worker, daemon=True)
172
+ t.start()
173
+
174
+
175
+ # ---------------------------------------------------------------------
176
+ # ZERO-DEPENDENCY HEALTHCHECK SERVER
177
+ # ---------------------------------------------------------------------
178
+ from http.server import BaseHTTPRequestHandler, HTTPServer
179
+
180
+ class FastHealthcheck(BaseHTTPRequestHandler):
181
+ def do_GET(self):
182
+ logger.info(f"Healthcheck triggered at {self.path}", extra={"stage": "HEALTHCHECK"})
183
+ self.send_response(200)
184
+ self.send_header("Content-type", "application/json")
185
+ self.end_headers()
186
+ self.wfile.write(b'{"status":"ok"}')
187
+ logger.info("Healthcheck returned 200 OK immediately", extra={"stage": "HEALTHCHECK"})
188
+
189
+ def log_message(self, format, *args):
190
+ pass # disable default stdout spam from simple server
191
 
192
+ def _run_healthcheck() -> None:
 
193
  try:
194
+ # Binding to 7860 as Spaces default checks it
195
+ server = HTTPServer(('0.0.0.0', 7860), FastHealthcheck)
196
+ logger.info("Background Healthcheck server bound to 0.0.0.0:7860", extra={"stage": "APP STARTUP"})
197
+ server.serve_forever()
198
+ except Exception as e:
199
+ logger.error(f"Healthcheck server crash: {e}\n{traceback.format_exc()}", extra={"stage": "APP STARTUP (ERR)"})
 
 
200
 
201
+ # Start Healthcheck Thread instantly
202
+ t_health = threading.Thread(target=_run_healthcheck, daemon=True)
203
+ t_health.start()
204
 
 
 
 
205
 
206
  def detector_agent(prev_info: dict) -> Optional[dict]:
207
+ det = (prev_info or {}).get("detector", {})
208
+ if not isinstance(det, dict) or "attack_detected" not in det:
 
 
 
 
 
 
 
209
  return None
210
+ return {
211
+ "attack_detected": det.get("attack_detected", False),
212
+ "attack_type": det.get("attack_type", 0),
213
+ "confidence": det.get("confidence", 0.5),
214
+ "protective_action": det.get("protective_action", 0),
215
+ }
216
 
217
 
 
 
 
 
218
  class HeuristicState:
219
  def __init__(self):
220
  self.reset()
221
+
222
  def reset(self):
223
  self.vq_history = []
224
  self.omega_dev_history = []
 
227
  self.settled_baseline = None
228
  self.peak_vq = 0.0
229
 
230
+
231
  _hstate = HeuristicState()
232
 
233
+
234
  def heuristic_agent(obs: dict) -> dict:
235
+ global _hstate
 
 
 
 
 
 
236
 
237
+ try:
238
+ vq = obs["vq_window"]
239
+ omega_dev = obs["omega_deviation_window"]
240
+ task_id = int(obs["task_id"])
241
+ step = int(obs["step"])
242
+ except Exception:
243
+ return DEFAULT_ACTION.copy()
244
 
245
+ if step == 0:
246
+ _hstate.reset()
247
 
248
+ try:
249
  vq_abs = [abs(v) for v in vq]
250
+ vq_mean = sum(vq_abs) / len(vq_abs)
251
+ vq_max = max(vq_abs)
252
+ vq_latest = abs(vq[-1]) if vq else 0.0
253
 
254
  omega_dev_abs = [abs(v) for v in omega_dev]
255
  omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) if omega_dev_abs else 0.0
 
261
  if step == 50:
262
  _hstate.settled_baseline = omega_dev_mean
263
 
264
+ if step < 25:
265
+ detected = False
266
+ else:
267
  detected = vq_mean > 0.01 or vq_max > 0.025
268
 
269
  if detected:
270
  _hstate.attack_detected = True
271
 
272
  if task_id == 0:
273
+ return {
274
  "attack_detected": _hstate.attack_detected,
275
  "attack_type": 1 if _hstate.attack_detected else 0,
276
  "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8,
277
  "protective_action": 1 if _hstate.attack_detected else 0,
278
+ }
279
 
280
  if task_id == 1:
281
  if not _hstate.attack_detected:
282
+ return {
283
  "attack_detected": False,
284
  "attack_type": 0,
285
  "confidence": 0.7,
286
  "protective_action": 0,
287
+ }
288
+
289
  n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
290
+
291
  if n_elevated < 5:
292
  attack_type = 1
293
  else:
294
  elevated = [v for v in _hstate.vq_history if v > 0.005]
295
  recent = elevated[-min(20, len(elevated)):]
296
+
297
+ current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0.0
298
+ zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i - 1] < 0)
299
+
300
  if len(recent) >= 6:
301
+ third = max(1, len(recent) // 3)
302
+ first_third = sum(recent[:third]) / third
303
+ last_third = sum(recent[-third:]) / third
304
  growth = last_third / first_third if first_third > 0.001 else 1.0
305
  else:
306
  growth = 1.0
 
316
  elif zero_crossings >= 1:
317
  attack_type = 1
318
  else:
319
+ vq_diffs = [vq[i] - vq[i - 1] for i in range(1, len(vq))]
320
  neg = sum(1 for d in vq_diffs if d < 0)
321
+ attack_type = 3 if neg > 14 else 1
322
+
 
 
323
  _hstate.predicted_type = attack_type
324
 
325
+ return {
326
  "attack_detected": True,
327
  "attack_type": _hstate.predicted_type,
328
  "confidence": 0.8,
329
  "protective_action": 1,
330
+ }
331
 
332
  if task_id == 2:
333
  drift_detected = False
334
  confidence = 0.3
335
+
336
  if step > 50 and _hstate.settled_baseline is not None:
337
  baseline = _hstate.settled_baseline
338
+ ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100.0
339
+
340
  if len(_hstate.omega_dev_history) > 10:
341
  recent_10 = _hstate.omega_dev_history[-10:]
342
+ old_10 = (
343
+ _hstate.omega_dev_history[-20:-10]
344
+ if len(_hstate.omega_dev_history) > 20
345
+ else _hstate.omega_dev_history[:10]
346
+ )
347
  recent_avg = sum(recent_10) / len(recent_10)
348
  old_avg = sum(old_10) / len(old_10)
349
  rising = recent_avg > old_avg * 1.1
350
  else:
351
  rising = False
352
+
353
  if ratio > 2.0:
354
  drift_detected = True
355
  confidence = 0.9
 
362
  elif vq_mean > 0.2:
363
  drift_detected = True
364
  confidence = 0.5
365
+
366
  if drift_detected:
367
  _hstate.attack_detected = True
368
+
369
+ return {
370
  "attack_detected": drift_detected,
371
  "attack_type": 4 if drift_detected else 0,
372
  "confidence": confidence,
373
  "protective_action": 2 if drift_detected else 0,
374
+ }
375
 
376
  return DEFAULT_ACTION.copy()
377
+
378
  except Exception as e:
379
+ logger.warning(f"heuristic_agent failed: {e}\n{traceback.format_exc()}", extra={"stage": "HEURISTIC AGENT (ERR)"})
380
  return DEFAULT_ACTION.copy()
381
 
382
 
383
+ def parse_llm_response(response_text: str) -> dict:
 
 
 
 
 
 
 
 
 
384
  try:
385
+ text = (response_text or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  if text.startswith("```"):
387
  lines = text.split("\n")
388
  json_lines = []
 
398
  text = "\n".join(json_lines)
399
 
400
  parsed = json.loads(text)
401
+ return safe_action(
402
+ {
403
+ "attack_detected": parsed.get("attack_detected", False),
404
+ "attack_type": parsed.get("attack_type", 0),
405
+ "confidence": parsed.get("confidence", 0.5),
406
+ "protective_action": parsed.get("protective_action", 0),
407
+ }
408
+ )
409
+ except Exception:
410
+ return DEFAULT_ACTION.copy()
411
+
412
+
413
+ def format_observation(obs: dict) -> str:
414
+ try:
415
+ parts = [
416
+ f"Step: {obs['step']}",
417
+ f"Task: {obs['task_id']}",
418
+ f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}",
419
+ f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}",
420
+ f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}",
421
+ f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}",
422
+ f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
423
+ ]
424
+ return "\n".join(parts)
425
+ except Exception:
426
+ return ""
427
+
428
+
429
+ def llm_agent(obs: dict) -> dict:
430
+ if client is None:
431
+ return heuristic_agent(obs)
432
+
433
+ try:
434
+ obs_text = format_observation(obs)
435
+ completion = client.chat.completions.create(
436
+ model=MODEL_NAME,
437
+ messages=[
438
+ {"role": "system", "content": SYSTEM_PROMPT},
439
+ {"role": "user", "content": obs_text},
440
+ ],
441
+ temperature=0.1,
442
+ max_tokens=200,
443
+ )
444
+ llm_response = completion.choices[0].message.content if completion and completion.choices else ""
445
+ return parse_llm_response(llm_response)
446
  except Exception as e:
447
+ logger.warning(f"LLM error ({type(e).__name__}: {e})\n{traceback.format_exc()}", extra={"stage": "LLM AGENT (ERR)"})
448
+ return heuristic_agent(obs)
449
 
450
 
451
+ def choose_action(obs: dict, prev_info: dict) -> dict:
452
+ # Preserve the baseline heuristic behavior by default.
453
+ try:
454
+ if USE_LLM and client is not None:
455
+ return safe_action(llm_agent(obs))
456
+ except Exception:
457
+ pass
458
+ return safe_action(heuristic_agent(obs))
459
 
 
 
 
 
 
460
 
461
+ def run_episode(task_id: int) -> float:
462
+ log_start(
463
+ task=TASK_NAMES[task_id],
464
+ env="pll-cyberattack-detection",
465
+ model=MODEL_NAME if USE_LLM else "rule-based-heuristic",
466
+ )
467
 
468
+ print(f"\n{'=' * 60}")
469
+ print(f"Task {task_id}: {TASK_NAMES[task_id]}")
470
+ print(f"Agent: {'LLM (' + MODEL_NAME + ')' if USE_LLM else 'Rule-Based Heuristic'}")
471
+ print(f"{'=' * 60}")
472
 
473
  step_count = 0
474
  grader_score = 0.0
475
  rewards = []
476
+ info: Dict[str, Any] = {}
477
+ prev_info: Dict[str, Any] = {}
478
+
479
  try:
480
+ reset_result = safe_post_json(
481
+ f"{ENV_URL}/reset",
482
+ {"task_id": task_id},
483
+ timeout=10,
484
+ retries=2,
485
+ )
486
+ if not isinstance(reset_result, dict):
487
+ logger.error("Reset failed; skipping episode.", extra={"stage": "ENV RESET"})
488
  return 0.0
489
 
490
+ obs = reset_result
491
  done = False
492
  total_reward = 0.0
 
493
 
494
  while not done:
 
 
 
495
  try:
496
+ action = choose_action(obs, prev_info)
497
+ except Exception as e:
498
+ logger.warning(f"Action selection failed: {e}\n{traceback.format_exc()}", extra={"stage": "ACTION SELECTION"})
499
+ action = DEFAULT_ACTION.copy()
500
+
501
+ result = safe_post_json(
502
+ f"{ENV_URL}/step",
503
+ action,
504
+ timeout=10,
505
+ retries=2,
506
+ )
507
+ if not isinstance(result, dict):
508
+ logger.error("Step failed; ending episode early.", extra={"stage": "ENV STEP"})
509
+ break
510
 
511
+ obs = result.get("observation", obs)
512
+ reward = result.get("reward", {})
513
+ done = bool(result.get("done", False))
514
+ info = result.get("info", {})
 
 
515
 
516
+ step_reward = 0.0
517
+ if isinstance(reward, dict):
518
  try:
519
+ step_reward = float(reward.get("total", 0.0))
520
  except Exception:
521
+ step_reward = 0.0
522
 
523
+ total_reward += step_reward
524
+ rewards.append(step_reward)
525
+ log_step(step=step_count, action=action, reward=step_reward, done=done, error=None)
526
 
527
+ prev_info = info if isinstance(info, dict) else {}
528
+ step_count += 1
 
529
 
530
+ if step_count % 50 == 0:
531
+ print(
532
+ f" Step {step_count:3d} | Reward: {step_reward:+.4f} | "
533
+ f"Cumulative: {total_reward:+.4f} | "
534
+ f"Detected: {action.get('attack_detected', False)} | "
535
+ f"Type: {action.get('attack_type', 0)}",
536
+ flush=True,
537
+ )
538
+
539
+ if isinstance(info, dict):
540
  try:
541
+ grader_score = float(info.get("grader_score", 0.0))
542
+ except Exception:
543
+ grader_score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
  print(f"\n Episode complete: {step_count} steps")
546
  print(f" Total reward: {total_reward:+.4f}")
547
  print(f" Grader score: {grader_score:.4f}")
548
+
549
+ return grader_score
550
+
551
  except Exception as e:
552
+ logger.error(f"Episode crashed safely: {e}\n{traceback.format_exc()}", extra={"stage": "EPISODE SEVERE ERR"})
553
+ return 0.0
554
+
555
  finally:
556
  log_end(success=grader_score > 0.0, steps=step_count, score=grader_score, rewards=rewards)
557
 
 
 
558
 
559
  if __name__ == "__main__":
560
+ agent_name = f"LLM ({MODEL_NAME})" if USE_LLM else "Rule-Based Heuristic"
561
+ logger.info("PLL Cyberattack Detection — Agentic Inference", extra={"stage": "APP STARTUP"})
562
+ logger.info(f"Agent: {agent_name}", extra={"stage": "APP STARTUP"})
563
+ logger.info(f"Environment: {ENV_URL}", extra={"stage": "APP STARTUP"})
564
+ if not USE_LLM:
565
+ logger.info("(Set USE_LLM=1 to use LLM agent instead of heuristic)", extra={"stage": "APP STARTUP"})
566
+
567
  warmup_proxy()
568
 
569
  start_time = time.time()
570
  scores = []
571
 
572
+ for task_id in range(3):
573
+ score = run_episode(task_id)
574
+ print(f"Task {task_id} score: {score:.4f}")
575
+ scores.append(score)
576
+
577
+ elapsed = time.time() - start_time
578
+
579
+ print(f"\n{'=' * 60}")
580
+ print("FINAL RESULTS")
581
+ print(f"{'=' * 60}")
582
+ for i, score in enumerate(scores):
583
+ print(f" Task {i} ({TASK_NAMES[i]}): {score:.4f}")
584
+ if scores:
585
+ print(f"\n Average score: {sum(scores) / len(scores):.4f}")
586
+ print(f" Total time: {elapsed:.1f}s ({elapsed / 60:.1f} min)")
587
+ print(f"{'=' * 60}")