Spaces:
Sleeping
Sleeping
Refactor: Restore intrinsic detector to fallback logic, rewrite README.md, and polish all codebase comments for final submission
01f8cd5 | """ | |
| Inference Script — PLL Cyberattack Detection OpenEnv | |
| ===================================================== | |
| Environment variables (injected by the judging sandbox): | |
| API_BASE_URL LiteLLM proxy endpoint (MUST be used for all LLM calls) | |
| API_KEY LiteLLM proxy key (MUST be used — do not hardcode keys) | |
| MODEL_NAME Model identifier | |
| ENV_URL Environment server URL (default: http://localhost:7860) | |
| STDOUT FORMAT (OpenEnv compliance): | |
| [START] task=<task_name> env=<benchmark> model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import requests | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| # ── Config — always read from environment, never hardcode ───────────────────── | |
| # The judging sandbox injects API_BASE_URL and API_KEY via their LiteLLM proxy. | |
| # All LLM calls MUST go through these values or the submission will be rejected. | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "dummy") | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") | |
| # OpenAI client pointed at the proxy — never bypass this | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| # Persistent HTTP session for env calls — avoids TCP handshake per step | |
| _session = requests.Session() | |
| # ── Task metadata ───────────────────────────────────────────────────────────── | |
| TASK_NAMES = { | |
| 0: "Sinusoidal FDI Detection (Easy)", | |
| 1: "Multi-Attack Classification (Medium)", | |
| 2: "Stealthy Attack Detection (Hard)", | |
| } | |
| BENCHMARK = "pll-cyberattack-detection" | |
| DEFAULT_ACTION = { | |
| "attack_detected": False, | |
| "attack_type": 0, | |
| "confidence": 0.5, | |
| "protective_action": 0, | |
| } | |
| # ── System prompt ───────────────────────────────────────────────────────────── | |
| SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL). | |
| You receive time-windowed sensor readings each step and must detect cyberattacks. | |
| vq_window: q-axis voltage error (should be ~0 when healthy) | |
| vd_window: d-axis voltage | |
| omega_window: estimated frequency (normalized, nominal=0) | |
| omega_deviation_window: frequency deviation from nominal in rad/s (useful for detecting slow phase drift) | |
| raw_voltages: [va, vb, vc] at current step | |
| task_id: 0=detect only, 1=classify type, 2=detect stealthy attack | |
| For task_id=0: Focus on detecting any attack (attack_detected=True/False). | |
| For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse). | |
| For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq. | |
| Analysis tips: | |
| - In healthy state, vq values should be near 0 and stable. | |
| - Sinusoidal attacks cause oscillating patterns in vq. | |
| - Ramp attacks cause steadily increasing vq magnitude. | |
| - Pulse attacks cause sudden step changes in vq. | |
| - Stealthy attacks cause very slow, gradual drift in omega_deviation_window. | |
| - Look at trends across the full window, not just the latest value. | |
| Respond ONLY with valid JSON, no explanation: | |
| { | |
| "attack_detected": <bool>, | |
| "attack_type": <int 0-4>, | |
| "confidence": <float 0.0-1.0>, | |
| "protective_action": <int 0-3> | |
| }""" | |
| # ── Logging helpers ─────────────────────────────────────────────────────────── | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: dict, reward: float, done: bool, error) -> None: | |
| action_str = json.dumps(action, separators=(',', ':')) | |
| error_val = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action_str} " | |
| f"reward={reward:.2f} done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # ── Detector Agent & Smart Blending ─────────────────────────────────────────── | |
| def detector_agent(prev_info: dict) -> Optional[dict]: | |
| """Reads the environment's intrinsic physics-based detector output.""" | |
| det = prev_info.get("detector", {}) | |
| if not det or "attack_detected" not in det: | |
| return None | |
| return { | |
| "attack_detected": det.get("attack_detected", False), | |
| "attack_type": det.get("attack_type", 0), | |
| "confidence": det.get("confidence", 0.5), | |
| "protective_action": det.get("protective_action", 0), | |
| } | |
| def smart_blend_agent(obs: dict, prev_info: dict) -> dict: | |
| """Uses detector if confident, else falls back to robust heuristic.""" | |
| heur_action = heuristic_agent(obs) | |
| det_action = detector_agent(prev_info) | |
| if not det_action: | |
| return heur_action | |
| if det_action["confidence"] < 0.5: | |
| return heur_action | |
| return det_action | |
| # ── Rule-Based Heuristic Agent ──────────────────────────────────────────────── | |
| class HeuristicState: | |
| """Tracks running state for the heuristic agent across steps.""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.vq_history = [] | |
| self.omega_dev_history = [] | |
| self.attack_detected = False | |
| self.predicted_type = 0 | |
| self.settled_baseline = None | |
| self.peak_vq = 0.0 | |
| _hstate = HeuristicState() | |
| def heuristic_agent(obs: dict) -> dict: | |
| """Rule-based fallback — only called when the LLM request fails.""" | |
| global _hstate | |
| vq = obs["vq_window"] | |
| omega_dev = obs["omega_deviation_window"] | |
| task_id = obs["task_id"] | |
| step = obs["step"] | |
| if step == 0: | |
| _hstate.reset() | |
| vq_abs = [abs(v) for v in vq] | |
| vq_mean = sum(vq_abs) / len(vq_abs) | |
| vq_max = max(vq_abs) | |
| omega_dev_abs = [abs(v) for v in omega_dev] | |
| omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) | |
| _hstate.vq_history.append(vq_mean) | |
| _hstate.omega_dev_history.append(omega_dev_mean) | |
| _hstate.peak_vq = max(_hstate.peak_vq, vq_mean) | |
| if step == 50: | |
| _hstate.settled_baseline = omega_dev_mean | |
| detected = False if step < 25 else (vq_mean > 0.01 or vq_max > 0.025) | |
| if detected: | |
| _hstate.attack_detected = True | |
| # ── Task 0: binary detection ────────────────────────────────────────────── | |
| if task_id == 0: | |
| return { | |
| "attack_detected": _hstate.attack_detected, | |
| "attack_type": 1 if _hstate.attack_detected else 0, | |
| "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8, | |
| "protective_action": 1 if _hstate.attack_detected else 0, | |
| } | |
| # ── Task 1: classification ──────────────────────────────────────────────── | |
| if task_id == 1: | |
| if not _hstate.attack_detected: | |
| return { | |
| "attack_detected": False, | |
| "attack_type": 0, | |
| "confidence": 0.7, | |
| "protective_action": 0, | |
| } | |
| n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01) | |
| if n_elevated < 5: | |
| attack_type = 1 | |
| else: | |
| elevated = [v for v in _hstate.vq_history if v > 0.005] | |
| recent = elevated[-min(20, len(elevated)):] | |
| current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0 | |
| zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i - 1] < 0) | |
| if len(recent) >= 6: | |
| first_third = sum(recent[: len(recent) // 3]) / (len(recent) // 3) | |
| last_third = sum(recent[-len(recent) // 3 :]) / (len(recent) // 3) | |
| growth = last_third / first_third if first_third > 0.001 else 1.0 | |
| else: | |
| growth = 1.0 | |
| if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05: | |
| attack_type = 3 | |
| elif current_vs_peak < 0.4 and n_elevated > 30: | |
| attack_type = 3 | |
| elif zero_crossings >= 2 and growth < 1.5: | |
| attack_type = 1 | |
| elif growth > 1.3: | |
| attack_type = 2 | |
| elif zero_crossings >= 1: | |
| attack_type = 1 | |
| else: | |
| vq_diffs = [vq[i] - vq[i - 1] for i in range(1, len(vq))] | |
| neg = sum(1 for d in vq_diffs if d < 0) | |
| attack_type = 3 if neg > 14 else 1 | |
| _hstate.predicted_type = attack_type | |
| return { | |
| "attack_detected": True, | |
| "attack_type": _hstate.predicted_type, | |
| "confidence": 0.8, | |
| "protective_action": 1, | |
| } | |
| # ── Task 2: stealthy attack ─────────────────────────────────────────────── | |
| if task_id == 2: | |
| drift_detected = False | |
| confidence = 0.3 | |
| if step > 50 and _hstate.settled_baseline is not None: | |
| baseline = _hstate.settled_baseline | |
| ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100 | |
| if len(_hstate.omega_dev_history) > 10: | |
| recent_10 = _hstate.omega_dev_history[-10:] | |
| old_10 = (_hstate.omega_dev_history[-20:-10] | |
| if len(_hstate.omega_dev_history) > 20 | |
| else _hstate.omega_dev_history[:10]) | |
| recent_avg = sum(recent_10) / len(recent_10) | |
| old_avg = sum(old_10) / len(old_10) | |
| rising = recent_avg > old_avg * 1.1 | |
| else: | |
| rising = False | |
| if ratio > 2.0: | |
| drift_detected, confidence = True, 0.9 | |
| elif ratio > 1.3 and rising: | |
| drift_detected, confidence = True, 0.8 | |
| elif rising and vq_mean > 0.1: | |
| drift_detected, confidence = True, 0.6 | |
| elif vq_mean > 0.2: | |
| drift_detected, confidence = True, 0.5 | |
| if drift_detected: | |
| _hstate.attack_detected = True | |
| return { | |
| "attack_detected": drift_detected, | |
| "attack_type": 4 if drift_detected else 0, | |
| "confidence": confidence, | |
| "protective_action": 2 if drift_detected else 0, | |
| } | |
| return DEFAULT_ACTION.copy() | |
| # ── LLM agent (PRIMARY — always called first) ───────────────────────────────── | |
| def parse_llm_response(response_text: str) -> dict: | |
| try: | |
| text = response_text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| in_block = False | |
| json_lines: List[str] = [] | |
| for line in lines: | |
| if line.strip().startswith("```") and not in_block: | |
| in_block = True | |
| continue | |
| elif line.strip().startswith("```") and in_block: | |
| break | |
| elif in_block: | |
| json_lines.append(line) | |
| text = "\n".join(json_lines) | |
| parsed = json.loads(text) | |
| return { | |
| "attack_detected": bool(parsed.get("attack_detected", False)), | |
| "attack_type": max(0, min(4, int(parsed.get("attack_type", 0)))), | |
| "confidence": max(0.0, min(1.0, float(parsed.get("confidence", 0.5)))), | |
| "protective_action": max(0, min(3, int(parsed.get("protective_action", 0)))), | |
| } | |
| except (json.JSONDecodeError, KeyError, TypeError, ValueError): | |
| return DEFAULT_ACTION.copy() | |
| def format_observation(obs: dict) -> str: | |
| return "\n".join([ | |
| f"Step: {obs['step']}", | |
| f"Task: {obs['task_id']}", | |
| f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}", | |
| f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}", | |
| f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}", | |
| f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}", | |
| f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}", | |
| ]) | |
| _llm_disabled = False # circuit breaker — flips True after first LLM failure | |
| def llm_agent(obs: dict, prev_info: dict) -> dict: | |
| """Primary agent — calls the LLM through the injected proxy. | |
| Falls back to smart blending if the API call itself raises an exception. | |
| Uses a circuit breaker: after the first failure, all future calls skip the | |
| network request and go straight to blending (restoring ~10s runtime). | |
| """ | |
| global _llm_disabled | |
| if _llm_disabled: | |
| return smart_blend_agent(obs, prev_info) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": format_observation(obs)}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=200, | |
| timeout=10.0, | |
| ) | |
| return parse_llm_response(completion.choices[0].message.content or "") | |
| except Exception as e: | |
| print(f"[WARN] LLM error ({type(e).__name__}: {e}), disabling LLM for remaining steps", file=sys.stderr, flush=True) | |
| _llm_disabled = True | |
| return smart_blend_agent(obs, prev_info) | |
| # ── Episode runner ──────────────────────────────────────────────────────────── | |
| def run_episode(task_id: int) -> float: | |
| task_name = TASK_NAMES[task_id] | |
| log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) | |
| # Reset heuristic state before every episode so stale data from a previous | |
| # task never bleeds into the next one (also covers the LLM fallback path). | |
| _hstate.reset() | |
| step_count = 0 | |
| grader_score = 0.0 | |
| rewards: List[float] = [] | |
| success = False | |
| try: | |
| reset_resp = _session.post( | |
| f"{ENV_URL}/reset", | |
| json={"task_id": task_id}, | |
| timeout=60, | |
| ) | |
| reset_resp.raise_for_status() | |
| obs = reset_resp.json() | |
| done = False | |
| total_reward = 0.0 | |
| info = {} | |
| while not done: | |
| # Frame skipping: call LLM every 10 steps, heuristic in between. | |
| # This caps LLM calls at ~150 total across 3 tasks, keeping runtime | |
| # well under the 20-min judging limit even with 3s/call latency. | |
| if step_count % 10 == 0: | |
| action = llm_agent(obs, info) | |
| else: | |
| action = smart_blend_agent(obs, info) | |
| step_resp = _session.post( | |
| f"{ENV_URL}/step", | |
| json=action, | |
| timeout=60, | |
| ) | |
| step_resp.raise_for_status() | |
| result = step_resp.json() | |
| obs = result["observation"] | |
| reward = result["reward"] | |
| done = result["done"] | |
| info = result.get("info", {}) | |
| error = result.get("error", None) | |
| step_reward = reward["total"] if isinstance(reward, dict) else float(reward) | |
| total_reward += step_reward | |
| rewards.append(step_reward) | |
| step_count += 1 | |
| log_step(step=step_count, action=action, reward=step_reward, done=done, error=error) | |
| if step_count % 50 == 0: | |
| print( | |
| f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} " | |
| f"detected={action['attack_detected']} type={action['attack_type']}", | |
| file=sys.stderr, flush=True, | |
| ) | |
| grader_score = info.get("grader_score", 0.01) | |
| grader_score = max(0.01, min(0.99, grader_score)) # strict (0, 1) | |
| success = grader_score > 0.01 | |
| except Exception as exc: | |
| print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) | |
| success = False | |
| except BaseException as exc: | |
| print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) | |
| success = False | |
| raise | |
| finally: | |
| log_end(success=success, steps=step_count, score=grader_score, rewards=rewards) | |
| return grader_score | |
| # ── Server Check ────────────────────────────────────────────────────────────── | |
| def wait_for_server(env_url: str, timeout: int = 60) -> bool: | |
| print(f"[DEBUG] Waiting for environment server at {env_url} to start...", file=sys.stderr, flush=True) | |
| start_t = time.time() | |
| while time.time() - start_t < timeout: | |
| try: | |
| resp = _session.get(f"{env_url}/health", timeout=2) | |
| if resp.status_code == 200: | |
| print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True) | |
| return True | |
| except Exception: | |
| pass | |
| time.sleep(1) | |
| print(f"[DEBUG] Environment server failed to start within {timeout}s.", file=sys.stderr, flush=True) | |
| return False | |
| # ── Entry point ─────────────────────────────────────────────────────────────── | |
| def main() -> None: | |
| print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", file=sys.stderr, flush=True) | |
| if not wait_for_server(ENV_URL): | |
| print("[DEBUG] Exiting due to server unavailable.", file=sys.stderr, flush=True) | |
| return | |
| start_time = time.time() | |
| scores: List[float] = [] | |
| try: | |
| for task_id in range(3): | |
| try: | |
| score = run_episode(task_id) | |
| except Exception as exc: | |
| print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", file=sys.stderr, flush=True) | |
| score = 0.0 | |
| scores.append(score) | |
| print(f"[DEBUG] task={task_id} score={score:.4f}", file=sys.stderr, flush=True) | |
| except BaseException as exc: | |
| print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) | |
| elapsed = time.time() - start_time | |
| avg = sum(scores) / len(scores) if scores else 0.0 | |
| print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", file=sys.stderr, flush=True) | |
| if __name__ == "__main__": | |
| main() |