| import numpy as np |
| import math |
| from typing import List, Dict, Tuple, Optional |
| from .models import ( |
| GridObservation, GridAction, GridReward, GridInfo, |
| LineStatus, BusState, ZoneObservation, ZoneInfo, |
| SafetyReport, OversightReport, MultiAgentStepResult, |
| ) |
| from .physics import DCSolver, IslandedException |
| from .safety import SafetyLayer |
| from .oversight import OversightAgent |
|
|
|
|
| class OpenGridEnv: |
| """ |
| OpenGrid: A renewable energy grid load-balancing environment. |
| |
| Supports two modes: |
| 1. Single-agent (backward compatible): reset()/step()/state() |
| 2. Multi-agent POMDP: reset_multi()/step_multi() with per-zone |
| partial observability, safety layer, and oversight agent. |
| |
| The agent(s) must maintain grid stability by: |
| - Balancing generation and load (frequency control) |
| - Managing transmission line loading (congestion management) |
| - Coordinating battery storage and topology switching |
| """ |
|
|
| NOMINAL_FREQ = 50.0 |
| FREQ_DEADBAND = 0.5 |
| FREQ_NOISE_STD = 0.05 |
| LINE_NOISE_STD = 0.02 |
|
|
| def __init__(self, config: Dict): |
| self.config = config |
| self.num_buses = config['num_buses'] |
| self.lines_config = config['lines'] |
| self.buses_config = config['buses'] |
|
|
| |
| self.slack_bus_id = next( |
| (b['id'] for b in self.buses_config if b['type'] == 'slack'), 0 |
| ) |
|
|
| self.solver = DCSolver(self.num_buses, slack_bus=self.slack_bus_id) |
| self.timestep = 0 |
| self.max_steps = config.get('max_steps', 50) |
|
|
| self.bus_state = [] |
| self.line_state = [] |
| self.cooldowns = {} |
| self.slack_injection = 0.0 |
| self._is_blackout = False |
|
|
| |
| self._bus_cfg_by_id = {b['id']: b for b in self.buses_config} |
| self._line_cfg_by_id = {l['id']: l for l in self.lines_config} |
|
|
| |
| self.num_agents = config.get('num_agents', 1) |
| self.zone_assignments = config.get('zone_assignments', {}) |
| self.zone_names = config.get('zone_names', []) |
| self.zone_bus_ids = config.get('zone_bus_ids', {}) |
| self.internal_lines = config.get('internal_lines', {}) |
| self.boundary_lines = config.get('boundary_lines', {}) |
|
|
| |
| self.safety_layer = SafetyLayer(config) |
| self.oversight_agent = OversightAgent(config) |
|
|
| |
| self._safety_reports_this_step: List[SafetyReport] = [] |
| self._oversight_report_this_step: Optional[OversightReport] = None |
|
|
| |
| total_load = sum( |
| b['base_p'] for b in self.buses_config if b['type'] == 'load' |
| ) |
| total_gen = sum( |
| b['max_p'] for b in self.buses_config |
| if b['type'] in ['slack', 'generator', 'solar', 'wind'] |
| ) |
| total_system = max(total_load + total_gen, 50.0) |
| self.droop_constant = 2.5 / total_system |
|
|
| |
| self._seed = config.get('seed', 42) |
| self._rng = np.random.default_rng(self._seed) |
|
|
| |
| |
| |
|
|
| def _set_state(self, obs_dict: dict) -> None: |
| """Restore the environment to a state described by an observation dict. |
| |
| This enables environment-grounded GRPO rewards: instead of scoring |
| actions with a heuristic proxy, we restore the env to the observed state, |
| step with the proposed action, and use the real reward. |
| |
| Args: |
| obs_dict: A dict from ZoneObservation.model_dump() or |
| GridObservation.model_dump(), containing at minimum: |
| timestep, grid_frequency, and bus/line state. |
| """ |
| self.timestep = obs_dict.get('timestep', 0) |
| self._is_blackout = obs_dict.get('is_blackout', False) |
| self.cooldowns = obs_dict.get('cooldowns', {k: 0 for k in self.cooldowns}) |
|
|
| |
| local_buses = obs_dict.get('local_buses', obs_dict.get('buses', [])) |
| if local_buses: |
| for b_obs in local_buses: |
| b_dyn = self._find_bus_state(b_obs['id']) |
| if b_dyn is not None: |
| b_dyn['p'] = b_obs.get('p_injection', b_dyn['p']) |
| b_dyn['soc'] = b_obs.get('soc', b_dyn.get('soc', 0.0)) |
|
|
| |
| all_lines = (obs_dict.get('internal_lines', []) or []) + \ |
| (obs_dict.get('boundary_lines', []) or []) + \ |
| (obs_dict.get('lines', []) or []) |
| for l_obs in all_lines: |
| l_dyn = self._find_line(l_obs['id']) |
| if l_dyn is not None: |
| l_dyn['connected'] = l_obs.get('connected', True) |
| l_dyn['flow'] = l_obs.get('flow', 0.0) |
|
|
| |
| self._bus_state_by_id = {b['id']: b for b in self.bus_state} |
| self._line_state_by_id = {l['id']: l for l in self.line_state} |
|
|
| |
| freq = obs_dict.get('grid_frequency', self.NOMINAL_FREQ) |
| self.slack_injection = (self.NOMINAL_FREQ - freq) / self.droop_constant |
|
|
| |
| slack_dyn = self._find_bus_state(self.slack_bus_id) |
| if slack_dyn is not None: |
| slack_dyn['p'] = self.slack_injection |
|
|
| |
| |
| |
|
|
| def reset(self) -> GridObservation: |
| """Reset the environment to initial state. Returns initial observation.""" |
| self.timestep = 0 |
| self.slack_injection = 0.0 |
| self.cooldowns = {l['id']: 0 for l in self.lines_config} |
| self._rng = np.random.default_rng(self._seed) |
| self.oversight_agent.reset() |
|
|
| self.bus_state = [] |
| for b in self.buses_config: |
| init_p = 0.0 |
| |
| if b['type'] in ['generator']: |
| init_p = b['max_p'] * 0.5 |
| self.bus_state.append({ |
| 'id': b['id'], 'p': init_p, 'soc': b.get('init_soc', 0.0) |
| }) |
| self.line_state = [ |
| {'id': l['id'], 'connected': True, 'flow': 0.0} |
| for l in self.lines_config |
| ] |
|
|
| |
| self._bus_state_by_id = {b['id']: b for b in self.bus_state} |
| self._line_state_by_id = {l['id']: l for l in self.line_state} |
|
|
| self._is_blackout = False |
| self._update_loads_and_renewables() |
| self._run_power_flow() |
|
|
| return self._get_obs() |
|
|
| def step(self, action: GridAction) -> Tuple[GridObservation, GridReward, bool, GridInfo]: |
| """Execute one step: apply action, update dynamics, solve physics, compute reward.""" |
| self.timestep += 1 |
| reward_components = {"survival": 1.0, "frequency": 0.0, "overload": 0.0, "action_cost": 0.0} |
| self._is_blackout = False |
|
|
| |
| for t_act in action.topology_actions: |
| l_id = t_act.line_id |
| if l_id not in self.cooldowns: |
| continue |
| if self.cooldowns[l_id] == 0: |
| line = self._find_line(l_id) |
| if line is None: |
| continue |
| current_status = line['connected'] |
| new_status = (t_act.action == "close") |
|
|
| if current_status != new_status: |
| line['connected'] = new_status |
| self.cooldowns[l_id] = 3 |
| reward_components['action_cost'] -= 0.5 |
|
|
| |
| for l_id in self.cooldowns: |
| self.cooldowns[l_id] = max(0, self.cooldowns[l_id] - 1) |
|
|
| |
| for adj in action.bus_adjustments: |
| bus_cfg = self._find_bus_config(adj.bus_id) |
| bus_dyn = self._find_bus_state(adj.bus_id) |
| if bus_cfg is None or bus_dyn is None: |
| continue |
|
|
| delta = adj.delta |
|
|
| if bus_cfg['type'] == 'battery': |
| max_charge = bus_cfg['capacity'] - bus_dyn['soc'] |
| max_discharge = bus_dyn['soc'] |
|
|
| if delta > 0: |
| delta = min(delta, max_discharge) |
| else: |
| delta = max(delta, -max_charge) |
|
|
| bus_dyn['soc'] = np.clip(bus_dyn['soc'] - delta, 0.0, bus_cfg['capacity']) |
| bus_dyn['p'] = delta |
|
|
| elif bus_cfg['type'] not in ['load', 'solar', 'wind']: |
| max_ramp = bus_cfg.get('ramp_rate', 10.0) |
| delta = np.clip(delta, -max_ramp, max_ramp) |
| new_p = bus_dyn['p'] + delta |
| bus_dyn['p'] = np.clip(new_p, bus_cfg['min_p'], bus_cfg['max_p']) |
|
|
| |
| self._update_loads_and_renewables() |
|
|
| |
| try: |
| self._run_power_flow() |
|
|
| |
| for l in self.line_state: |
| if l['connected']: |
| flow = l['flow'] |
| limit = self._get_line_capacity(l['id']) |
| rho = abs(flow) / limit if limit > 0 else 0.0 |
|
|
| if rho > 1.0: |
| reward_components['overload'] -= (rho - 1.0) ** 2 * 20 |
| elif rho > 0.8: |
| reward_components['overload'] -= 0.1 |
|
|
| |
| freq = self._compute_frequency() |
| freq_dev = abs(freq - self.NOMINAL_FREQ) |
| if freq_dev > self.FREQ_DEADBAND: |
| raw_penalty = (freq_dev - self.FREQ_DEADBAND) * 0.5 |
| reward_components['frequency'] -= min(raw_penalty, 1.5) |
| elif freq_dev < 0.1: |
| reward_components['frequency'] += 0.2 |
|
|
| except IslandedException: |
| self._is_blackout = True |
| reward_components['survival'] = -100.0 |
|
|
| done = self._is_blackout or (self.timestep >= self.max_steps) |
|
|
| total_reward = sum(reward_components.values()) |
| reward = GridReward(value=total_reward, components=reward_components) |
| info = GridInfo(task_id=self.config['id'], is_blackout=self._is_blackout) |
|
|
| return self._get_obs(), reward, done, info |
|
|
| def state(self) -> GridObservation: |
| """Return current state (alias for observation).""" |
| return self._get_obs() |
|
|
| |
| |
| |
|
|
| def reset_multi(self) -> Dict[int, ZoneObservation]: |
| """Reset environment and return per-agent partial observations.""" |
| self.reset() |
| return { |
| agent_id: self._get_zone_obs(agent_id) |
| for agent_id in range(self.num_agents) |
| } |
|
|
| def step_multi(self, agent_actions: Dict[int, GridAction]) -> MultiAgentStepResult: |
| """Multi-agent step with safety layer and oversight. |
| |
| Flow: |
| 1. Safety layer validates each agent's actions |
| 2. Combine corrected actions into one GridAction |
| 3. Run single-agent step with combined action |
| 4. Oversight agent evaluates coordination |
| 5. Compute per-agent rewards (local + global + safety + coordination) |
| """ |
| pre_frequency = self._compute_frequency() |
| pre_bus_state = [dict(b) for b in self.bus_state] |
|
|
| |
| safety_reports: Dict[int, SafetyReport] = {} |
| corrected_actions: Dict[int, GridAction] = {} |
|
|
| for agent_id in range(self.num_agents): |
| proposed = agent_actions.get(agent_id, GridAction()) |
| corrected, report = self.safety_layer.validate_and_correct( |
| agent_id=agent_id, |
| proposed_action=proposed, |
| current_line_state=self.line_state, |
| current_bus_state=self.bus_state, |
| cooldowns=self.cooldowns, |
| ) |
| corrected_actions[agent_id] = corrected |
| safety_reports[agent_id] = report |
|
|
| self._safety_reports_this_step = safety_reports |
|
|
| |
| combined = GridAction( |
| bus_adjustments=[ |
| adj for action in corrected_actions.values() |
| for adj in action.bus_adjustments |
| ], |
| topology_actions=[ |
| t for action in corrected_actions.values() |
| for t in action.topology_actions |
| ], |
| ) |
|
|
| |
| obs, base_reward, done, info = self.step(combined) |
| post_frequency = self._compute_frequency() |
|
|
| |
| oversight_report = self.oversight_agent.evaluate( |
| agent_actions=agent_actions, |
| safety_reports=safety_reports, |
| pre_frequency=pre_frequency, |
| post_frequency=post_frequency, |
| pre_bus_state=pre_bus_state, |
| post_bus_state=self.bus_state, |
| ) |
| self._oversight_report_this_step = oversight_report |
|
|
| |
| per_agent_rewards = {} |
| for agent_id in range(self.num_agents): |
| agent_reward = self._compute_agent_reward( |
| agent_id=agent_id, |
| base_reward=base_reward, |
| safety_report=safety_reports.get(agent_id), |
| oversight_report=oversight_report, |
| is_blackout=info.is_blackout, |
| ) |
| per_agent_rewards[agent_id] = agent_reward |
|
|
| team_reward = base_reward.value |
|
|
| |
| per_agent_obs = { |
| agent_id: self._get_zone_obs(agent_id) |
| for agent_id in range(self.num_agents) |
| } |
|
|
| |
| if info.is_blackout: |
| for obs in per_agent_obs.values(): |
| obs.is_blackout = True |
|
|
| return MultiAgentStepResult( |
| observations=per_agent_obs, |
| rewards=per_agent_rewards, |
| team_reward=round(team_reward, 4), |
| done=done, |
| safety_reports=safety_reports, |
| oversight_report=oversight_report, |
| info=info, |
| ) |
|
|
| def get_zone_info(self) -> Dict[int, ZoneInfo]: |
| """Get metadata about each agent's zone.""" |
| zones = {} |
| for agent_id in range(self.num_agents): |
| zones[agent_id] = ZoneInfo( |
| agent_id=agent_id, |
| zone_name=self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}", |
| bus_ids=self.zone_bus_ids.get(agent_id, []), |
| boundary_line_ids=self.boundary_lines.get(agent_id, []), |
| internal_line_ids=self.internal_lines.get(agent_id, []), |
| ) |
| return zones |
|
|
| |
| |
| |
|
|
| def _compute_agent_reward( |
| self, |
| agent_id: int, |
| base_reward: GridReward, |
| safety_report: Optional[SafetyReport], |
| oversight_report: OversightReport, |
| is_blackout: bool, |
| ) -> GridReward: |
| """Compute per-agent reward with composable components. |
| |
| Components: |
| - survival: shared team component (same for all) |
| - frequency: shared (all agents affected equally) |
| - local_congestion: penalty for overloads in agent's zone |
| - safety_compliance: penalty if safety layer corrected the action |
| - coordination: penalty from oversight for selfish/conflicting behavior |
| - efficiency: small bonus for minimal actions |
| """ |
| components = {} |
|
|
| |
| components['survival'] = base_reward.components.get('survival', 1.0) |
| components['frequency'] = base_reward.components.get('frequency', 0.0) |
|
|
| |
| components['overload_shared'] = base_reward.components.get('overload', 0.0) / max(self.num_agents, 1) |
|
|
| |
| zone_overload = 0.0 |
| agent_lines = set(self.internal_lines.get(agent_id, [])) |
| agent_lines.update(self.boundary_lines.get(agent_id, [])) |
| for l in self.line_state: |
| if l['id'] in agent_lines and l['connected']: |
| limit = self._get_line_capacity(l['id']) |
| rho = abs(l['flow']) / limit if limit > 0 else 0.0 |
| if rho > 1.0: |
| zone_overload -= (rho - 1.0) ** 2 * 10 |
| elif rho > 0.8: |
| zone_overload -= 0.05 |
| components['local_congestion'] = zone_overload |
|
|
| |
| if safety_report and safety_report.was_corrected: |
| components['safety_compliance'] = -0.3 * ( |
| 1 + safety_report.blocked_topology_actions |
| ) |
| else: |
| components['safety_compliance'] = 0.1 |
|
|
| |
| coord_penalty = oversight_report.coordination_penalties.get(agent_id, 0.0) |
| components['coordination'] = -coord_penalty |
|
|
| |
| components['action_cost'] = base_reward.components.get('action_cost', 0.0) / max(self.num_agents, 1) |
|
|
| total = sum(components.values()) |
| return GridReward(value=round(total, 4), components=components) |
|
|
| |
| |
| |
|
|
| def _get_zone_obs(self, agent_id: int) -> ZoneObservation: |
| """Build partial observation for one agent (POMDP). |
| |
| Each agent sees: |
| - Only buses in their zone |
| - Internal + boundary lines |
| - Noisy global frequency |
| - Limited neighbor signals |
| """ |
| |
| zone_bus_ids = set(self.zone_bus_ids.get(agent_id, [])) |
| local_buses = [] |
| zone_load = 0.0 |
| zone_gen = 0.0 |
| for b in self.bus_state: |
| if b['id'] in zone_bus_ids: |
| b_cfg = self._find_bus_config(b['id']) |
| if b_cfg is None: |
| continue |
| local_buses.append(BusState( |
| id=b['id'], type=b_cfg['type'], |
| p_injection=round(b['p'], 4), |
| soc=round(b.get('soc', 0.0), 4), |
| ramp_rate=b_cfg.get('ramp_rate', 0.0), |
| )) |
| if b_cfg['type'] == 'load': |
| zone_load += abs(b['p']) |
| elif b_cfg['type'] in ('generator', 'solar', 'wind', 'slack'): |
| zone_gen += b['p'] |
| |
|
|
| |
| int_line_ids = set(self.internal_lines.get(agent_id, [])) |
| internal_lines = [] |
| for l in self.line_state: |
| if l['id'] in int_line_ids: |
| limit = self._get_line_capacity(l['id']) |
| rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0 |
| |
| noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho |
| noisy_rho = max(0.0, noisy_rho) |
| internal_lines.append(LineStatus( |
| id=l['id'], connected=l['connected'], |
| flow=round(l['flow'], 4), |
| rho=round(noisy_rho, 4), |
| )) |
|
|
| |
| bnd_line_ids = set(self.boundary_lines.get(agent_id, [])) |
| boundary_lines = [] |
| for l in self.line_state: |
| if l['id'] in bnd_line_ids: |
| limit = self._get_line_capacity(l['id']) |
| rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0 |
| noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho |
| noisy_rho = max(0.0, noisy_rho) |
| boundary_lines.append(LineStatus( |
| id=l['id'], connected=l['connected'], |
| flow=round(l['flow'], 4), |
| rho=round(noisy_rho, 4), |
| )) |
|
|
| |
| true_freq = self._compute_frequency() |
| noisy_freq = true_freq + (self._rng.normal(0, self.FREQ_NOISE_STD) if self._rng else 0.0) |
|
|
| |
| neighbor_signals = {} |
| for other_id in range(self.num_agents): |
| if other_id == agent_id: |
| continue |
| other_bus_ids = self.zone_bus_ids.get(other_id, []) |
| if other_bus_ids: |
| avg_inj = np.mean([ |
| b['p'] for b in self.bus_state if b['id'] in other_bus_ids |
| ]) |
| neighbor_signals[other_id] = round(float(avg_inj), 2) |
|
|
| |
| visible_lines = int_line_ids | bnd_line_ids |
| visible_cooldowns = { |
| k: v for k, v in self.cooldowns.items() if k in visible_lines |
| } |
|
|
| zone_name = self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}" |
|
|
| return ZoneObservation( |
| agent_id=agent_id, |
| zone_name=zone_name, |
| timestep=self.timestep, |
| grid_frequency=round(noisy_freq, 4), |
| local_buses=local_buses, |
| boundary_lines=boundary_lines, |
| internal_lines=internal_lines, |
| neighbor_signals=neighbor_signals, |
| cooldowns=visible_cooldowns, |
| is_blackout=False, |
| zone_load_mw=round(zone_load, 2), |
| zone_gen_mw=round(zone_gen, 2), |
| ) |
|
|
| |
| |
| |
|
|
| def _run_power_flow(self): |
| """Build active line list, solve DC power flow, update line flows and slack injection.""" |
| active_lines = [] |
| for l_cfg in self.lines_config: |
| l_dyn = self._find_line(l_cfg['id']) |
| if l_dyn and l_dyn['connected']: |
| active_lines.append({ |
| 'id': l_cfg['id'], 'from': l_cfg['from'], 'to': l_cfg['to'], |
| 'susceptance': l_cfg['susceptance'], 'connected': True |
| }) |
|
|
| self.solver.update_grid(active_lines) |
|
|
| p_inj = np.zeros(self.num_buses) |
| for b_dyn in self.bus_state: |
| p_inj[b_dyn['id']] = b_dyn['p'] |
|
|
| theta, flows, slack_inj = self.solver.solve(p_inj) |
|
|
| self.slack_injection = slack_inj |
| slack_dyn = self._find_bus_state(self.slack_bus_id) |
| if slack_dyn is not None: |
| slack_dyn['p'] = slack_inj |
|
|
| for l in self.line_state: |
| if l['connected'] and l['id'] in flows: |
| l['flow'] = flows[l['id']] |
| elif not l['connected']: |
| l['flow'] = 0.0 |
|
|
| def _compute_frequency(self) -> float: |
| """Frequency proxy using droop model, calibrated to system size.""" |
| return self.NOMINAL_FREQ - self.droop_constant * self.slack_injection |
|
|
| def _update_loads_and_renewables(self): |
| """Update time-varying loads and renewable generation. Uses per-episode RNG.""" |
| for b_dyn in self.bus_state: |
| b_cfg = self._find_bus_config(b_dyn['id']) |
| if b_cfg is None: |
| continue |
|
|
| if b_cfg['type'] == 'load': |
| daily_cycle = math.sin((self.timestep % 24 - 6) * math.pi / 12) |
| b_dyn['p'] = -b_cfg['base_p'] * (0.8 + 0.4 * max(0, daily_cycle)) |
|
|
| elif b_cfg['type'] == 'solar': |
| solar_cycle = max(0, math.sin((self.timestep % 24 - 6) * math.pi / 12)) |
| b_dyn['p'] = b_cfg['max_p'] * solar_cycle |
|
|
| elif b_cfg['type'] == 'wind': |
| wind_delta = self._rng.uniform(-5, 5) |
| b_dyn['p'] = float(np.clip(b_dyn['p'] + wind_delta, 0, b_cfg['max_p'])) |
|
|
| def _get_obs(self) -> GridObservation: |
| """Build observation from current state.""" |
| obs_lines = [] |
| for l in self.line_state: |
| limit = self._get_line_capacity(l['id']) |
| rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0 |
| obs_lines.append(LineStatus( |
| id=l['id'], connected=l['connected'], flow=round(l['flow'], 4), rho=round(rho, 4) |
| )) |
|
|
| obs_buses = [] |
| for b in self.bus_state: |
| b_cfg = self._find_bus_config(b['id']) |
| if b_cfg is None: |
| continue |
| obs_buses.append(BusState( |
| id=b['id'], type=b_cfg['type'], |
| p_injection=round(b['p'], 4), |
| soc=round(b.get('soc', 0.0), 4), |
| ramp_rate=b_cfg.get('ramp_rate', 0.0) |
| )) |
|
|
| freq = self._compute_frequency() |
|
|
| return GridObservation( |
| timestep=self.timestep, |
| grid_frequency=round(freq, 4), |
| buses=obs_buses, |
| lines=obs_lines, |
| cooldowns=self.cooldowns, |
| is_blackout=getattr(self, '_is_blackout', False) |
| ) |
|
|
| |
|
|
| def _find_line(self, line_id: str): |
| |
| idx = getattr(self, '_line_state_by_id', None) |
| if idx is not None: |
| return idx.get(line_id) |
| return next((l for l in self.line_state if l['id'] == line_id), None) |
|
|
| def _find_bus_config(self, bus_id: int): |
| return self._bus_cfg_by_id.get(bus_id) |
|
|
| def _find_bus_state(self, bus_id: int): |
| idx = getattr(self, '_bus_state_by_id', None) |
| if idx is not None: |
| return idx.get(bus_id) |
| return next((b for b in self.bus_state if b['id'] == bus_id), None) |
|
|
| def _get_line_capacity(self, line_id: str) -> float: |
| cfg = self._line_cfg_by_id.get(line_id) |
| return cfg['capacity'] if cfg else 1.0 |