| """ |
| Baseline Policies for OpenGrid |
| ================================ |
| Provides two agent implementations: |
| 1. heuristic_policy — deterministic rule-based baseline for reproducible scoring |
| 2. llm_policy — LLM-based policy using OpenAI-compatible API |
| |
| Both support GridObservation (single-agent) and ZoneObservation (multi-agent). |
| """ |
|
|
| import json |
| import logging |
| import os |
| from typing import List, Union |
|
|
| from openai import OpenAI |
| from .models import GridAction, BusAdjustment, GridObservation, ZoneObservation |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o") |
| API_KEY = os.getenv("OPENAI_API_KEY", os.getenv("HF_TOKEN", "")) |
|
|
| |
| _CLIENT = None |
|
|
|
|
| def _get_client() -> OpenAI: |
| """Lazy-cached client creation.""" |
| global _CLIENT |
| if _CLIENT is None: |
| if not API_KEY: |
| raise RuntimeError( |
| "Missing API key. Set OPENAI_API_KEY or HF_TOKEN environment variable." |
| ) |
| _CLIENT = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=15.0) |
| return _CLIENT |
|
|
|
|
| def _obs_buses(obs): |
| """Extract bus list from either GridObservation or ZoneObservation.""" |
| return getattr(obs, "buses", getattr(obs, "local_buses", [])) |
|
|
|
|
| def _obs_lines(obs): |
| """Extract line list from either GridObservation or ZoneObservation.""" |
| if hasattr(obs, "lines"): |
| return obs.lines |
| internal = getattr(obs, "internal_lines", []) |
| boundary = getattr(obs, "boundary_lines", []) |
| return list(internal) + list(boundary) |
|
|
|
|
| SYSTEM_PROMPT = """You are a Power Grid Controller AI. Your goal is to maintain grid stability. |
| |
| Key objectives: |
| 1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5–50.5 Hz) |
| 2. Prevent transmission line overloads (rho < 1.0) |
| 3. Avoid grid islanding (blackout) |
| |
| Available actions: |
| 1. bus_adjustments: List of {"bus_id": int, "delta": float} |
| - Positive delta = increase power injection (discharge battery / ramp up generator) |
| - Negative delta = decrease power injection (charge battery / ramp down generator) |
| - Only works on battery and generator buses (NOT slack, load, solar, or wind) |
| - Slack bus injection is computed by physics — adjustments are ignored |
| 2. topology_actions: List of {"line_id": str, "action": "open" | "close"} |
| - Opening a line removes it; closing reconnects. 3-step cooldown after each switch. |
| - WARNING: Opening lines can cause islanding → blackout → -100 reward |
| - Prefer NO topology actions unless absolutely necessary. |
| |
| Strategy tips: |
| - If frequency < 50 Hz: grid needs more generation → discharge batteries or ramp up generators |
| - If frequency > 50 Hz: grid has excess generation → charge batteries or ramp down generators |
| - If a line rho > 0.9: reduce generation at one end or increase at the other to shift flow |
| - Prefer minimal actions. Do-nothing is better than reckless switching. |
| |
| Respond with ONLY a valid JSON object, no markdown, no explanation. Example: |
| {"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []} |
| """ |
|
|
|
|
| def parse_action_response(response_text: str) -> GridAction: |
| """Parse LLM response into a GridAction. Falls back to no-op on parse errors.""" |
| try: |
| text = response_text.strip() |
|
|
| |
| if text.startswith("```"): |
| lines = text.splitlines() |
| if lines[0].startswith("```"): |
| lines = lines[1:] |
| if lines and lines[-1].startswith("```"): |
| lines = lines[:-1] |
| text = "\n".join(lines).strip() |
|
|
| |
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| return GridAction() |
|
|
| data = json.loads(text[start:end + 1]) |
|
|
| |
| if isinstance(data, list): |
| data = data[0] if data else {} |
|
|
| return GridAction(**data) |
| except Exception: |
| return GridAction() |
|
|
|
|
| def llm_policy(obs: Union[GridObservation, ZoneObservation]) -> GridAction: |
| """LLM-based policy using the OpenAI-compatible API. |
| |
| Supports both GridObservation and ZoneObservation. |
| Falls back to no-op on any error. |
| """ |
| client = _get_client() |
| obs_json = obs.model_dump_json() |
|
|
| try: |
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": f"Current Grid State:\n{obs_json}"} |
| ], |
| temperature=0.0, |
| max_tokens=300, |
| ) |
| action_str = response.choices[0].message.content |
| return parse_action_response(action_str) |
| except Exception as e: |
| logger.debug("LLM policy error: %s", e, exc_info=True) |
| return GridAction() |
|
|
|
|
| def heuristic_policy( |
| obs: Union[GridObservation, ZoneObservation], |
| ) -> GridAction: |
| """Rule-based baseline policy for reproducible scoring. |
| |
| Strategy: |
| - Use batteries and generators for frequency regulation (proportional control) |
| - DO NOT open overloaded lines (causes cascading failures) |
| - DO NOT adjust the slack bus (overwritten by physics solver) |
| - Let the environment/safety layer clamp any out-of-range deltas |
| |
| Supports both GridObservation (single-agent) and ZoneObservation (multi-agent). |
| """ |
| adj = [] |
| freq = obs.grid_frequency |
| freq_error = freq - 50.0 |
|
|
| buses = list(_obs_buses(obs)) |
| lines = list(_obs_lines(obs)) |
|
|
| batteries = [b for b in buses if b.type == 'battery'] |
| generators = [b for b in buses if b.type == 'generator'] |
|
|
| |
| if abs(freq_error) > 0.1 and batteries: |
| |
| correction_total = -freq_error * 15.0 |
| correction_total = max(-20.0, min(20.0, correction_total)) |
| per_battery = correction_total / len(batteries) |
|
|
| for bus in batteries: |
| if per_battery > 0 and bus.soc > 0: |
| |
| adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery)) |
| elif per_battery < 0: |
| |
| adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery)) |
|
|
| |
| if abs(freq_error) > 0.25: |
| for bus in generators: |
| delta = -freq_error * 5.0 |
| ramp = getattr(bus, 'ramp_rate', 20.0) |
| delta = max(-ramp, min(ramp, delta)) |
| adj.append(BusAdjustment(bus_id=bus.id, delta=delta)) |
|
|
| |
| adjusted_for_overload = set() |
| for line in lines: |
| if line.rho > 0.95 and line.connected: |
| for bus in generators: |
| if bus.id not in adjusted_for_overload and bus.p_injection > 5: |
| adj.append(BusAdjustment(bus_id=bus.id, delta=-3.0)) |
| adjusted_for_overload.add(bus.id) |
| break |
|
|
| |
| return GridAction(bus_adjustments=adj, topology_actions=[]) |