Commit ·
78131a0
0
Parent(s):
OpenGrid: Multi-agent POMDP power grid environment with GRPO training
Browse filesFeatures:
- Multi-agent POMDP environment with safety layer and oversight agent
- Environment-grounded GRPO reward function (steps actual physics)
- FastAPI server with single/multi-agent APIs, grading, and visualization
- Heuristic baseline, LLM inference pipeline, and training notebook
- Karnataka KPTCL real-world grid task
- 4 task difficulties: easy, medium, hard, karnataka
- .dockerignore +31 -0
- .gitattributes +5 -0
- .gitignore +36 -0
- Dockerfile +32 -0
- LICENSE +21 -0
- README.md +390 -0
- app.py +416 -0
- changes.md +111 -0
- inference.py +535 -0
- openenv.yaml +40 -0
- pyproject.toml +41 -0
- requirements.txt +9 -0
- server/__init__.py +1 -0
- server/app.py +21 -0
- src/__init__.py +0 -0
- src/baseline.py +199 -0
- src/environment.py +672 -0
- src/grader.py +232 -0
- src/models.py +162 -0
- src/oversight.py +190 -0
- src/physics.py +172 -0
- src/safety.py +316 -0
- src/tasks.py +384 -0
- src/visualization.py +224 -0
- static/app.js +680 -0
- static/index.html +225 -0
- static/karnataka.svg +3 -0
- static/logo.png +3 -0
- static/style.css +935 -0
- tests/__init__.py +0 -0
- tests/test_multi_agent.py +345 -0
- tests/test_solver.py +195 -0
- training/__init__.py +1 -0
- training/opengrid_grpo_colab.ipynb +635 -0
- training/train_grpo.py +827 -0
- validate-submission.sh +103 -0
.dockerignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
.git/
|
| 7 |
+
.gitignore
|
| 8 |
+
.vscode/
|
| 9 |
+
.env
|
| 10 |
+
|
| 11 |
+
# Docs (keep README for the Space)
|
| 12 |
+
guide.md
|
| 13 |
+
detailed judging criteria.md
|
| 14 |
+
ui_skill.md
|
| 15 |
+
project-spec.md
|
| 16 |
+
codebase_summary.md
|
| 17 |
+
pyrightconfig.json
|
| 18 |
+
|
| 19 |
+
# Generated files
|
| 20 |
+
inference_output.txt
|
| 21 |
+
generate_code_md.py
|
| 22 |
+
uv.lock
|
| 23 |
+
|
| 24 |
+
# Training outputs (not needed in Docker image)
|
| 25 |
+
training/outputs/
|
| 26 |
+
*.safetensors
|
| 27 |
+
*.bin
|
| 28 |
+
|
| 29 |
+
# Tests not needed in production
|
| 30 |
+
tests/
|
| 31 |
+
test_multiagent.py
|
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ico filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
.venv/
|
| 6 |
+
venv/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.env
|
| 11 |
+
.vscode/
|
| 12 |
+
|
| 13 |
+
# Generated / temporary files
|
| 14 |
+
inference_output.txt
|
| 15 |
+
codebase_summary.md
|
| 16 |
+
generate_code_md.py
|
| 17 |
+
uv.lock
|
| 18 |
+
|
| 19 |
+
# Reference docs (not part of submission)
|
| 20 |
+
guide.md
|
| 21 |
+
detailed judging criteria.md
|
| 22 |
+
ui_skill.md
|
| 23 |
+
project-spec.md
|
| 24 |
+
pyrightconfig.json
|
| 25 |
+
|
| 26 |
+
# Training outputs (large files — push separately or add to HF)
|
| 27 |
+
training/outputs/
|
| 28 |
+
*.safetensors
|
| 29 |
+
*.bin
|
| 30 |
+
|
| 31 |
+
# OS files
|
| 32 |
+
Thumbs.db
|
| 33 |
+
.DS_Store
|
| 34 |
+
|
| 35 |
+
# Duplicate test file (tests/ directory has the real one)
|
| 36 |
+
test_multiagent.py
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Docker Space — OpenGrid
|
| 2 |
+
# Docs: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
LABEL org.opencontainers.image.title="OpenGrid"
|
| 7 |
+
LABEL org.opencontainers.image.description="Renewable energy grid load-balancing environment"
|
| 8 |
+
LABEL openenv="true"
|
| 9 |
+
|
| 10 |
+
# Create non-root user required by HF Spaces
|
| 11 |
+
RUN useradd -m -u 1000 user
|
| 12 |
+
USER user
|
| 13 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 14 |
+
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# Install dependencies
|
| 18 |
+
COPY --chown=user requirements.txt .
|
| 19 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy application code
|
| 22 |
+
COPY --chown=user . /app
|
| 23 |
+
|
| 24 |
+
# Expose HF Spaces default port
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Healthcheck
|
| 28 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s \
|
| 29 |
+
CMD python -c "import httpx; httpx.get('http://localhost:7860/health').raise_for_status()" || exit 1
|
| 30 |
+
|
| 31 |
+
# Run the server
|
| 32 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 KRISHNA GOYAL
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OpenGrid
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<img src="static/logo.png" alt="OpenGrid Logo" width="120">
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
<h1 align="center">OpenGrid ⚡</h1>
|
| 16 |
+
<p align="center"><strong>Safe Multi-Agent RL for Power Grid Operations</strong></p>
|
| 17 |
+
|
| 18 |
+
<p align="center">
|
| 19 |
+
<a href="https://huggingface.co/spaces/K446/Opengrid"><img src="https://img.shields.io/badge/🤗%20Live%20Demo-HuggingFace%20Space-yellow" alt="Live Demo"></a>
|
| 20 |
+
<a href="https://github.com/krishnagoyal099/Opengrid_env"><img src="https://img.shields.io/badge/GitHub-Repository-181717?logo=github" alt="GitHub"></a>
|
| 21 |
+
<a href="https://github.com/openenv"><img src="https://img.shields.io/badge/OpenEnv-compatible-blue" alt="OpenEnv"></a>
|
| 22 |
+
<a href="https://www.python.org"><img src="https://img.shields.io/badge/python-3.10%2B-blue" alt="Python 3.10+"></a>
|
| 23 |
+
<a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-green.svg" alt="License: MIT"></a>
|
| 24 |
+
</p>
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## What is OpenGrid?
|
| 29 |
+
|
| 30 |
+
OpenGrid is a **multi-agent reinforcement learning environment** where AI agents control a power grid. Multiple agents, each managing a zone, must coordinate under **partial observability** to keep the lights on — balancing electricity supply and demand in real-time while managing renewable energy volatility.
|
| 31 |
+
|
| 32 |
+
What makes OpenGrid different:
|
| 33 |
+
|
| 34 |
+
- **Multi-Agent POMDP**: 2-3 agents, each seeing only their local zone + noisy global signals
|
| 35 |
+
- **Safety Layer**: Hard constraint filter blocks unsafe actions before they reach the physics engine (N-1 security, anti-islanding, ramp limits)
|
| 36 |
+
- **Oversight Agent**: Monitors cross-zone coordination, penalizes selfish behavior
|
| 37 |
+
- **Composable Rewards**: 6 independent reward functions — survival, frequency, congestion, safety compliance, coordination, efficiency
|
| 38 |
+
- **Real Physics**: DC power flow solver with droop frequency model
|
| 39 |
+
|
| 40 |
+
> **🔗 Try it live:** [huggingface.co/spaces/K446/Opengrid](https://huggingface.co/spaces/K446/Opengrid)
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## How It Works
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
┌─────────────────────────────────────────────────────────┐
|
| 48 |
+
│ MULTI-AGENT LOOP │
|
| 49 |
+
│ │
|
| 50 |
+
│ Each agent observes LOCAL zone state (POMDP) │
|
| 51 |
+
│ │ │
|
| 52 |
+
│ ▼ │
|
| 53 |
+
│ Each agent proposes action (adjust power, switch │
|
| 54 |
+
│ lines — only within their zone) │
|
| 55 |
+
│ │ │
|
| 56 |
+
│ ▼ │
|
| 57 |
+
│ SAFETY LAYER validates all actions: │
|
| 58 |
+
│ - N-1 security check │
|
| 59 |
+
│ - Anti-islanding │
|
| 60 |
+
│ - Projects unsafe → nearest safe alternative │
|
| 61 |
+
│ │ │
|
| 62 |
+
│ ▼ │
|
| 63 |
+
│ OVERSIGHT AGENT evaluates coordination: │
|
| 64 |
+
│ - Detects conflicts between agents │
|
| 65 |
+
│ - Penalizes selfish behavior │
|
| 66 |
+
│ │ │
|
| 67 |
+
│ ▼ │
|
| 68 |
+
│ Physics engine solves DC power flow │
|
| 69 |
+
│ │ │
|
| 70 |
+
│ ▼ │
|
| 71 |
+
│ Per-agent rewards: local + global + safety + coord │
|
| 72 |
+
│ │ │
|
| 73 |
+
│ Repeat for 50 steps — or until blackout! │
|
| 74 |
+
└─────────────────────────────────────────────────────────┘
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
The agent interacts through a **REST API** — any language or framework that can make HTTP requests can play. Both single-agent (backward compatible) and multi-agent modes are supported.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Three Difficulty Levels
|
| 82 |
+
|
| 83 |
+
| Task | Grid Size | Agents | Renewable Mix | What Makes It Hard |
|
| 84 |
+
|---|---|---|---|---|
|
| 85 |
+
| `task_easy` | 5 buses | 2 | 20% | Basic frequency control, 2-zone coordination |
|
| 86 |
+
| `task_medium` | 10 buses | 3 | 50% | Volatile renewables + congestion + 3-zone POMDP |
|
| 87 |
+
| `task_hard` | 14 buses | 3 | 70% | High volatility, tight margins, complex topology |
|
| 88 |
+
| `task_karnataka` | 15 buses | 4 | Real mix | Real KPTCL topology (Raichur, Ballari, Bengaluru, Mysuru) with GPS coordinates |
|
| 89 |
+
|
| 90 |
+
All tasks run for **50 timesteps**. Scores range from **0.02 to 0.98** (higher = better).
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## Quick Start
|
| 95 |
+
|
| 96 |
+
### 1. Clone & Install
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
git clone https://github.com/krishnagoyal099/Opengrid_env.git
|
| 100 |
+
cd Opengrid_env
|
| 101 |
+
|
| 102 |
+
pip install -r requirements.txt
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### 2. Start the Server
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
uvicorn app:app --host 0.0.0.0 --port 7860
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Then open [http://localhost:7860](http://localhost:7860) — you'll see the **interactive SCADA dashboard** with a Leaflet.js GIS map showing the Karnataka grid topology in real-time.
|
| 112 |
+
|
| 113 |
+
### 3. Run the AI Agent
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
# Set your LLM API credentials
|
| 117 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 118 |
+
export MODEL_NAME="gpt-4o"
|
| 119 |
+
export HF_TOKEN="your-api-key"
|
| 120 |
+
export ENV_URL="http://localhost:7860"
|
| 121 |
+
|
| 122 |
+
# Run inference on all 3 tasks
|
| 123 |
+
python inference.py
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### 4. Train with GRPO
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
# Test the training pipeline (no GPU needed)
|
| 130 |
+
python training/train_grpo.py --test-mode
|
| 131 |
+
|
| 132 |
+
# Full training with Unsloth (needs GPU)
|
| 133 |
+
python training/train_grpo.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit --use-unsloth
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Docker (Alternative)
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
docker build -t opengrid .
|
| 140 |
+
docker run -p 7860:7860 opengrid
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
## Multi-Agent API
|
| 146 |
+
|
| 147 |
+
### Reset in Multi-Agent Mode
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
curl -X POST "http://localhost:7860/reset_multi?task_id=task_medium"
|
| 151 |
+
# Returns: {
|
| 152 |
+
# "session_id": "abc-123",
|
| 153 |
+
# "num_agents": 3,
|
| 154 |
+
# "zone_info": {"0": {"zone_name": "Bengaluru_Region", "bus_ids": [...]}, ...},
|
| 155 |
+
# "observations": {"0": {...}, "1": {...}, "2": {...}}
|
| 156 |
+
# }
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Take a Multi-Agent Step
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
curl -X POST "http://localhost:7860/step_multi?session_id=abc-123" \
|
| 163 |
+
-H "Content-Type: application/json" \
|
| 164 |
+
-d '{
|
| 165 |
+
"agent_actions": {
|
| 166 |
+
"0": {"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []},
|
| 167 |
+
"1": {"bus_adjustments": [], "topology_actions": []},
|
| 168 |
+
"2": {"bus_adjustments": [{"bus_id": 9, "delta": -3.0}], "topology_actions": []}
|
| 169 |
+
}
|
| 170 |
+
}'
|
| 171 |
+
# Returns: per-agent observations, per-agent rewards, safety reports, oversight report
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### Single-Agent API (Backward Compatible)
|
| 175 |
+
|
| 176 |
+
The original single-agent API (`/reset`, `/step`, `/state`, `/grader`) is fully preserved.
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## What Each Agent Sees (POMDP Observation)
|
| 181 |
+
|
| 182 |
+
Each agent receives a **partial** observation of their zone:
|
| 183 |
+
|
| 184 |
+
| Field | Example | Meaning |
|
| 185 |
+
|---|---|---|
|
| 186 |
+
| `grid_frequency` | `49.87` | **Noisy** frequency reading (Gaussian noise added) |
|
| 187 |
+
| `local_buses[].type` | `"solar"` | Bus type (only buses in agent's zone) |
|
| 188 |
+
| `local_buses[].p_injection` | `35.2` | Power output in MW |
|
| 189 |
+
| `boundary_lines[].rho` | `0.78` | Lines connecting to other zones |
|
| 190 |
+
| `internal_lines[].flow` | `62.4` | Lines within agent's zone |
|
| 191 |
+
| `neighbor_signals` | `{1: 12.5}` | Average injection of neighboring zones |
|
| 192 |
+
| `zone_load_mw` | `85.3` | Total load in this zone |
|
| 193 |
+
| `zone_gen_mw` | `42.1` | Total generation in this zone |
|
| 194 |
+
|
| 195 |
+
Agents do **NOT** see buses or lines in other zones — they must coordinate through limited neighbor signals and the shared (but noisy) frequency reading.
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## Safety Layer
|
| 200 |
+
|
| 201 |
+
The safety layer validates every action BEFORE it reaches the physics engine:
|
| 202 |
+
|
| 203 |
+
| Check | What It Does | If Violated |
|
| 204 |
+
|---|---|---|
|
| 205 |
+
| **Zone Boundary** | Agent can only adjust buses in their zone | Action removed |
|
| 206 |
+
| **N-1 Security** | Grid must survive loss of any single line | Action blocked |
|
| 207 |
+
| **Anti-Islanding** | Opening a line must not disconnect the grid | Switch blocked |
|
| 208 |
+
| **Ramp Limits** | Power changes within physical ramp rates | Delta clamped |
|
| 209 |
+
| **Capacity Limits** | Generation within min/max bounds | Output clamped |
|
| 210 |
+
| **Battery SoC** | Can't discharge below 0 or charge above capacity | Delta clamped |
|
| 211 |
+
|
| 212 |
+
Critically, unsafe actions are **projected to the nearest safe alternative** rather than simply rejected. This preserves the agent's intent while enforcing safety, and provides a richer training signal.
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
## Reward System
|
| 217 |
+
|
| 218 |
+
Six composable, independent reward functions:
|
| 219 |
+
|
| 220 |
+
| Component | Range | When |
|
| 221 |
+
|---|---|---|
|
| 222 |
+
| **survival** | +1.0 / -100.0 | Grid stays connected / blackout |
|
| 223 |
+
| **frequency** | -1.5 to +0.2 | Based on deviation from 50 Hz |
|
| 224 |
+
| **local_congestion** | ≤ 0 | Line overloads in agent's zone |
|
| 225 |
+
| **safety_compliance** | -0.3 to +0.1 | Penalty if safety layer corrected action |
|
| 226 |
+
| **coordination** | ≤ 0 | Penalty for selfish/conflicting actions |
|
| 227 |
+
| **action_cost** | -0.5 / switch | Topology change cost |
|
| 228 |
+
|
| 229 |
+
---
|
| 230 |
+
|
| 231 |
+
## Scoring
|
| 232 |
+
|
| 233 |
+
Scores are normalized to **(0.02 – 0.98)** using:
|
| 234 |
+
|
| 235 |
+
```
|
| 236 |
+
score = (agent_reward - worst_case) / (best_case - worst_case) + N1_bonus
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
| Bound | How It's Computed |
|
| 240 |
+
|---|---|
|
| 241 |
+
| **Worst case (floor)** | Random agent that chaotically switches lines — causes blackouts fast |
|
| 242 |
+
| **Best case (ceiling)** | Theoretical perfect agent: survives every step + perfect frequency bonus |
|
| 243 |
+
| **N-1 bonus** | Up to +10% for completing the episode without a blackout |
|
| 244 |
+
|
| 245 |
+
### Baseline Scores (Heuristic Policy)
|
| 246 |
+
|
| 247 |
+
| Task | Score | Strategy |
|
| 248 |
+
|---|---|---|
|
| 249 |
+
| `task_easy` | ~0.90 | Proportional frequency control, no line switching |
|
| 250 |
+
| `task_medium` | ~0.98 | Same heuristic — medium grid happens to be well-balanced |
|
| 251 |
+
| `task_hard` | ~0.98 | Same heuristic — hard grid has more buses but similar dynamics |
|
| 252 |
+
| `task_karnataka` | ~0.98 | 15-bus real topology, 4 zones, generators warm-started |
|
| 253 |
+
|
| 254 |
+
> Reproduce with: `python get_scores.py`
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## Project Structure
|
| 259 |
+
|
| 260 |
+
```
|
| 261 |
+
OpenGrid/
|
| 262 |
+
├── app.py # FastAPI server (single + multi-agent endpoints)
|
| 263 |
+
├── inference.py # LLM inference script
|
| 264 |
+
├── get_scores.py # Reproduce baseline scores
|
| 265 |
+
├── openenv.yaml # OpenEnv manifest
|
| 266 |
+
├── Dockerfile # Container config
|
| 267 |
+
├── requirements.txt # Python dependencies
|
| 268 |
+
│
|
| 269 |
+
├── src/ # Core environment
|
| 270 |
+
│ ├── models.py # Pydantic models (single + multi-agent)
|
| 271 |
+
│ ├── environment.py # Grid simulation (POMDP + backward-compatible)
|
| 272 |
+
│ ├── physics.py # DC power flow solver
|
| 273 |
+
│ ├── tasks.py # Procedural grid generation with zone assignment
|
| 274 |
+
│ ├── grader.py # Scoring (floor/ceiling normalization)
|
| 275 |
+
│ ├── baseline.py # Heuristic + LLM policies
|
| 276 |
+
│ ├── safety.py # Safety layer (N-1, anti-islanding, projection)
|
| 277 |
+
│ ├── oversight.py # Oversight agent (coordination monitoring)
|
| 278 |
+
│ └── visualization.py # Grid topology & frequency plots
|
| 279 |
+
│
|
| 280 |
+
├── training/ # RL training pipeline
|
| 281 |
+
│ ├── train_grpo.py # TRL GRPO training script
|
| 282 |
+
│ └── opengrid_grpo_colab.ipynb # Google Colab notebook for GPU training
|
| 283 |
+
│
|
| 284 |
+
├── tests/ # Test suite (28 tests)
|
| 285 |
+
│ ├── test_solver.py # Physics, environment, grader tests
|
| 286 |
+
│ └── test_multi_agent.py # Multi-agent, safety, oversight tests
|
| 287 |
+
│
|
| 288 |
+
├── static/ # Dashboard frontend
|
| 289 |
+
│ ├── index.html
|
| 290 |
+
│ ├── style.css
|
| 291 |
+
│ └── app.js
|
| 292 |
+
│
|
| 293 |
+
└── server/ # Alternative entry point
|
| 294 |
+
└── app.py
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
---
|
| 298 |
+
|
| 299 |
+
## Training Results (GRPO)
|
| 300 |
+
|
| 301 |
+
We trained **Qwen 2.5 1.5B** using GRPO (Group Relative Policy Optimization) on the Karnataka grid topology.
|
| 302 |
+
|
| 303 |
+
### Training Loss
|
| 304 |
+
|
| 305 |
+
The loss converges from ~0.09 to near 0 by step ~400, confirming end-to-end training pipeline functionality.
|
| 306 |
+
|
| 307 |
+
### Before vs After (Average Episode Reward)
|
| 308 |
+
|
| 309 |
+
| Task | Heuristic Baseline | GRPO Trained |
|
| 310 |
+
|---|---|---|
|
| 311 |
+
| `task_easy` | 27.6 | 27.6 |
|
| 312 |
+
| `task_medium` | 48.7 | 48.7 |
|
| 313 |
+
| `task_karnataka` | 19.6 | -316.9 |
|
| 314 |
+
|
| 315 |
+
**Key Finding**: Naive LLM training on simplified proxy rewards does not transfer to real-world grid topologies — Karnataka collapses to -316.9. This validates our architectural decision to pair RL agents with a **safety layer + oversight agent**. The heuristic baseline with safety corrections (19.6 reward, zero blackouts) outperforms pure RL, proving that critical infrastructure needs guardrails, not just learned policies.
|
| 316 |
+
|
| 317 |
+
> **Reproduce training**: Open `training/opengrid_grpo_colab.ipynb` in Google Colab (T4 GPU)
|
| 318 |
+
|
| 319 |
+
---
|
| 320 |
+
|
| 321 |
+
## Technical Details
|
| 322 |
+
|
| 323 |
+
<details>
|
| 324 |
+
<summary><strong>Physics Engine</strong></summary>
|
| 325 |
+
|
| 326 |
+
- **DC Power Flow** with B-matrix formulation (standard power systems approximation)
|
| 327 |
+
- **Slack bus** absorbs generation/load imbalance after each power flow solve
|
| 328 |
+
- **Islanding detection** via NetworkX graph connectivity checks
|
| 329 |
+
- **Droop frequency model** calibrated to system size: `f = 50.0 - (2.5 / total_capacity) * P_slack`
|
| 330 |
+
|
| 331 |
+
</details>
|
| 332 |
+
|
| 333 |
+
<details>
|
| 334 |
+
<summary><strong>Multi-Agent Design</strong></summary>
|
| 335 |
+
|
| 336 |
+
- Buses partitioned into zones using **greedy modularity community detection** (NetworkX)
|
| 337 |
+
- Each zone maps to a KPTCL transmission region (Bengaluru, Mysuru, Kalburagi)
|
| 338 |
+
- **Partial observability**: agents see only local buses, boundary lines, noisy frequency
|
| 339 |
+
- **Neighbor signals**: each agent receives average injection of adjacent zones
|
| 340 |
+
- **Safety-first**: all actions validated by constraint filter before physics engine
|
| 341 |
+
|
| 342 |
+
</details>
|
| 343 |
+
|
| 344 |
+
<details>
|
| 345 |
+
<summary><strong>Thread Safety</strong></summary>
|
| 346 |
+
|
| 347 |
+
- All session reads/writes are protected by a `threading.Lock`
|
| 348 |
+
- Grader bounds use double-checked locking to avoid duplicate rollouts
|
| 349 |
+
- Safe for concurrent requests from multiple agents
|
| 350 |
+
|
| 351 |
+
</details>
|
| 352 |
+
|
| 353 |
+
<details>
|
| 354 |
+
<summary><strong>Reproducibility</strong></summary>
|
| 355 |
+
|
| 356 |
+
| Component | Mechanism |
|
| 357 |
+
|---|---|
|
| 358 |
+
| Task grids | Seeded procedural generation (`np.random.default_rng`) |
|
| 359 |
+
| Zone partitioning | Deterministic community detection with seed |
|
| 360 |
+
| Wind variability | Per-episode RNG (same seed → same wind pattern) |
|
| 361 |
+
| Floor estimation | Seeded thrash policy + 10 diverse-seeded episodes |
|
| 362 |
+
| Ceiling | Analytical formula (deterministic) |
|
| 363 |
+
| Scoring | Shared `normalize_score()` across all endpoints |
|
| 364 |
+
|
| 365 |
+
</details>
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
+
## Related Work
|
| 370 |
+
|
| 371 |
+
- **Massgen**: When Multiple LLMs Think Together (Gradient Network, 2025)
|
| 372 |
+
- **Symphony**: Multi-Agent Intelligence in a Collective Fabric (Gradient Network, 2025)
|
| 373 |
+
- **Grid2Op**: Power grid RL environment (RTE, 2020)
|
| 374 |
+
- **OpenEnv**: Standardized agentic execution environments (Scalar/HuggingFace/Meta, 2026)
|
| 375 |
+
|
| 376 |
+
---
|
| 377 |
+
|
| 378 |
+
## Links
|
| 379 |
+
|
| 380 |
+
| Resource | URL |
|
| 381 |
+
|---|---|
|
| 382 |
+
| **Live Demo** | [huggingface.co/spaces/K446/Opengrid](https://huggingface.co/spaces/K446/Opengrid) |
|
| 383 |
+
| **GitHub Repo** | [github.com/krishnagoyal099/Opengrid_env](https://github.com/krishnagoyal099/Opengrid_env) |
|
| 384 |
+
| **API Docs (Swagger)** | [huggingface.co/spaces/K446/Opengrid/docs](https://k446-opengrid.hf.space/docs) |
|
| 385 |
+
|
| 386 |
+
---
|
| 387 |
+
|
| 388 |
+
## License
|
| 389 |
+
|
| 390 |
+
MIT — see [LICENSE](LICENSE) for details.
|
app.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.staticfiles import StaticFiles
|
| 3 |
+
from fastapi.responses import FileResponse
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
from src.models import (
|
| 6 |
+
GridAction, GridObservation, GridReward,
|
| 7 |
+
MultiAgentAction, MultiAgentStepResult,
|
| 8 |
+
)
|
| 9 |
+
from src.environment import OpenGridEnv
|
| 10 |
+
from src.tasks import TASKS
|
| 11 |
+
from src.grader import RobustnessGrader, normalize_score, _SCORE_EPSILON, _clamp_score
|
| 12 |
+
from src.baseline import heuristic_policy, llm_policy
|
| 13 |
+
from src.visualization import generate_dashboard
|
| 14 |
+
import copy
|
| 15 |
+
import uuid
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
import pathlib
|
| 19 |
+
import threading
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="OpenGrid Environment",
|
| 24 |
+
description="Multi-agent renewable energy grid load-balancing environment with safety constraints",
|
| 25 |
+
version="2.0.0"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Static files — mount only if present (allows API-only or test deployments)
|
| 29 |
+
STATIC_DIR = pathlib.Path(__file__).parent / "static"
|
| 30 |
+
if STATIC_DIR.exists():
|
| 31 |
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
| 32 |
+
else:
|
| 33 |
+
warnings.warn(
|
| 34 |
+
f"Static directory not found: {STATIC_DIR}. "
|
| 35 |
+
"Dashboard UI disabled; API endpoints remain available."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Session storage with TTL + per-session locking
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# _session_lock guards the sessions/history *dicts* for insert/delete/lookup.
|
| 42 |
+
# Each session also has its own lock ("lock" key) that serializes env
|
| 43 |
+
# operations, preventing race conditions when concurrent requests target
|
| 44 |
+
# the same session (e.g. two /step calls, or /step racing with /grader).
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
sessions: Dict[str, Dict] = {}
|
| 47 |
+
history: Dict[str, List] = {}
|
| 48 |
+
MAX_SESSIONS = 100
|
| 49 |
+
SESSION_TTL_SECONDS = 3600 # 1 hour
|
| 50 |
+
_session_lock = threading.Lock()
|
| 51 |
+
|
| 52 |
+
# Grader cache: bounds are expensive (10 rollouts per task), compute once.
|
| 53 |
+
# Construction AND bounds estimation are serialized under _grader_lock.
|
| 54 |
+
_grader_cache: Dict[str, RobustnessGrader] = {}
|
| 55 |
+
_grader_lock = threading.Lock()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _new_session(env: OpenGridEnv, task_id: str, mode: str, **extra) -> dict:
|
| 59 |
+
"""Create a session dict with per-session lock and metadata."""
|
| 60 |
+
session = {
|
| 61 |
+
"env": env,
|
| 62 |
+
"created": time.time(),
|
| 63 |
+
"last_access": time.time(),
|
| 64 |
+
"task_id": task_id,
|
| 65 |
+
"rewards": [],
|
| 66 |
+
"mode": mode,
|
| 67 |
+
"done": False,
|
| 68 |
+
"is_blackout": False,
|
| 69 |
+
"lock": threading.Lock(),
|
| 70 |
+
}
|
| 71 |
+
session.update(extra)
|
| 72 |
+
return session
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _session_age(s: dict, now: float) -> float:
|
| 76 |
+
"""Return the last-access timestamp for a session (for eviction sorting)."""
|
| 77 |
+
ts = s.get("last_access")
|
| 78 |
+
if ts is None:
|
| 79 |
+
ts = s.get("created")
|
| 80 |
+
return float(ts) if ts is not None else now
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _cleanup_sessions():
|
| 84 |
+
"""Evict expired and excess sessions. Caller must hold _session_lock."""
|
| 85 |
+
now = time.time()
|
| 86 |
+
|
| 87 |
+
# Phase 1: evict expired sessions (actual TTL)
|
| 88 |
+
expired = [
|
| 89 |
+
sid for sid, s in sessions.items()
|
| 90 |
+
if now - _session_age(s, now) > SESSION_TTL_SECONDS
|
| 91 |
+
]
|
| 92 |
+
for sid in expired:
|
| 93 |
+
sessions.pop(sid, None)
|
| 94 |
+
history.pop(sid, None)
|
| 95 |
+
|
| 96 |
+
# Phase 2: evict oldest if still over limit
|
| 97 |
+
while len(sessions) >= MAX_SESSIONS:
|
| 98 |
+
oldest_sid = min(
|
| 99 |
+
sessions,
|
| 100 |
+
key=lambda k: _session_age(sessions[k], 0.0),
|
| 101 |
+
)
|
| 102 |
+
sessions.pop(oldest_sid, None)
|
| 103 |
+
history.pop(oldest_sid, None)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _get_session(session_id: str) -> dict:
|
| 107 |
+
"""Look up session, update last_access, raise 404 if missing.
|
| 108 |
+
Caller must NOT hold _session_lock (this acquires it)."""
|
| 109 |
+
with _session_lock:
|
| 110 |
+
session = sessions.get(session_id)
|
| 111 |
+
if session is None:
|
| 112 |
+
raise HTTPException(404, "Session not found")
|
| 113 |
+
session["last_access"] = time.time()
|
| 114 |
+
return session
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _get_grader(task_id: str) -> RobustnessGrader:
|
| 118 |
+
"""Get or create a cached RobustnessGrader for a task.
|
| 119 |
+
|
| 120 |
+
Both construction and bounds estimation run under _grader_lock
|
| 121 |
+
so concurrent /grader requests don't duplicate or race on
|
| 122 |
+
_estimate_bounds() mutations.
|
| 123 |
+
"""
|
| 124 |
+
with _grader_lock:
|
| 125 |
+
if task_id not in _grader_cache:
|
| 126 |
+
grader = RobustnessGrader(copy.deepcopy(TASKS[task_id]))
|
| 127 |
+
grader.get_bounds() # force expensive mutation while locked
|
| 128 |
+
_grader_cache[task_id] = grader
|
| 129 |
+
return _grader_cache[task_id]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.get("/")
|
| 133 |
+
def root():
|
| 134 |
+
"""Serve the interactive dashboard (or API info if static files absent)."""
|
| 135 |
+
index = STATIC_DIR / "index.html"
|
| 136 |
+
if index.exists():
|
| 137 |
+
return FileResponse(str(index))
|
| 138 |
+
return {"status": "OpenGrid API", "version": "2.0.0", "docs": "/docs"}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@app.get("/health")
|
| 142 |
+
def health():
|
| 143 |
+
"""Health check endpoint (JSON)."""
|
| 144 |
+
return {"status": "OpenGrid Running", "version": "2.0.0", "docs": "/docs"}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@app.get("/tasks")
|
| 148 |
+
def get_tasks():
|
| 149 |
+
"""List available tasks with metadata including multi-agent zone info."""
|
| 150 |
+
action_schema = GridAction.model_json_schema()
|
| 151 |
+
obs_schema = GridObservation.model_json_schema()
|
| 152 |
+
return [
|
| 153 |
+
{
|
| 154 |
+
"id": k,
|
| 155 |
+
"difficulty": v.get("difficulty", k.split('_')[1]),
|
| 156 |
+
"num_buses": v["num_buses"],
|
| 157 |
+
"max_steps": v["max_steps"],
|
| 158 |
+
"num_agents": v.get("num_agents", 1),
|
| 159 |
+
"zone_names": v.get("zone_names", []),
|
| 160 |
+
"buses": v.get("buses", []),
|
| 161 |
+
"action_schema": action_schema,
|
| 162 |
+
"observation_schema": obs_schema
|
| 163 |
+
} for k, v in TASKS.items()
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ===========================================================================
|
| 168 |
+
# Single-Agent API (backward compatible)
|
| 169 |
+
# ===========================================================================
|
| 170 |
+
|
| 171 |
+
@app.post("/reset")
|
| 172 |
+
def reset(task_id: str = "task_easy"):
|
| 173 |
+
"""Reset (or create) an environment session. Returns initial observation."""
|
| 174 |
+
if task_id not in TASKS:
|
| 175 |
+
raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
|
| 176 |
+
|
| 177 |
+
env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
|
| 178 |
+
obs = env.reset()
|
| 179 |
+
sid = str(uuid.uuid4())
|
| 180 |
+
|
| 181 |
+
with _session_lock:
|
| 182 |
+
_cleanup_sessions()
|
| 183 |
+
sessions[sid] = _new_session(env, task_id, mode="single")
|
| 184 |
+
history[sid] = [obs]
|
| 185 |
+
|
| 186 |
+
return {"session_id": sid, "observation": obs.model_dump()}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@app.post("/step")
|
| 190 |
+
def step(session_id: str, action: GridAction):
|
| 191 |
+
"""Execute one step in the environment."""
|
| 192 |
+
session = _get_session(session_id)
|
| 193 |
+
|
| 194 |
+
# Per-session lock serializes all env operations for this session
|
| 195 |
+
with session["lock"]:
|
| 196 |
+
if session.get("done"):
|
| 197 |
+
raise HTTPException(400, "Episode already done. Call /reset to start a new session.")
|
| 198 |
+
|
| 199 |
+
env = session["env"]
|
| 200 |
+
obs, reward, done, info = env.step(action)
|
| 201 |
+
|
| 202 |
+
session["rewards"].append(reward.value)
|
| 203 |
+
session["done"] = done
|
| 204 |
+
session["is_blackout"] = info.is_blackout
|
| 205 |
+
|
| 206 |
+
with _session_lock:
|
| 207 |
+
history[session_id].append(obs)
|
| 208 |
+
|
| 209 |
+
return {
|
| 210 |
+
"observation": obs.model_dump(),
|
| 211 |
+
"reward": reward.model_dump(),
|
| 212 |
+
"done": done,
|
| 213 |
+
"info": info.model_dump()
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@app.get("/state")
|
| 218 |
+
def get_state(session_id: str):
|
| 219 |
+
"""Get current state of a session."""
|
| 220 |
+
session = _get_session(session_id)
|
| 221 |
+
|
| 222 |
+
with session["lock"]:
|
| 223 |
+
return session["env"].state().model_dump()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ===========================================================================
|
| 227 |
+
# Multi-Agent POMDP API
|
| 228 |
+
# ===========================================================================
|
| 229 |
+
|
| 230 |
+
@app.post("/reset_multi")
|
| 231 |
+
def reset_multi(task_id: str = "task_easy"):
|
| 232 |
+
"""Reset environment in multi-agent mode. Returns per-agent partial observations."""
|
| 233 |
+
if task_id not in TASKS:
|
| 234 |
+
raise HTTPException(404, f"Task '{task_id}' not found. Available: {list(TASKS.keys())}")
|
| 235 |
+
|
| 236 |
+
env = OpenGridEnv(copy.deepcopy(TASKS[task_id]))
|
| 237 |
+
zone_obs = env.reset_multi()
|
| 238 |
+
sid = str(uuid.uuid4())
|
| 239 |
+
|
| 240 |
+
zone_info = env.get_zone_info()
|
| 241 |
+
|
| 242 |
+
with _session_lock:
|
| 243 |
+
_cleanup_sessions()
|
| 244 |
+
sessions[sid] = _new_session(
|
| 245 |
+
env, task_id, mode="multi",
|
| 246 |
+
per_agent_rewards={i: [] for i in range(env.num_agents)},
|
| 247 |
+
)
|
| 248 |
+
# Store full-grid observation for visualization history
|
| 249 |
+
history[sid] = [env.state()]
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"session_id": sid,
|
| 253 |
+
"num_agents": env.num_agents,
|
| 254 |
+
"zone_info": {str(k): v.model_dump() for k, v in zone_info.items()},
|
| 255 |
+
"observations": {str(k): v.model_dump() for k, v in zone_obs.items()},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@app.post("/step_multi")
|
| 260 |
+
def step_multi(session_id: str, actions: MultiAgentAction):
|
| 261 |
+
"""Multi-agent step with safety layer and oversight.
|
| 262 |
+
|
| 263 |
+
Each agent submits actions for their zone. The safety layer validates,
|
| 264 |
+
the oversight agent evaluates coordination, and per-agent rewards are computed.
|
| 265 |
+
"""
|
| 266 |
+
session = _get_session(session_id)
|
| 267 |
+
|
| 268 |
+
with session["lock"]:
|
| 269 |
+
if session.get("done"):
|
| 270 |
+
raise HTTPException(400, "Episode already done. Call /reset_multi to start a new session.")
|
| 271 |
+
|
| 272 |
+
env = session["env"]
|
| 273 |
+
if session.get("mode") != "multi":
|
| 274 |
+
raise HTTPException(400, "Session not in multi-agent mode. Use /reset_multi first.")
|
| 275 |
+
|
| 276 |
+
# Convert string keys from JSON to int keys, with validation
|
| 277 |
+
agent_actions = {}
|
| 278 |
+
for k, v in actions.agent_actions.items():
|
| 279 |
+
try:
|
| 280 |
+
agent_id = int(k) if isinstance(k, str) else k
|
| 281 |
+
except (TypeError, ValueError):
|
| 282 |
+
raise HTTPException(400, f"Invalid agent_id: {k!r}")
|
| 283 |
+
if not (0 <= agent_id < env.num_agents):
|
| 284 |
+
raise HTTPException(
|
| 285 |
+
400,
|
| 286 |
+
f"Invalid agent_id {agent_id}; expected 0..{env.num_agents - 1}",
|
| 287 |
+
)
|
| 288 |
+
agent_actions[agent_id] = v
|
| 289 |
+
|
| 290 |
+
result = env.step_multi(agent_actions)
|
| 291 |
+
|
| 292 |
+
session["rewards"].append(result.team_reward)
|
| 293 |
+
session["done"] = result.done
|
| 294 |
+
session["is_blackout"] = result.info.is_blackout
|
| 295 |
+
for agent_id, reward in result.rewards.items():
|
| 296 |
+
if agent_id in session.get("per_agent_rewards", {}):
|
| 297 |
+
session["per_agent_rewards"][agent_id].append(reward.value)
|
| 298 |
+
|
| 299 |
+
# Store full-grid observation for visualization
|
| 300 |
+
with _session_lock:
|
| 301 |
+
history[session_id].append(env.state())
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"observations": {str(k): v.model_dump() for k, v in result.observations.items()},
|
| 305 |
+
"rewards": {str(k): v.model_dump() for k, v in result.rewards.items()},
|
| 306 |
+
"team_reward": result.team_reward,
|
| 307 |
+
"done": result.done,
|
| 308 |
+
"safety_reports": {str(k): v.model_dump() for k, v in result.safety_reports.items()},
|
| 309 |
+
"oversight_report": result.oversight_report.model_dump(),
|
| 310 |
+
"info": result.info.model_dump(),
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@app.get("/zones")
|
| 315 |
+
def get_zones(session_id: str):
|
| 316 |
+
"""Get zone assignments and agent info for a multi-agent session."""
|
| 317 |
+
session = _get_session(session_id)
|
| 318 |
+
|
| 319 |
+
with session["lock"]:
|
| 320 |
+
zone_info = session["env"].get_zone_info()
|
| 321 |
+
|
| 322 |
+
return {
|
| 323 |
+
"num_agents": session["env"].num_agents,
|
| 324 |
+
"zones": {str(k): v.model_dump() for k, v in zone_info.items()},
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ===========================================================================
|
| 329 |
+
# Grading & Baseline
|
| 330 |
+
# ===========================================================================
|
| 331 |
+
|
| 332 |
+
@app.get("/grader")
|
| 333 |
+
def run_grader(session_id: str):
|
| 334 |
+
"""
|
| 335 |
+
Grade a completed (or in-progress) session.
|
| 336 |
+
Returns a score strictly in the open interval (0, 1) using the same
|
| 337 |
+
normalization as the /baseline endpoint (analytical ceiling + empirical floor).
|
| 338 |
+
"""
|
| 339 |
+
session = _get_session(session_id)
|
| 340 |
+
|
| 341 |
+
with session["lock"]:
|
| 342 |
+
rewards = list(session["rewards"]) # snapshot under lock
|
| 343 |
+
task_id = session["task_id"]
|
| 344 |
+
is_blackout = session.get("is_blackout", False)
|
| 345 |
+
|
| 346 |
+
if not rewards:
|
| 347 |
+
return {"score": _SCORE_EPSILON, "message": "No steps taken yet. Run /step first."}
|
| 348 |
+
|
| 349 |
+
cumulative = sum(rewards)
|
| 350 |
+
n_steps = len(rewards)
|
| 351 |
+
|
| 352 |
+
grader = _get_grader(task_id)
|
| 353 |
+
bounds = grader.get_bounds()
|
| 354 |
+
n1_rate = 0.0 if is_blackout else 1.0
|
| 355 |
+
|
| 356 |
+
score = normalize_score(
|
| 357 |
+
cumulative_reward=cumulative,
|
| 358 |
+
reward_floor=bounds["reward_floor"],
|
| 359 |
+
reward_ceiling=bounds["reward_ceiling"],
|
| 360 |
+
n1_survival_rate=n1_rate
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Defense-in-depth: clamp again at the API boundary
|
| 364 |
+
score = _clamp_score(score)
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"score": score,
|
| 368 |
+
"cumulative_reward": round(cumulative, 4),
|
| 369 |
+
"steps": n_steps,
|
| 370 |
+
"is_blackout": is_blackout,
|
| 371 |
+
"task_id": task_id,
|
| 372 |
+
"reward_floor": bounds["reward_floor"],
|
| 373 |
+
"reward_ceiling": bounds["reward_ceiling"]
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@app.get("/baseline")
|
| 378 |
+
def run_baseline(use_llm: bool = False):
|
| 379 |
+
"""
|
| 380 |
+
Run baseline policy on all registered tasks. Returns 0.0–1.0 scores.
|
| 381 |
+
Default: heuristic (reproducible). Set use_llm=true for LLM agent.
|
| 382 |
+
|
| 383 |
+
Uses the same cached grader as /grader — bounds are computed once
|
| 384 |
+
and reused across all endpoints.
|
| 385 |
+
"""
|
| 386 |
+
api_key = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", ""))
|
| 387 |
+
if use_llm and not api_key:
|
| 388 |
+
raise HTTPException(
|
| 389 |
+
400,
|
| 390 |
+
"use_llm=true requires HF_TOKEN or OPENAI_API_KEY environment variable",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
policy = llm_policy if use_llm and api_key else heuristic_policy
|
| 394 |
+
policy_name = "llm" if policy is llm_policy else "heuristic"
|
| 395 |
+
|
| 396 |
+
results = {}
|
| 397 |
+
for task_id, config in TASKS.items():
|
| 398 |
+
grader = _get_grader(task_id) # cached — no duplicate rollouts
|
| 399 |
+
res = grader.evaluate_policy(policy, n_episodes=3)
|
| 400 |
+
results[task_id] = res
|
| 401 |
+
|
| 402 |
+
return {"policy": policy_name, "baseline_scores": results}
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@app.get("/visualize")
|
| 406 |
+
def visualize(session_id: str):
|
| 407 |
+
"""Generate a visualization of the current grid state and frequency history."""
|
| 408 |
+
session = _get_session(session_id)
|
| 409 |
+
|
| 410 |
+
with session["lock"]:
|
| 411 |
+
obs = session["env"].state()
|
| 412 |
+
with _session_lock:
|
| 413 |
+
hist = list(history.get(session_id, []))
|
| 414 |
+
|
| 415 |
+
img_str = generate_dashboard(hist, obs)
|
| 416 |
+
return {"image_base64": img_str}
|
changes.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Notebook Changes — opengrid_grpo_colab.ipynb
|
| 2 |
+
|
| 3 |
+
## Bug fixes applied (2026-04-25)
|
| 4 |
+
|
| 5 |
+
### Cell 7 — Generate Training Prompts
|
| 6 |
+
|
| 7 |
+
| # | Severity | Bug | Fix |
|
| 8 |
+
|---|----------|-----|-----|
|
| 9 |
+
| 1 | 🔴 Critical | `obs_dict = obs.model_dump()` produces dicts with integer keys; `Dataset.from_dict({"obs_context": obs_contexts})` fails with `ArrowTypeError: Expected dict key of type str or bytes, got 'int'` | Changed to `json.loads(obs.model_dump_json())` so all keys are strings; then stored as `json.dumps(obs_dict)` — a flat JSON string PyArrow handles trivially |
|
| 10 |
+
| 2 | 🟡 Bug | `env = OpenGridEnv(task_config)` instantiated before the loop but immediately replaced inside the loop — wasted object creation | Removed stray instantiation |
|
| 11 |
+
| 3 | 🟡 Bug | `import copy`, `import json` inside inner loop body — re-imported on every iteration | Moved to top of cell |
|
| 12 |
+
| 4 | 🟡 Bug | Slack bus included in random action choices — physics solver overwrites it, wasting action budget | Filtered to `['generator', 'battery']` only |
|
| 13 |
+
|
| 14 |
+
### Cell 8 — Reward Function
|
| 15 |
+
|
| 16 |
+
| # | Severity | Bug | Fix |
|
| 17 |
+
|---|----------|-----|-----|
|
| 18 |
+
| 5 | 🔴 Critical | `reward_fn` received `obs_context` as JSON strings from the dataset column but passed them directly to `compute_grpo_reward` which expects dicts | Added `json.loads(ctx) if isinstance(ctx, str) else ctx` deserialization before scoring |
|
| 19 |
+
| 6 | 🟡 Bug | No assertion to catch silent arity mismatches | Added `assert len(test_rewards) == 2` sanity check |
|
| 20 |
+
|
| 21 |
+
### Cell 9 — Training
|
| 22 |
+
|
| 23 |
+
| # | Severity | Bug | Fix |
|
| 24 |
+
|---|----------|-----|-----|
|
| 25 |
+
| 7 | 🟡 Bug | `bf16=torch.cuda.is_bf16_supported()` raises `AssertionError` when CUDA is not available (no GPU runtime) | Guarded: `_cuda_ok = torch.cuda.is_available()` then `_bf16 = _cuda_ok and ...` |
|
| 26 |
+
|
| 27 |
+
### Cell 12 — Before/After Plot
|
| 28 |
+
|
| 29 |
+
| # | Severity | Bug | Fix |
|
| 30 |
+
|---|----------|-----|-----|
|
| 31 |
+
| 8 | 🟡 Bug | Bar labels used `va='bottom'` for all bars; for negative-height bars the label renders inside/below the bar | Fixed: `va='bottom'` when `h >= 0`, `va='top'` when `h < 0`, with matching y-offset |
|
| 32 |
+
|
| 33 |
+
### Cell 13 — Summary Table
|
| 34 |
+
|
| 35 |
+
| # | Severity | Bug | Fix |
|
| 36 |
+
|---|----------|-----|-----|
|
| 37 |
+
| 9 | 🟡 Bug | `common_tasks` was set in Cell 12; if the user skips the plot cell, Cell 13 raises `NameError: common_tasks` | Rebuilt `common_tasks` defensively at the top of Cell 13 |
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## `inference.py` — Code review fixes (2026-04-25)
|
| 42 |
+
|
| 43 |
+
### High-priority fixes
|
| 44 |
+
|
| 45 |
+
| # | Severity | Issue | Fix |
|
| 46 |
+
|---|----------|-------|-----|
|
| 47 |
+
| 1 | 🔴 Bug | `parse_action()` crashes on valid JSON that is not an object (e.g. `[]`) — `AttributeError` not caught by `except (json.JSONDecodeError, KeyError)` | Rewrote with `isinstance(data, dict)` guard, list-unwrapping, field-type validation, and broad `except Exception` |
|
| 48 |
+
| 2 | 🔴 Bug | `parse_action()` markdown/prose stripping is fragile — fails on `Here is the action: {...}` | Extracts first `{...}` substring via `text.find("{")` / `text.rfind("}")` |
|
| 49 |
+
| 3 | 🔴 Reliability | `/grader` call can exceed `httpx` 30s timeout on first use (lazy `RobustnessGrader` bound estimation) | `grade()` now uses `timeout=180.0`; base client uses `httpx.Timeout(connect=10, read=60, write=30, pool=10)` |
|
| 50 |
+
| 4 | 🟡 Bug | `HF_TOKEN` takes precedence over `OPENAI_API_KEY` — if both set with OpenAI endpoint, auth fails | Changed to `API_KEY or OPENAI_API_KEY or HF_TOKEN` priority order |
|
| 51 |
+
| 5 | 🟡 Bug | No JSON-mode enforcement for LLM — models return markdown/prose | Added `response_format={"type": "json_object"}` with fallback for unsupported endpoints |
|
| 52 |
+
|
| 53 |
+
### System prompt fixes
|
| 54 |
+
|
| 55 |
+
| # | Severity | Issue | Fix |
|
| 56 |
+
|---|----------|-------|-----|
|
| 57 |
+
| 6 | 🟡 Design | Prompt says slack bus is controllable, but physics solver overwrites it | Changed to: "avoid adjusting the slack bus — physics overwrites it" |
|
| 58 |
+
| 7 | 🟡 Design | Single-agent mode allows topology actions without safety layer protection | Added: "Prefer NO topology actions unless absolutely necessary" |
|
| 59 |
+
| 8 | 🟡 Design | Multi-agent prompt says "Only for lines in your zone" but observations include boundary lines | Clarified: "Only for visible internal or boundary lines. Boundary-line switching is risky" |
|
| 60 |
+
|
| 61 |
+
### Multi-agent robustness fixes
|
| 62 |
+
|
| 63 |
+
| # | Severity | Issue | Fix |
|
| 64 |
+
|---|----------|-------|-----|
|
| 65 |
+
| 9 | 🟡 Bug | Agent iteration uses `range(num_agents)` — assumes contiguous integer IDs | Changed to `sorted(observations.keys())` |
|
| 66 |
+
| 10 | 🟡 Bug | `safety_reports` assumed to be list, but API returns dict keyed by agent ID | Added `isinstance` check to handle both list and dict formats |
|
| 67 |
+
| 11 | 🟡 Design | Safety correction feedback not fed back to LLM — model repeats same invalid actions | Appended `[SAFETY] {reason}` to agent history when corrections occur |
|
| 68 |
+
|
| 69 |
+
### Other fixes
|
| 70 |
+
|
| 71 |
+
| # | Severity | Issue | Fix |
|
| 72 |
+
|---|----------|-------|-----|
|
| 73 |
+
| 12 | 🟡 Bug | `MAX_STEPS = 50` hardcoded — may truncate future tasks | Changed to `MAX_STEPS = 100` as safety cap; `done` flag is the true terminator |
|
| 74 |
+
| 13 | 🟡 Bug | Default task list excludes `task_karnataka` despite KPTCL multi-agent framing | Added `task_karnataka` to `TASKS` list |
|
| 75 |
+
| 14 | 🟡 Bug | Module docstring says all 3 env vars are required; only API key is | Fixed docstring to document defaults and actual requirements |
|
| 76 |
+
| 15 | 🟡 Bug | `[END]` log prints score at `.2f` but summary prints `.4f` — precision loss | Changed `log_end` to use `:.4f` |
|
| 77 |
+
| 16 | 🟡 Reliability | `OpenAI()` client has no timeout or retry config | Added `timeout=30.0, max_retries=2` |
|
| 78 |
+
| 17 | 🟢 Feature | No `list_tasks()` method on `EnvClient` | Added `list_tasks()` for future task validation |
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## GRPO Training — Environment-Grounded Rewards (2026-04-25)
|
| 83 |
+
|
| 84 |
+
### Root Cause: Proxy Reward Disconnect
|
| 85 |
+
|
| 86 |
+
The original `compute_grpo_reward` was a **heuristic proxy scorer** that evaluated JSON format, direction, and proportionality without ever stepping the environment. The model optimized this proxy, which did not correlate with actual grid physics rewards. Result: zero improvement over baseline.
|
| 87 |
+
|
| 88 |
+
### Changes Made
|
| 89 |
+
|
| 90 |
+
#### `src/environment.py`
|
| 91 |
+
|
| 92 |
+
| # | Change | Purpose |
|
| 93 |
+
|---|--------|---------|
|
| 94 |
+
| 1 | Added `_set_state(obs_dict)` method to `OpenGridEnv` | Enables restoring environment to any observed state for reward computation. Rebuilds bus/line state, frequency, and slack injection from observation dicts. |
|
| 95 |
+
|
| 96 |
+
#### `training/train_grpo.py`
|
| 97 |
+
|
| 98 |
+
| # | Severity | Change | Details |
|
| 99 |
+
|---|----------|--------|---------|
|
| 100 |
+
| 2 | 🔴 Critical | Replaced `compute_grpo_reward` with `compute_grpo_reward_env` | New reward function **actually steps the physics simulation**: restores env state → steps with LLM action → measures real reward → runs mini-rollout with heuristic continuation for trajectory awareness |
|
| 101 |
+
| 3 | 🔴 Critical | Added mini-rollout scoring (horizon=3) | After the LLM's action, runs 2 more steps with heuristic policy to capture trajectory-level impact. Combines: `immediate_reward + 0.5 * rollout_reward` |
|
| 102 |
+
| 4 | 🟡 Medium | Increased `num_generations` from 4 → 8 | Wider GRPO group = more reward variance = stronger ranking signal. Prevents the advantage calculation from collapsing to zero. |
|
| 103 |
+
| 5 | 🟡 Medium | Increased random perturbation range from ±15 → ±30 MW | Creates more diverse/stressed grid states during training data generation. Model sees near-blackout and overload scenarios. |
|
| 104 |
+
| 6 | 🟡 Medium | Added adversarial battery drain (every 5th episode) | Forces model to learn actions when batteries are near-empty — a critical edge case the original data lacked. |
|
| 105 |
+
| 7 | 🟡 Medium | Multi-bus perturbations (1-2 buses per step) | Was single-bus. More diverse action patterns create richer state transitions. |
|
| 106 |
+
| 8 | 🟡 Medium | Increased learning rate from 5e-6 → 1e-5 | Slightly more aggressive to capitalize on the now-meaningful reward signal. |
|
| 107 |
+
| 9 | 🟡 Medium | Increased gradient accumulation (effective batch 16) | Smoother gradients for more stable training. |
|
| 108 |
+
| 10 | 🟡 Medium | Steps per episode increased from 10 → 15 | More temporal diversity in observations. |
|
| 109 |
+
| 11 | 🟢 Minor | obs_context stored as JSON string | Fixes Arrow serialization (PyArrow can't handle dicts with int keys). |
|
| 110 |
+
| 12 | 🟢 Minor | Kept legacy `compute_grpo_reward` for test-mode compat | Backward compatibility with `--test-mode` pipeline verification. |
|
| 111 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenGrid Inference Script
|
| 3 |
+
=========================
|
| 4 |
+
Runs an LLM agent against all OpenGrid tasks via the OpenAI-compatible API.
|
| 5 |
+
Supports both single-agent and multi-agent POMDP modes.
|
| 6 |
+
|
| 7 |
+
Optional environment variables:
|
| 8 |
+
API_BASE_URL -- defaults to https://api.openai.com/v1
|
| 9 |
+
MODEL_NAME -- defaults to gpt-4o
|
| 10 |
+
Required (one of):
|
| 11 |
+
OPENAI_API_KEY or HF_TOKEN
|
| 12 |
+
|
| 13 |
+
Emits structured [START], [STEP], [END] logs to stdout.
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
# Single-agent mode (backward compatible)
|
| 17 |
+
python inference.py
|
| 18 |
+
|
| 19 |
+
# Multi-agent mode (uses safety layer + oversight)
|
| 20 |
+
python inference.py --multi
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import json
|
| 26 |
+
import math
|
| 27 |
+
import argparse
|
| 28 |
+
import httpx
|
| 29 |
+
|
| 30 |
+
from openai import OpenAI
|
| 31 |
+
|
| 32 |
+
# ---------- Configuration ----------
|
| 33 |
+
|
| 34 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 35 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
| 36 |
+
|
| 37 |
+
# Prefer OPENAI_API_KEY when using OpenAI endpoint; otherwise try HF_TOKEN
|
| 38 |
+
API_KEY = (
|
| 39 |
+
os.environ.get("API_KEY")
|
| 40 |
+
or os.environ.get("OPENAI_API_KEY")
|
| 41 |
+
or os.environ.get("HF_TOKEN")
|
| 42 |
+
or ""
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 46 |
+
BENCHMARK = "OpenGrid"
|
| 47 |
+
# Safety cap — the environment's 'done' flag is the true terminator
|
| 48 |
+
MAX_STEPS = 100
|
| 49 |
+
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 50 |
+
|
| 51 |
+
TASKS = ["task_easy", "task_medium", "task_hard", "task_karnataka"]
|
| 52 |
+
|
| 53 |
+
SYSTEM_PROMPT_SINGLE = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
|
| 54 |
+
|
| 55 |
+
Key objectives:
|
| 56 |
+
1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5-50.5 Hz)
|
| 57 |
+
2. Prevent transmission line overloads (rho < 1.0)
|
| 58 |
+
3. Avoid grid islanding (blackout)
|
| 59 |
+
|
| 60 |
+
Available actions:
|
| 61 |
+
1. bus_adjustments: List of {"bus_id": int, "delta": float}
|
| 62 |
+
- Positive delta = increase power injection (discharge battery / ramp up generator)
|
| 63 |
+
- Negative delta = decrease power injection (charge battery / ramp down generator)
|
| 64 |
+
- Only works on battery and generator buses (avoid adjusting the slack bus — physics overwrites it)
|
| 65 |
+
2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
|
| 66 |
+
- Opening a line removes it; closing reconnects. 3-step cooldown.
|
| 67 |
+
- WARNING: Opening lines can cause islanding -> blackout
|
| 68 |
+
- Prefer NO topology actions unless absolutely necessary. Always return "topology_actions": []
|
| 69 |
+
|
| 70 |
+
Strategy:
|
| 71 |
+
- If frequency < 50 Hz -> discharge batteries, ramp up generators
|
| 72 |
+
- If frequency > 50 Hz -> charge batteries, ramp down generators
|
| 73 |
+
- If a line rho > 0.9 -> reduce generation near that line, do NOT open it
|
| 74 |
+
- Prefer minimal actions over aggressive switching
|
| 75 |
+
|
| 76 |
+
Respond with ONLY a valid JSON object. Example:
|
| 77 |
+
{"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
SYSTEM_PROMPT_MULTI = """You are a KPTCL Zone Controller AI managing one zone of the Karnataka power grid.
|
| 81 |
+
You can only see and control buses in YOUR zone. Other zones are managed by other agents.
|
| 82 |
+
|
| 83 |
+
Key objectives:
|
| 84 |
+
1. Keep grid frequency close to 50.0 Hz (you see a noisy reading)
|
| 85 |
+
2. Prevent line overloads in your zone (rho < 1.0)
|
| 86 |
+
3. Coordinate with other zones (don't fight against them)
|
| 87 |
+
4. Avoid actions that would trigger the safety layer
|
| 88 |
+
|
| 89 |
+
Available actions:
|
| 90 |
+
1. bus_adjustments: List of {"bus_id": int, "delta": float}
|
| 91 |
+
- ONLY adjust battery and generator buses in YOUR zone (avoid slack — physics overwrites it)
|
| 92 |
+
- Positive delta = increase power injection
|
| 93 |
+
- Negative delta = decrease power injection
|
| 94 |
+
2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
|
| 95 |
+
- Only for visible internal or boundary lines. Safety layer will block dangerous switches.
|
| 96 |
+
- Boundary-line switching is risky; avoid unless necessary.
|
| 97 |
+
|
| 98 |
+
Strategy:
|
| 99 |
+
- If frequency < 50 Hz -> increase generation/discharge in your zone
|
| 100 |
+
- If frequency > 50 Hz -> decrease generation/charge in your zone
|
| 101 |
+
- Check neighbor signals to understand if other zones are compensating
|
| 102 |
+
- Prefer small corrections over large swings
|
| 103 |
+
|
| 104 |
+
Respond with ONLY a valid JSON object. Example:
|
| 105 |
+
{"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------- Structured Logging ----------
|
| 110 |
+
|
| 111 |
+
def log_start(task: str, env: str, model: str, mode: str = "single"):
|
| 112 |
+
print(f"[START] task={task} env={env} model={model} mode={mode}", flush=True)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def log_step(step: int, action: str, reward: float, done: bool, error=None, agent_id=None):
|
| 116 |
+
done_val = str(done).lower()
|
| 117 |
+
error_val = str(error) if error else "null"
|
| 118 |
+
agent_str = f" agent={agent_id}" if agent_id is not None else ""
|
| 119 |
+
print(
|
| 120 |
+
f"[STEP] step={step}{agent_str} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 121 |
+
flush=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def clamp_score(s: float) -> float:
|
| 126 |
+
"""Ensure score is strictly in (0, 1). Mirrors grader._clamp_score."""
|
| 127 |
+
try:
|
| 128 |
+
s = float(s)
|
| 129 |
+
except (TypeError, ValueError):
|
| 130 |
+
return 0.5
|
| 131 |
+
if not math.isfinite(s):
|
| 132 |
+
return 0.5
|
| 133 |
+
s = max(0.02, min(0.98, s))
|
| 134 |
+
s = math.floor(s * 10000) / 10000
|
| 135 |
+
return max(0.02, min(0.98, s))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def log_end(success: bool, steps: int, score: float, rewards: list, mode: str = "single"):
|
| 139 |
+
clamped = clamp_score(score)
|
| 140 |
+
success_val = str(success).lower()
|
| 141 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 142 |
+
print(
|
| 143 |
+
f"[END] success={success_val} steps={steps} score={clamped:.4f} rewards={rewards_str} mode={mode}",
|
| 144 |
+
flush=True,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ---------- LLM Call ----------
|
| 149 |
+
|
| 150 |
+
def get_model_message(client: OpenAI, step: int, obs_json: str, last_reward: float,
|
| 151 |
+
history: list, system_prompt: str, zone_name: str = None) -> str:
|
| 152 |
+
"""Ask the LLM what action to take given the current observation."""
|
| 153 |
+
context = ""
|
| 154 |
+
if zone_name:
|
| 155 |
+
context += f"[Zone: {zone_name}] "
|
| 156 |
+
context += f"Step {step} | Last reward: {last_reward:+.2f}\n"
|
| 157 |
+
if history:
|
| 158 |
+
context += "Recent history (last 3):\n" + "\n".join(history[-3:]) + "\n\n"
|
| 159 |
+
context += f"Current Grid State:\n{obs_json}"
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
kwargs = dict(
|
| 163 |
+
model=MODEL_NAME,
|
| 164 |
+
messages=[
|
| 165 |
+
{"role": "system", "content": system_prompt},
|
| 166 |
+
{"role": "user", "content": context}
|
| 167 |
+
],
|
| 168 |
+
temperature=0.0,
|
| 169 |
+
max_tokens=300,
|
| 170 |
+
)
|
| 171 |
+
# Use JSON mode if the endpoint supports it (OpenAI-compatible)
|
| 172 |
+
try:
|
| 173 |
+
kwargs["response_format"] = {"type": "json_object"}
|
| 174 |
+
response = client.chat.completions.create(**kwargs)
|
| 175 |
+
except Exception:
|
| 176 |
+
# Fallback: endpoint may not support response_format
|
| 177 |
+
kwargs.pop("response_format", None)
|
| 178 |
+
response = client.chat.completions.create(**kwargs)
|
| 179 |
+
return response.choices[0].message.content.strip()
|
| 180 |
+
except Exception as exc:
|
| 181 |
+
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 182 |
+
return '{"bus_adjustments": [], "topology_actions": []}'
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------- Environment Client ----------
|
| 186 |
+
|
| 187 |
+
class EnvClient:
|
| 188 |
+
"""HTTP client for the OpenGrid FastAPI environment."""
|
| 189 |
+
|
| 190 |
+
def __init__(self, base_url: str):
|
| 191 |
+
self.base_url = base_url.rstrip("/")
|
| 192 |
+
self.client = httpx.Client(
|
| 193 |
+
timeout=httpx.Timeout(connect=10.0, read=60.0, write=30.0, pool=10.0)
|
| 194 |
+
)
|
| 195 |
+
self.session_id = None
|
| 196 |
+
|
| 197 |
+
# --- Single-Agent ---
|
| 198 |
+
|
| 199 |
+
def reset(self, task_id: str) -> dict:
|
| 200 |
+
resp = self.client.post(f"{self.base_url}/reset", params={"task_id": task_id})
|
| 201 |
+
resp.raise_for_status()
|
| 202 |
+
data = resp.json()
|
| 203 |
+
self.session_id = data["session_id"]
|
| 204 |
+
return data["observation"]
|
| 205 |
+
|
| 206 |
+
def step(self, action_dict: dict) -> dict:
|
| 207 |
+
resp = self.client.post(
|
| 208 |
+
f"{self.base_url}/step",
|
| 209 |
+
params={"session_id": self.session_id},
|
| 210 |
+
json=action_dict
|
| 211 |
+
)
|
| 212 |
+
resp.raise_for_status()
|
| 213 |
+
return resp.json()
|
| 214 |
+
|
| 215 |
+
# --- Multi-Agent ---
|
| 216 |
+
|
| 217 |
+
def reset_multi(self, task_id: str) -> dict:
|
| 218 |
+
resp = self.client.post(f"{self.base_url}/reset_multi", params={"task_id": task_id})
|
| 219 |
+
resp.raise_for_status()
|
| 220 |
+
data = resp.json()
|
| 221 |
+
self.session_id = data["session_id"]
|
| 222 |
+
return data
|
| 223 |
+
|
| 224 |
+
def step_multi(self, agent_actions: dict) -> dict:
|
| 225 |
+
resp = self.client.post(
|
| 226 |
+
f"{self.base_url}/step_multi",
|
| 227 |
+
params={"session_id": self.session_id},
|
| 228 |
+
json={"agent_actions": agent_actions}
|
| 229 |
+
)
|
| 230 |
+
resp.raise_for_status()
|
| 231 |
+
return resp.json()
|
| 232 |
+
|
| 233 |
+
# --- Shared ---
|
| 234 |
+
|
| 235 |
+
def state(self) -> dict:
|
| 236 |
+
resp = self.client.get(f"{self.base_url}/state", params={"session_id": self.session_id})
|
| 237 |
+
resp.raise_for_status()
|
| 238 |
+
return resp.json()
|
| 239 |
+
|
| 240 |
+
def grade(self) -> dict:
|
| 241 |
+
# Grading can trigger lazy bound estimation (multiple rollouts) — use long timeout
|
| 242 |
+
resp = self.client.get(
|
| 243 |
+
f"{self.base_url}/grader",
|
| 244 |
+
params={"session_id": self.session_id},
|
| 245 |
+
timeout=180.0,
|
| 246 |
+
)
|
| 247 |
+
resp.raise_for_status()
|
| 248 |
+
return resp.json()
|
| 249 |
+
|
| 250 |
+
def list_tasks(self) -> list:
|
| 251 |
+
"""Fetch available tasks from the server."""
|
| 252 |
+
resp = self.client.get(f"{self.base_url}/tasks")
|
| 253 |
+
resp.raise_for_status()
|
| 254 |
+
return resp.json()
|
| 255 |
+
|
| 256 |
+
def close(self):
|
| 257 |
+
self.client.close()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ---------- Parse Action ----------
|
| 261 |
+
|
| 262 |
+
NOOP_ACTION = {"bus_adjustments": [], "topology_actions": []}
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def parse_action(response_text: str) -> dict:
|
| 266 |
+
"""Parse LLM JSON response into an action dict.
|
| 267 |
+
|
| 268 |
+
Handles markdown fences, prose preambles, JSON lists, and malformed output.
|
| 269 |
+
"""
|
| 270 |
+
try:
|
| 271 |
+
text = str(response_text).strip()
|
| 272 |
+
|
| 273 |
+
# Strip markdown code fences
|
| 274 |
+
if text.startswith("```"):
|
| 275 |
+
lines = text.splitlines()
|
| 276 |
+
if lines and lines[0].startswith("```"):
|
| 277 |
+
lines = lines[1:]
|
| 278 |
+
if lines and lines[-1].startswith("```"):
|
| 279 |
+
lines = lines[:-1]
|
| 280 |
+
text = "\n".join(lines).strip()
|
| 281 |
+
|
| 282 |
+
# Extract first JSON object from any surrounding prose
|
| 283 |
+
start = text.find("{")
|
| 284 |
+
end = text.rfind("}")
|
| 285 |
+
if start < 0 or end <= start:
|
| 286 |
+
return dict(NOOP_ACTION)
|
| 287 |
+
|
| 288 |
+
data = json.loads(text[start:end + 1])
|
| 289 |
+
|
| 290 |
+
# Handle list wrapping (e.g. [{...}])
|
| 291 |
+
if isinstance(data, list):
|
| 292 |
+
data = data[0] if data else {}
|
| 293 |
+
if not isinstance(data, dict):
|
| 294 |
+
return dict(NOOP_ACTION)
|
| 295 |
+
|
| 296 |
+
bus_adjustments = data.get("bus_adjustments", [])
|
| 297 |
+
topology_actions = data.get("topology_actions", [])
|
| 298 |
+
|
| 299 |
+
if not isinstance(bus_adjustments, list):
|
| 300 |
+
bus_adjustments = []
|
| 301 |
+
if not isinstance(topology_actions, list):
|
| 302 |
+
topology_actions = []
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
"bus_adjustments": bus_adjustments,
|
| 306 |
+
"topology_actions": topology_actions,
|
| 307 |
+
}
|
| 308 |
+
except Exception:
|
| 309 |
+
return dict(NOOP_ACTION)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ---------- Single-Agent Runner ----------
|
| 313 |
+
|
| 314 |
+
def run_task_single(client: OpenAI, env: EnvClient, task_id: str) -> dict:
|
| 315 |
+
"""Run one task in single-agent mode and return results."""
|
| 316 |
+
history_msgs = []
|
| 317 |
+
rewards = []
|
| 318 |
+
steps_taken = 0
|
| 319 |
+
score = 0.05
|
| 320 |
+
success = False
|
| 321 |
+
|
| 322 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME, mode="single")
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
obs = env.reset(task_id)
|
| 326 |
+
last_reward = 0.0
|
| 327 |
+
|
| 328 |
+
for step_num in range(1, MAX_STEPS + 1):
|
| 329 |
+
obs_json = json.dumps(obs, indent=2)
|
| 330 |
+
message = get_model_message(client, step_num, obs_json, last_reward,
|
| 331 |
+
history_msgs, SYSTEM_PROMPT_SINGLE)
|
| 332 |
+
action_dict = parse_action(message)
|
| 333 |
+
|
| 334 |
+
result = env.step(action_dict)
|
| 335 |
+
obs = result["observation"]
|
| 336 |
+
reward = result.get("reward", {}).get("value", 0.0)
|
| 337 |
+
done = result.get("done", False)
|
| 338 |
+
|
| 339 |
+
rewards.append(reward)
|
| 340 |
+
steps_taken = step_num
|
| 341 |
+
last_reward = reward
|
| 342 |
+
|
| 343 |
+
action_summary = json.dumps(action_dict)
|
| 344 |
+
if len(action_summary) > 200:
|
| 345 |
+
action_summary = action_summary[:200] + "..."
|
| 346 |
+
|
| 347 |
+
log_step(step=step_num, action=action_summary, reward=reward, done=done)
|
| 348 |
+
|
| 349 |
+
history_msgs.append(f"Step {step_num}: action={action_summary[:80]} -> reward {reward:+.2f}")
|
| 350 |
+
|
| 351 |
+
if done:
|
| 352 |
+
break
|
| 353 |
+
|
| 354 |
+
grade_result = env.grade()
|
| 355 |
+
score = clamp_score(grade_result.get("score", 0.5))
|
| 356 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"[DEBUG] Task {task_id} error: {e}", flush=True)
|
| 360 |
+
score = 0.05
|
| 361 |
+
success = False
|
| 362 |
+
|
| 363 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards, mode="single")
|
| 364 |
+
|
| 365 |
+
return {"task": task_id, "score": score, "steps": steps_taken, "success": success}
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ---------- Multi-Agent Runner ----------
|
| 369 |
+
|
| 370 |
+
def run_task_multi(client: OpenAI, env: EnvClient, task_id: str) -> dict:
|
| 371 |
+
"""Run one task in multi-agent mode and return results."""
|
| 372 |
+
rewards = []
|
| 373 |
+
steps_taken = 0
|
| 374 |
+
score = 0.05
|
| 375 |
+
success = False
|
| 376 |
+
total_safety_interventions = 0
|
| 377 |
+
|
| 378 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME, mode="multi")
|
| 379 |
+
|
| 380 |
+
try:
|
| 381 |
+
reset_data = env.reset_multi(task_id)
|
| 382 |
+
num_agents = reset_data["num_agents"]
|
| 383 |
+
zone_info = reset_data["zone_info"]
|
| 384 |
+
observations = reset_data["observations"]
|
| 385 |
+
|
| 386 |
+
# Per-agent history
|
| 387 |
+
agent_histories = {str(i): [] for i in range(num_agents)}
|
| 388 |
+
last_rewards = {str(i): 0.0 for i in range(num_agents)}
|
| 389 |
+
|
| 390 |
+
print(f"[INFO] Multi-agent mode: {num_agents} agents", flush=True)
|
| 391 |
+
for aid, zi in zone_info.items():
|
| 392 |
+
print(f" Agent {aid}: {zi['zone_name']} ({len(zi['bus_ids'])} buses)", flush=True)
|
| 393 |
+
|
| 394 |
+
for step_num in range(1, MAX_STEPS + 1):
|
| 395 |
+
agent_actions = {}
|
| 396 |
+
|
| 397 |
+
# Each agent generates its own action based on partial observation
|
| 398 |
+
for agent_id_str in sorted(observations.keys()):
|
| 399 |
+
obs = observations.get(agent_id_str, {})
|
| 400 |
+
zone_name = zone_info.get(agent_id_str, {}).get("zone_name", f"Zone_{agent_id_str}")
|
| 401 |
+
|
| 402 |
+
obs_json = json.dumps(obs, indent=2)
|
| 403 |
+
message = get_model_message(
|
| 404 |
+
client, step_num, obs_json,
|
| 405 |
+
last_rewards[agent_id_str],
|
| 406 |
+
agent_histories[agent_id_str],
|
| 407 |
+
SYSTEM_PROMPT_MULTI,
|
| 408 |
+
zone_name=zone_name
|
| 409 |
+
)
|
| 410 |
+
action_dict = parse_action(message)
|
| 411 |
+
agent_actions[agent_id_str] = action_dict
|
| 412 |
+
|
| 413 |
+
# Submit all actions together
|
| 414 |
+
result = env.step_multi(agent_actions)
|
| 415 |
+
observations = result["observations"]
|
| 416 |
+
team_reward = result.get("team_reward", 0.0)
|
| 417 |
+
done = result.get("done", False)
|
| 418 |
+
|
| 419 |
+
# Track safety interventions
|
| 420 |
+
safety_reports = result.get("safety_reports", {})
|
| 421 |
+
if isinstance(safety_reports, list):
|
| 422 |
+
# Handle list format from older API
|
| 423 |
+
step_interventions = sum(1 for sr in safety_reports if sr.get("was_corrected", False))
|
| 424 |
+
else:
|
| 425 |
+
step_interventions = sum(
|
| 426 |
+
1 for sr in safety_reports.values() if sr.get("was_corrected", False)
|
| 427 |
+
)
|
| 428 |
+
total_safety_interventions += step_interventions
|
| 429 |
+
|
| 430 |
+
# Feed safety correction feedback into agent histories
|
| 431 |
+
if isinstance(safety_reports, dict):
|
| 432 |
+
for aid_str, sr in safety_reports.items():
|
| 433 |
+
if sr.get("was_corrected") and aid_str in agent_histories:
|
| 434 |
+
reason = sr.get("correction_reason", "action corrected")[:120]
|
| 435 |
+
agent_histories[aid_str].append(f"[SAFETY] {reason}")
|
| 436 |
+
|
| 437 |
+
# Log per-agent rewards
|
| 438 |
+
per_agent_rewards = result.get("rewards", {})
|
| 439 |
+
for agent_id_str in sorted(observations.keys()):
|
| 440 |
+
agent_reward = per_agent_rewards.get(agent_id_str, {}).get("value", 0.0)
|
| 441 |
+
last_rewards[agent_id_str] = agent_reward
|
| 442 |
+
action_summary = json.dumps(agent_actions.get(agent_id_str, {}))
|
| 443 |
+
if len(action_summary) > 100:
|
| 444 |
+
action_summary = action_summary[:100] + "..."
|
| 445 |
+
agent_histories[agent_id_str].append(
|
| 446 |
+
f"Step {step_num}: action={action_summary[:60]} -> reward {agent_reward:+.2f}"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
rewards.append(team_reward)
|
| 450 |
+
steps_taken = step_num
|
| 451 |
+
|
| 452 |
+
# Log team-level step
|
| 453 |
+
oversight = result.get("oversight_report", {})
|
| 454 |
+
coord_score = oversight.get("coordination_score", 1.0)
|
| 455 |
+
safety_str = f" safety_corrections={step_interventions}" if step_interventions > 0 else ""
|
| 456 |
+
log_step(step=step_num, action=f"team_reward={team_reward:.2f} coord={coord_score:.2f}{safety_str}",
|
| 457 |
+
reward=team_reward, done=done)
|
| 458 |
+
|
| 459 |
+
if done:
|
| 460 |
+
break
|
| 461 |
+
|
| 462 |
+
grade_result = env.grade()
|
| 463 |
+
score = clamp_score(grade_result.get("score", 0.5))
|
| 464 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 465 |
+
|
| 466 |
+
except Exception as e:
|
| 467 |
+
print(f"[DEBUG] Task {task_id} multi-agent error: {e}", flush=True)
|
| 468 |
+
score = 0.05
|
| 469 |
+
success = False
|
| 470 |
+
|
| 471 |
+
print(f"[INFO] Total safety interventions: {total_safety_interventions}", flush=True)
|
| 472 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards, mode="multi")
|
| 473 |
+
|
| 474 |
+
return {
|
| 475 |
+
"task": task_id, "score": score, "steps": steps_taken,
|
| 476 |
+
"success": success, "safety_interventions": total_safety_interventions
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# ---------- Main ----------
|
| 481 |
+
|
| 482 |
+
def main():
|
| 483 |
+
"""Run inference on all tasks."""
|
| 484 |
+
parser = argparse.ArgumentParser(description="OpenGrid LLM Inference")
|
| 485 |
+
parser.add_argument("--multi", action="store_true",
|
| 486 |
+
help="Use multi-agent POMDP mode (default: single-agent)")
|
| 487 |
+
parser.add_argument("--tasks", nargs="+", default=TASKS,
|
| 488 |
+
help="Which tasks to run (default: all)")
|
| 489 |
+
args = parser.parse_args()
|
| 490 |
+
|
| 491 |
+
if not API_KEY:
|
| 492 |
+
print("[ERROR] No API key found. Set OPENAI_API_KEY or HF_TOKEN environment variable.", flush=True)
|
| 493 |
+
sys.exit(1)
|
| 494 |
+
|
| 495 |
+
mode = "multi-agent" if args.multi else "single-agent"
|
| 496 |
+
print(f"[CONFIG] API_BASE_URL={API_BASE_URL}", flush=True)
|
| 497 |
+
print(f"[CONFIG] MODEL_NAME={MODEL_NAME}", flush=True)
|
| 498 |
+
print(f"[CONFIG] ENV_URL={ENV_URL}", flush=True)
|
| 499 |
+
print(f"[CONFIG] MODE={mode}", flush=True)
|
| 500 |
+
|
| 501 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=30.0, max_retries=2)
|
| 502 |
+
env = EnvClient(ENV_URL)
|
| 503 |
+
|
| 504 |
+
all_results = []
|
| 505 |
+
runner = run_task_multi if args.multi else run_task_single
|
| 506 |
+
|
| 507 |
+
try:
|
| 508 |
+
for task_id in args.tasks:
|
| 509 |
+
print(f"\n{'='*60}", flush=True)
|
| 510 |
+
print(f"Running task: {task_id} ({mode})", flush=True)
|
| 511 |
+
print(f"{'='*60}", flush=True)
|
| 512 |
+
|
| 513 |
+
result = runner(client, env, task_id)
|
| 514 |
+
all_results.append(result)
|
| 515 |
+
|
| 516 |
+
finally:
|
| 517 |
+
env.close()
|
| 518 |
+
|
| 519 |
+
# Summary
|
| 520 |
+
print(f"\n{'='*60}", flush=True)
|
| 521 |
+
print(f"FINAL RESULTS ({mode})", flush=True)
|
| 522 |
+
print(f"{'='*60}", flush=True)
|
| 523 |
+
for r in all_results:
|
| 524 |
+
status = "PASS" if r["success"] else "FAIL"
|
| 525 |
+
extra = ""
|
| 526 |
+
if "safety_interventions" in r:
|
| 527 |
+
extra = f" safety={r['safety_interventions']}"
|
| 528 |
+
print(f" {r['task']}: score={r['score']:.4f} steps={r['steps']} [{status}]{extra}", flush=True)
|
| 529 |
+
|
| 530 |
+
avg_score = sum(r["score"] for r in all_results) / len(all_results) if all_results else 0
|
| 531 |
+
print(f"\n Average Score: {avg_score:.4f}", flush=True)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
if __name__ == "__main__":
|
| 535 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: opengrid
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: app:app
|
| 6 |
+
port: 7860
|
| 7 |
+
|
| 8 |
+
# Environment supports both single-agent and multi-agent POMDP modes.
|
| 9 |
+
# Single-agent: /reset + /step (backward compatible)
|
| 10 |
+
# Multi-agent: /reset_multi + /step_multi (2-3 agents per zone)
|
| 11 |
+
|
| 12 |
+
tasks:
|
| 13 |
+
- id: task_easy
|
| 14 |
+
name: Easy Grid (5 buses, 2 agents, 20% renewables)
|
| 15 |
+
description: Basic frequency control with 2-zone coordination
|
| 16 |
+
agents: 2
|
| 17 |
+
grader:
|
| 18 |
+
endpoint: /grader
|
| 19 |
+
score_range: [0.02, 0.98]
|
| 20 |
+
- id: task_medium
|
| 21 |
+
name: Medium Grid (10 buses, 3 agents, 50% renewables)
|
| 22 |
+
description: Congestion management with 3-zone POMDP and volatile renewables
|
| 23 |
+
agents: 3
|
| 24 |
+
grader:
|
| 25 |
+
endpoint: /grader
|
| 26 |
+
score_range: [0.02, 0.98]
|
| 27 |
+
- id: task_hard
|
| 28 |
+
name: Hard Grid (14 buses, 3 agents, 70% renewables)
|
| 29 |
+
description: High volatility, tight margins, complex topology with safety constraints
|
| 30 |
+
agents: 3
|
| 31 |
+
grader:
|
| 32 |
+
endpoint: /grader
|
| 33 |
+
score_range: [0.02, 0.98]
|
| 34 |
+
- id: task_karnataka
|
| 35 |
+
name: Karnataka KPTCL Grid (5 buses, 2 agents, real-world topology)
|
| 36 |
+
description: Realistic Karnataka power grid with POMDP multi-agent coordination
|
| 37 |
+
agents: 2
|
| 38 |
+
grader:
|
| 39 |
+
endpoint: /grader
|
| 40 |
+
score_range: [0.02, 0.98]
|
pyproject.toml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.backends._legacy:_Backend"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "opengrid"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Renewable energy grid load-balancing environment for AI agents"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "KRISHNA GOYAL", email = "krishnagoyalcse@gmail.com"}
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"fastapi",
|
| 17 |
+
"uvicorn",
|
| 18 |
+
"pydantic>=2.0",
|
| 19 |
+
"numpy",
|
| 20 |
+
"networkx",
|
| 21 |
+
"matplotlib",
|
| 22 |
+
"openai",
|
| 23 |
+
"httpx",
|
| 24 |
+
"openenv-core>=0.2.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[project.urls]
|
| 28 |
+
Homepage = "https://github.com/K446/opengrid"
|
| 29 |
+
|
| 30 |
+
[project.scripts]
|
| 31 |
+
server = "server.app:main"
|
| 32 |
+
|
| 33 |
+
[tool.setuptools.packages.find]
|
| 34 |
+
where = ["."]
|
| 35 |
+
include = ["src*", "server*"]
|
| 36 |
+
|
| 37 |
+
[tool.pyright]
|
| 38 |
+
venvPath = "."
|
| 39 |
+
venv = ".venv"
|
| 40 |
+
pythonVersion = "3.13"
|
| 41 |
+
extraPaths = ["."]
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
numpy
|
| 5 |
+
networkx
|
| 6 |
+
matplotlib
|
| 7 |
+
openai
|
| 8 |
+
httpx
|
| 9 |
+
openenv-core>=0.2.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Server package
|
server/app.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenGrid server entry point — used by openenv for multi-mode deployment.
|
| 3 |
+
Re-exports the FastAPI app from the root app module.
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import uvicorn
|
| 8 |
+
|
| 9 |
+
# Add parent directory to path so we can import from the root package
|
| 10 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
+
|
| 12 |
+
from app import app # type: ignore[import-not-found] # noqa: E402, F401
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
"""Entry point for openenv server mode."""
|
| 17 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/baseline.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline Policies for OpenGrid
|
| 3 |
+
================================
|
| 4 |
+
Provides two agent implementations:
|
| 5 |
+
1. heuristic_policy — deterministic rule-based baseline for reproducible scoring
|
| 6 |
+
2. llm_policy — LLM-based policy using OpenAI-compatible API
|
| 7 |
+
|
| 8 |
+
Both support GridObservation (single-agent) and ZoneObservation (multi-agent).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
from typing import List, Union
|
| 15 |
+
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
from .models import GridAction, BusAdjustment, GridObservation, ZoneObservation
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# API configuration — HF_TOKEN for Hugging Face endpoints, OPENAI_API_KEY for OpenAI
|
| 22 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 23 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
|
| 24 |
+
API_KEY = os.getenv("OPENAI_API_KEY", os.getenv("HF_TOKEN", ""))
|
| 25 |
+
|
| 26 |
+
# Cached client instance
|
| 27 |
+
_CLIENT = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_client() -> OpenAI:
|
| 31 |
+
"""Lazy-cached client creation."""
|
| 32 |
+
global _CLIENT
|
| 33 |
+
if _CLIENT is None:
|
| 34 |
+
if not API_KEY:
|
| 35 |
+
raise RuntimeError(
|
| 36 |
+
"Missing API key. Set OPENAI_API_KEY or HF_TOKEN environment variable."
|
| 37 |
+
)
|
| 38 |
+
_CLIENT = OpenAI(base_url=API_BASE_URL, api_key=API_KEY, timeout=15.0)
|
| 39 |
+
return _CLIENT
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _obs_buses(obs):
|
| 43 |
+
"""Extract bus list from either GridObservation or ZoneObservation."""
|
| 44 |
+
return getattr(obs, "buses", getattr(obs, "local_buses", []))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _obs_lines(obs):
|
| 48 |
+
"""Extract line list from either GridObservation or ZoneObservation."""
|
| 49 |
+
if hasattr(obs, "lines"):
|
| 50 |
+
return obs.lines
|
| 51 |
+
internal = getattr(obs, "internal_lines", [])
|
| 52 |
+
boundary = getattr(obs, "boundary_lines", [])
|
| 53 |
+
return list(internal) + list(boundary)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
SYSTEM_PROMPT = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
|
| 57 |
+
|
| 58 |
+
Key objectives:
|
| 59 |
+
1. Keep grid frequency close to 50.0 Hz (acceptable: 49.5–50.5 Hz)
|
| 60 |
+
2. Prevent transmission line overloads (rho < 1.0)
|
| 61 |
+
3. Avoid grid islanding (blackout)
|
| 62 |
+
|
| 63 |
+
Available actions:
|
| 64 |
+
1. bus_adjustments: List of {"bus_id": int, "delta": float}
|
| 65 |
+
- Positive delta = increase power injection (discharge battery / ramp up generator)
|
| 66 |
+
- Negative delta = decrease power injection (charge battery / ramp down generator)
|
| 67 |
+
- Only works on battery and generator buses (NOT slack, load, solar, or wind)
|
| 68 |
+
- Slack bus injection is computed by physics — adjustments are ignored
|
| 69 |
+
2. topology_actions: List of {"line_id": str, "action": "open" | "close"}
|
| 70 |
+
- Opening a line removes it; closing reconnects. 3-step cooldown after each switch.
|
| 71 |
+
- WARNING: Opening lines can cause islanding → blackout → -100 reward
|
| 72 |
+
- Prefer NO topology actions unless absolutely necessary.
|
| 73 |
+
|
| 74 |
+
Strategy tips:
|
| 75 |
+
- If frequency < 50 Hz: grid needs more generation → discharge batteries or ramp up generators
|
| 76 |
+
- If frequency > 50 Hz: grid has excess generation → charge batteries or ramp down generators
|
| 77 |
+
- If a line rho > 0.9: reduce generation at one end or increase at the other to shift flow
|
| 78 |
+
- Prefer minimal actions. Do-nothing is better than reckless switching.
|
| 79 |
+
|
| 80 |
+
Respond with ONLY a valid JSON object, no markdown, no explanation. Example:
|
| 81 |
+
{"bus_adjustments": [{"bus_id": 2, "delta": 5.0}], "topology_actions": []}
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_action_response(response_text: str) -> GridAction:
|
| 86 |
+
"""Parse LLM response into a GridAction. Falls back to no-op on parse errors."""
|
| 87 |
+
try:
|
| 88 |
+
text = response_text.strip()
|
| 89 |
+
|
| 90 |
+
# Remove fenced code block if present
|
| 91 |
+
if text.startswith("```"):
|
| 92 |
+
lines = text.splitlines()
|
| 93 |
+
if lines[0].startswith("```"):
|
| 94 |
+
lines = lines[1:]
|
| 95 |
+
if lines and lines[-1].startswith("```"):
|
| 96 |
+
lines = lines[:-1]
|
| 97 |
+
text = "\n".join(lines).strip()
|
| 98 |
+
|
| 99 |
+
# Extract first JSON object
|
| 100 |
+
start = text.find("{")
|
| 101 |
+
end = text.rfind("}")
|
| 102 |
+
if start == -1 or end == -1 or end <= start:
|
| 103 |
+
return GridAction()
|
| 104 |
+
|
| 105 |
+
data = json.loads(text[start:end + 1])
|
| 106 |
+
|
| 107 |
+
# Handle list wrapping
|
| 108 |
+
if isinstance(data, list):
|
| 109 |
+
data = data[0] if data else {}
|
| 110 |
+
|
| 111 |
+
return GridAction(**data)
|
| 112 |
+
except Exception:
|
| 113 |
+
return GridAction()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def llm_policy(obs: Union[GridObservation, ZoneObservation]) -> GridAction:
|
| 117 |
+
"""LLM-based policy using the OpenAI-compatible API.
|
| 118 |
+
|
| 119 |
+
Supports both GridObservation and ZoneObservation.
|
| 120 |
+
Falls back to no-op on any error.
|
| 121 |
+
"""
|
| 122 |
+
client = _get_client()
|
| 123 |
+
obs_json = obs.model_dump_json()
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
response = client.chat.completions.create(
|
| 127 |
+
model=MODEL_NAME,
|
| 128 |
+
messages=[
|
| 129 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 130 |
+
{"role": "user", "content": f"Current Grid State:\n{obs_json}"}
|
| 131 |
+
],
|
| 132 |
+
temperature=0.0,
|
| 133 |
+
max_tokens=300,
|
| 134 |
+
)
|
| 135 |
+
action_str = response.choices[0].message.content
|
| 136 |
+
return parse_action_response(action_str)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.debug("LLM policy error: %s", e, exc_info=True)
|
| 139 |
+
return GridAction()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def heuristic_policy(
|
| 143 |
+
obs: Union[GridObservation, ZoneObservation],
|
| 144 |
+
) -> GridAction:
|
| 145 |
+
"""Rule-based baseline policy for reproducible scoring.
|
| 146 |
+
|
| 147 |
+
Strategy:
|
| 148 |
+
- Use batteries and generators for frequency regulation (proportional control)
|
| 149 |
+
- DO NOT open overloaded lines (causes cascading failures)
|
| 150 |
+
- DO NOT adjust the slack bus (overwritten by physics solver)
|
| 151 |
+
- Let the environment/safety layer clamp any out-of-range deltas
|
| 152 |
+
|
| 153 |
+
Supports both GridObservation (single-agent) and ZoneObservation (multi-agent).
|
| 154 |
+
"""
|
| 155 |
+
adj = []
|
| 156 |
+
freq = obs.grid_frequency
|
| 157 |
+
freq_error = freq - 50.0 # positive = too high, negative = too low
|
| 158 |
+
|
| 159 |
+
buses = list(_obs_buses(obs))
|
| 160 |
+
lines = list(_obs_lines(obs))
|
| 161 |
+
|
| 162 |
+
batteries = [b for b in buses if b.type == 'battery']
|
| 163 |
+
generators = [b for b in buses if b.type == 'generator']
|
| 164 |
+
|
| 165 |
+
# --- 1. Proportional frequency control via batteries ---
|
| 166 |
+
if abs(freq_error) > 0.1 and batteries:
|
| 167 |
+
# Distribute correction across all available batteries
|
| 168 |
+
correction_total = -freq_error * 15.0 # stronger gain than naive 2.0
|
| 169 |
+
correction_total = max(-20.0, min(20.0, correction_total))
|
| 170 |
+
per_battery = correction_total / len(batteries)
|
| 171 |
+
|
| 172 |
+
for bus in batteries:
|
| 173 |
+
if per_battery > 0 and bus.soc > 0:
|
| 174 |
+
# Discharge — safety layer clamps to actual SOC
|
| 175 |
+
adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
|
| 176 |
+
elif per_battery < 0:
|
| 177 |
+
# Charge — safety layer clamps to remaining capacity
|
| 178 |
+
adj.append(BusAdjustment(bus_id=bus.id, delta=per_battery))
|
| 179 |
+
|
| 180 |
+
# --- 2. Generator response for larger deviations ---
|
| 181 |
+
if abs(freq_error) > 0.25:
|
| 182 |
+
for bus in generators:
|
| 183 |
+
delta = -freq_error * 5.0
|
| 184 |
+
ramp = getattr(bus, 'ramp_rate', 20.0)
|
| 185 |
+
delta = max(-ramp, min(ramp, delta))
|
| 186 |
+
adj.append(BusAdjustment(bus_id=bus.id, delta=delta))
|
| 187 |
+
|
| 188 |
+
# --- 3. Overload relief via generators (not slack) ---
|
| 189 |
+
adjusted_for_overload = set()
|
| 190 |
+
for line in lines:
|
| 191 |
+
if line.rho > 0.95 and line.connected:
|
| 192 |
+
for bus in generators:
|
| 193 |
+
if bus.id not in adjusted_for_overload and bus.p_injection > 5:
|
| 194 |
+
adj.append(BusAdjustment(bus_id=bus.id, delta=-3.0))
|
| 195 |
+
adjusted_for_overload.add(bus.id)
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
# No topology actions — much safer than opening overloaded lines
|
| 199 |
+
return GridAction(bus_adjustments=adj, topology_actions=[])
|
src/environment.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Dict, Tuple, Optional
|
| 4 |
+
from .models import (
|
| 5 |
+
GridObservation, GridAction, GridReward, GridInfo,
|
| 6 |
+
LineStatus, BusState, ZoneObservation, ZoneInfo,
|
| 7 |
+
SafetyReport, OversightReport, MultiAgentStepResult,
|
| 8 |
+
)
|
| 9 |
+
from .physics import DCSolver, IslandedException
|
| 10 |
+
from .safety import SafetyLayer
|
| 11 |
+
from .oversight import OversightAgent
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OpenGridEnv:
|
| 15 |
+
"""
|
| 16 |
+
OpenGrid: A renewable energy grid load-balancing environment.
|
| 17 |
+
|
| 18 |
+
Supports two modes:
|
| 19 |
+
1. Single-agent (backward compatible): reset()/step()/state()
|
| 20 |
+
2. Multi-agent POMDP: reset_multi()/step_multi() with per-zone
|
| 21 |
+
partial observability, safety layer, and oversight agent.
|
| 22 |
+
|
| 23 |
+
The agent(s) must maintain grid stability by:
|
| 24 |
+
- Balancing generation and load (frequency control)
|
| 25 |
+
- Managing transmission line loading (congestion management)
|
| 26 |
+
- Coordinating battery storage and topology switching
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
NOMINAL_FREQ = 50.0
|
| 30 |
+
FREQ_DEADBAND = 0.5 # Hz — acceptable deviation band
|
| 31 |
+
FREQ_NOISE_STD = 0.05 # Hz — noise added to POMDP observations
|
| 32 |
+
LINE_NOISE_STD = 0.02 # fraction — noise added to line readings
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: Dict):
|
| 35 |
+
self.config = config
|
| 36 |
+
self.num_buses = config['num_buses']
|
| 37 |
+
self.lines_config = config['lines']
|
| 38 |
+
self.buses_config = config['buses']
|
| 39 |
+
|
| 40 |
+
# Resolve slack bus from config (not hardcoded to index 0)
|
| 41 |
+
self.slack_bus_id = next(
|
| 42 |
+
(b['id'] for b in self.buses_config if b['type'] == 'slack'), 0
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.solver = DCSolver(self.num_buses, slack_bus=self.slack_bus_id)
|
| 46 |
+
self.timestep = 0
|
| 47 |
+
self.max_steps = config.get('max_steps', 50)
|
| 48 |
+
|
| 49 |
+
self.bus_state = []
|
| 50 |
+
self.line_state = []
|
| 51 |
+
self.cooldowns = {}
|
| 52 |
+
self.slack_injection = 0.0
|
| 53 |
+
self._is_blackout = False
|
| 54 |
+
|
| 55 |
+
# Build index dicts for O(1) lookups
|
| 56 |
+
self._bus_cfg_by_id = {b['id']: b for b in self.buses_config}
|
| 57 |
+
self._line_cfg_by_id = {l['id']: l for l in self.lines_config}
|
| 58 |
+
|
| 59 |
+
# Multi-agent config
|
| 60 |
+
self.num_agents = config.get('num_agents', 1)
|
| 61 |
+
self.zone_assignments = config.get('zone_assignments', {})
|
| 62 |
+
self.zone_names = config.get('zone_names', [])
|
| 63 |
+
self.zone_bus_ids = config.get('zone_bus_ids', {})
|
| 64 |
+
self.internal_lines = config.get('internal_lines', {})
|
| 65 |
+
self.boundary_lines = config.get('boundary_lines', {})
|
| 66 |
+
|
| 67 |
+
# Safety and oversight (initialized on first multi-agent use)
|
| 68 |
+
self.safety_layer = SafetyLayer(config)
|
| 69 |
+
self.oversight_agent = OversightAgent(config)
|
| 70 |
+
|
| 71 |
+
# Episode tracking for multi-agent rewards
|
| 72 |
+
self._safety_reports_this_step: List[SafetyReport] = []
|
| 73 |
+
self._oversight_report_this_step: Optional[OversightReport] = None
|
| 74 |
+
|
| 75 |
+
# Calibrate droop constant to system size
|
| 76 |
+
total_load = sum(
|
| 77 |
+
b['base_p'] for b in self.buses_config if b['type'] == 'load'
|
| 78 |
+
)
|
| 79 |
+
total_gen = sum(
|
| 80 |
+
b['max_p'] for b in self.buses_config
|
| 81 |
+
if b['type'] in ['slack', 'generator', 'solar', 'wind']
|
| 82 |
+
)
|
| 83 |
+
total_system = max(total_load + total_gen, 50.0)
|
| 84 |
+
self.droop_constant = 2.5 / total_system
|
| 85 |
+
|
| 86 |
+
# Per-episode RNG — initialized early so _update_loads_and_renewables never crashes
|
| 87 |
+
self._seed = config.get('seed', 42)
|
| 88 |
+
self._rng = np.random.default_rng(self._seed)
|
| 89 |
+
|
| 90 |
+
# ======================================================================
|
| 91 |
+
# State Restoration (for GRPO environment-grounded rewards)
|
| 92 |
+
# ======================================================================
|
| 93 |
+
|
| 94 |
+
def _set_state(self, obs_dict: dict) -> None:
|
| 95 |
+
"""Restore the environment to a state described by an observation dict.
|
| 96 |
+
|
| 97 |
+
This enables environment-grounded GRPO rewards: instead of scoring
|
| 98 |
+
actions with a heuristic proxy, we restore the env to the observed state,
|
| 99 |
+
step with the proposed action, and use the real reward.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
obs_dict: A dict from ZoneObservation.model_dump() or
|
| 103 |
+
GridObservation.model_dump(), containing at minimum:
|
| 104 |
+
timestep, grid_frequency, and bus/line state.
|
| 105 |
+
"""
|
| 106 |
+
self.timestep = obs_dict.get('timestep', 0)
|
| 107 |
+
self._is_blackout = obs_dict.get('is_blackout', False)
|
| 108 |
+
self.cooldowns = obs_dict.get('cooldowns', {k: 0 for k in self.cooldowns})
|
| 109 |
+
|
| 110 |
+
# Restore bus state from observation
|
| 111 |
+
local_buses = obs_dict.get('local_buses', obs_dict.get('buses', []))
|
| 112 |
+
if local_buses:
|
| 113 |
+
for b_obs in local_buses:
|
| 114 |
+
b_dyn = self._find_bus_state(b_obs['id'])
|
| 115 |
+
if b_dyn is not None:
|
| 116 |
+
b_dyn['p'] = b_obs.get('p_injection', b_dyn['p'])
|
| 117 |
+
b_dyn['soc'] = b_obs.get('soc', b_dyn.get('soc', 0.0))
|
| 118 |
+
|
| 119 |
+
# Restore line state from observation
|
| 120 |
+
all_lines = (obs_dict.get('internal_lines', []) or []) + \
|
| 121 |
+
(obs_dict.get('boundary_lines', []) or []) + \
|
| 122 |
+
(obs_dict.get('lines', []) or [])
|
| 123 |
+
for l_obs in all_lines:
|
| 124 |
+
l_dyn = self._find_line(l_obs['id'])
|
| 125 |
+
if l_dyn is not None:
|
| 126 |
+
l_dyn['connected'] = l_obs.get('connected', True)
|
| 127 |
+
l_dyn['flow'] = l_obs.get('flow', 0.0)
|
| 128 |
+
|
| 129 |
+
# Rebuild lookup indices
|
| 130 |
+
self._bus_state_by_id = {b['id']: b for b in self.bus_state}
|
| 131 |
+
self._line_state_by_id = {l['id']: l for l in self.line_state}
|
| 132 |
+
|
| 133 |
+
# Re-derive slack injection from frequency if available
|
| 134 |
+
freq = obs_dict.get('grid_frequency', self.NOMINAL_FREQ)
|
| 135 |
+
self.slack_injection = (self.NOMINAL_FREQ - freq) / self.droop_constant
|
| 136 |
+
|
| 137 |
+
# Update slack bus p to match
|
| 138 |
+
slack_dyn = self._find_bus_state(self.slack_bus_id)
|
| 139 |
+
if slack_dyn is not None:
|
| 140 |
+
slack_dyn['p'] = self.slack_injection
|
| 141 |
+
|
| 142 |
+
# ======================================================================
|
| 143 |
+
# Single-Agent API (backward compatible)
|
| 144 |
+
# ======================================================================
|
| 145 |
+
|
| 146 |
+
def reset(self) -> GridObservation:
|
| 147 |
+
"""Reset the environment to initial state. Returns initial observation."""
|
| 148 |
+
self.timestep = 0
|
| 149 |
+
self.slack_injection = 0.0
|
| 150 |
+
self.cooldowns = {l['id']: 0 for l in self.lines_config}
|
| 151 |
+
self._rng = np.random.default_rng(self._seed)
|
| 152 |
+
self.oversight_agent.reset()
|
| 153 |
+
|
| 154 |
+
self.bus_state = []
|
| 155 |
+
for b in self.buses_config:
|
| 156 |
+
init_p = 0.0
|
| 157 |
+
# Initialize generators at 50% capacity so slack doesn't absorb all load
|
| 158 |
+
if b['type'] in ['generator']:
|
| 159 |
+
init_p = b['max_p'] * 0.5
|
| 160 |
+
self.bus_state.append({
|
| 161 |
+
'id': b['id'], 'p': init_p, 'soc': b.get('init_soc', 0.0)
|
| 162 |
+
})
|
| 163 |
+
self.line_state = [
|
| 164 |
+
{'id': l['id'], 'connected': True, 'flow': 0.0}
|
| 165 |
+
for l in self.lines_config
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
# Build O(1) lookup indices for dynamic state
|
| 169 |
+
self._bus_state_by_id = {b['id']: b for b in self.bus_state}
|
| 170 |
+
self._line_state_by_id = {l['id']: l for l in self.line_state}
|
| 171 |
+
|
| 172 |
+
self._is_blackout = False
|
| 173 |
+
self._update_loads_and_renewables()
|
| 174 |
+
self._run_power_flow()
|
| 175 |
+
|
| 176 |
+
return self._get_obs()
|
| 177 |
+
|
| 178 |
+
def step(self, action: GridAction) -> Tuple[GridObservation, GridReward, bool, GridInfo]:
|
| 179 |
+
"""Execute one step: apply action, update dynamics, solve physics, compute reward."""
|
| 180 |
+
self.timestep += 1
|
| 181 |
+
reward_components = {"survival": 1.0, "frequency": 0.0, "overload": 0.0, "action_cost": 0.0}
|
| 182 |
+
self._is_blackout = False
|
| 183 |
+
|
| 184 |
+
# 1. Apply topology actions (with cooldown enforcement)
|
| 185 |
+
for t_act in action.topology_actions:
|
| 186 |
+
l_id = t_act.line_id
|
| 187 |
+
if l_id not in self.cooldowns:
|
| 188 |
+
continue
|
| 189 |
+
if self.cooldowns[l_id] == 0:
|
| 190 |
+
line = self._find_line(l_id)
|
| 191 |
+
if line is None:
|
| 192 |
+
continue
|
| 193 |
+
current_status = line['connected']
|
| 194 |
+
new_status = (t_act.action == "close")
|
| 195 |
+
|
| 196 |
+
if current_status != new_status:
|
| 197 |
+
line['connected'] = new_status
|
| 198 |
+
self.cooldowns[l_id] = 3
|
| 199 |
+
reward_components['action_cost'] -= 0.5
|
| 200 |
+
|
| 201 |
+
# Tick cooldowns
|
| 202 |
+
for l_id in self.cooldowns:
|
| 203 |
+
self.cooldowns[l_id] = max(0, self.cooldowns[l_id] - 1)
|
| 204 |
+
|
| 205 |
+
# 2. Apply power adjustment actions
|
| 206 |
+
for adj in action.bus_adjustments:
|
| 207 |
+
bus_cfg = self._find_bus_config(adj.bus_id)
|
| 208 |
+
bus_dyn = self._find_bus_state(adj.bus_id)
|
| 209 |
+
if bus_cfg is None or bus_dyn is None:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
delta = adj.delta
|
| 213 |
+
|
| 214 |
+
if bus_cfg['type'] == 'battery':
|
| 215 |
+
max_charge = bus_cfg['capacity'] - bus_dyn['soc']
|
| 216 |
+
max_discharge = bus_dyn['soc']
|
| 217 |
+
|
| 218 |
+
if delta > 0:
|
| 219 |
+
delta = min(delta, max_discharge)
|
| 220 |
+
else:
|
| 221 |
+
delta = max(delta, -max_charge)
|
| 222 |
+
|
| 223 |
+
bus_dyn['soc'] = np.clip(bus_dyn['soc'] - delta, 0.0, bus_cfg['capacity'])
|
| 224 |
+
bus_dyn['p'] = delta
|
| 225 |
+
|
| 226 |
+
elif bus_cfg['type'] not in ['load', 'solar', 'wind']:
|
| 227 |
+
max_ramp = bus_cfg.get('ramp_rate', 10.0)
|
| 228 |
+
delta = np.clip(delta, -max_ramp, max_ramp)
|
| 229 |
+
new_p = bus_dyn['p'] + delta
|
| 230 |
+
bus_dyn['p'] = np.clip(new_p, bus_cfg['min_p'], bus_cfg['max_p'])
|
| 231 |
+
|
| 232 |
+
# 3. Update load/renewable dynamics
|
| 233 |
+
self._update_loads_and_renewables()
|
| 234 |
+
|
| 235 |
+
# 4. Solve physics
|
| 236 |
+
try:
|
| 237 |
+
self._run_power_flow()
|
| 238 |
+
|
| 239 |
+
# Check line overloads
|
| 240 |
+
for l in self.line_state:
|
| 241 |
+
if l['connected']:
|
| 242 |
+
flow = l['flow']
|
| 243 |
+
limit = self._get_line_capacity(l['id'])
|
| 244 |
+
rho = abs(flow) / limit if limit > 0 else 0.0
|
| 245 |
+
|
| 246 |
+
if rho > 1.0:
|
| 247 |
+
reward_components['overload'] -= (rho - 1.0) ** 2 * 20
|
| 248 |
+
elif rho > 0.8:
|
| 249 |
+
reward_components['overload'] -= 0.1
|
| 250 |
+
|
| 251 |
+
# Frequency reward
|
| 252 |
+
freq = self._compute_frequency()
|
| 253 |
+
freq_dev = abs(freq - self.NOMINAL_FREQ)
|
| 254 |
+
if freq_dev > self.FREQ_DEADBAND:
|
| 255 |
+
raw_penalty = (freq_dev - self.FREQ_DEADBAND) * 0.5
|
| 256 |
+
reward_components['frequency'] -= min(raw_penalty, 1.5)
|
| 257 |
+
elif freq_dev < 0.1:
|
| 258 |
+
reward_components['frequency'] += 0.2
|
| 259 |
+
|
| 260 |
+
except IslandedException:
|
| 261 |
+
self._is_blackout = True
|
| 262 |
+
reward_components['survival'] = -100.0
|
| 263 |
+
|
| 264 |
+
done = self._is_blackout or (self.timestep >= self.max_steps)
|
| 265 |
+
|
| 266 |
+
total_reward = sum(reward_components.values())
|
| 267 |
+
reward = GridReward(value=total_reward, components=reward_components)
|
| 268 |
+
info = GridInfo(task_id=self.config['id'], is_blackout=self._is_blackout)
|
| 269 |
+
|
| 270 |
+
return self._get_obs(), reward, done, info
|
| 271 |
+
|
| 272 |
+
def state(self) -> GridObservation:
|
| 273 |
+
"""Return current state (alias for observation)."""
|
| 274 |
+
return self._get_obs()
|
| 275 |
+
|
| 276 |
+
# ======================================================================
|
| 277 |
+
# Multi-Agent POMDP API
|
| 278 |
+
# ======================================================================
|
| 279 |
+
|
| 280 |
+
def reset_multi(self) -> Dict[int, ZoneObservation]:
|
| 281 |
+
"""Reset environment and return per-agent partial observations."""
|
| 282 |
+
self.reset() # Reuse single-agent reset for state initialization
|
| 283 |
+
return {
|
| 284 |
+
agent_id: self._get_zone_obs(agent_id)
|
| 285 |
+
for agent_id in range(self.num_agents)
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
def step_multi(self, agent_actions: Dict[int, GridAction]) -> MultiAgentStepResult:
|
| 289 |
+
"""Multi-agent step with safety layer and oversight.
|
| 290 |
+
|
| 291 |
+
Flow:
|
| 292 |
+
1. Safety layer validates each agent's actions
|
| 293 |
+
2. Combine corrected actions into one GridAction
|
| 294 |
+
3. Run single-agent step with combined action
|
| 295 |
+
4. Oversight agent evaluates coordination
|
| 296 |
+
5. Compute per-agent rewards (local + global + safety + coordination)
|
| 297 |
+
"""
|
| 298 |
+
pre_frequency = self._compute_frequency()
|
| 299 |
+
pre_bus_state = [dict(b) for b in self.bus_state]
|
| 300 |
+
|
| 301 |
+
# --- 1. Safety validation per agent ---
|
| 302 |
+
safety_reports: Dict[int, SafetyReport] = {}
|
| 303 |
+
corrected_actions: Dict[int, GridAction] = {}
|
| 304 |
+
|
| 305 |
+
for agent_id in range(self.num_agents):
|
| 306 |
+
proposed = agent_actions.get(agent_id, GridAction())
|
| 307 |
+
corrected, report = self.safety_layer.validate_and_correct(
|
| 308 |
+
agent_id=agent_id,
|
| 309 |
+
proposed_action=proposed,
|
| 310 |
+
current_line_state=self.line_state,
|
| 311 |
+
current_bus_state=self.bus_state,
|
| 312 |
+
cooldowns=self.cooldowns,
|
| 313 |
+
)
|
| 314 |
+
corrected_actions[agent_id] = corrected
|
| 315 |
+
safety_reports[agent_id] = report
|
| 316 |
+
|
| 317 |
+
self._safety_reports_this_step = safety_reports
|
| 318 |
+
|
| 319 |
+
# --- 2. Combine all corrected actions ---
|
| 320 |
+
combined = GridAction(
|
| 321 |
+
bus_adjustments=[
|
| 322 |
+
adj for action in corrected_actions.values()
|
| 323 |
+
for adj in action.bus_adjustments
|
| 324 |
+
],
|
| 325 |
+
topology_actions=[
|
| 326 |
+
t for action in corrected_actions.values()
|
| 327 |
+
for t in action.topology_actions
|
| 328 |
+
],
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# --- 3. Run the step ---
|
| 332 |
+
obs, base_reward, done, info = self.step(combined)
|
| 333 |
+
post_frequency = self._compute_frequency()
|
| 334 |
+
|
| 335 |
+
# --- 4. Oversight evaluation ---
|
| 336 |
+
oversight_report = self.oversight_agent.evaluate(
|
| 337 |
+
agent_actions=agent_actions,
|
| 338 |
+
safety_reports=safety_reports,
|
| 339 |
+
pre_frequency=pre_frequency,
|
| 340 |
+
post_frequency=post_frequency,
|
| 341 |
+
pre_bus_state=pre_bus_state,
|
| 342 |
+
post_bus_state=self.bus_state,
|
| 343 |
+
)
|
| 344 |
+
self._oversight_report_this_step = oversight_report
|
| 345 |
+
|
| 346 |
+
# --- 5. Per-agent rewards ---
|
| 347 |
+
per_agent_rewards = {}
|
| 348 |
+
for agent_id in range(self.num_agents):
|
| 349 |
+
agent_reward = self._compute_agent_reward(
|
| 350 |
+
agent_id=agent_id,
|
| 351 |
+
base_reward=base_reward,
|
| 352 |
+
safety_report=safety_reports.get(agent_id),
|
| 353 |
+
oversight_report=oversight_report,
|
| 354 |
+
is_blackout=info.is_blackout,
|
| 355 |
+
)
|
| 356 |
+
per_agent_rewards[agent_id] = agent_reward
|
| 357 |
+
|
| 358 |
+
team_reward = base_reward.value
|
| 359 |
+
|
| 360 |
+
# --- 6. Per-agent partial observations ---
|
| 361 |
+
per_agent_obs = {
|
| 362 |
+
agent_id: self._get_zone_obs(agent_id)
|
| 363 |
+
for agent_id in range(self.num_agents)
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
# Propagate blackout to observations
|
| 367 |
+
if info.is_blackout:
|
| 368 |
+
for obs in per_agent_obs.values():
|
| 369 |
+
obs.is_blackout = True
|
| 370 |
+
|
| 371 |
+
return MultiAgentStepResult(
|
| 372 |
+
observations=per_agent_obs,
|
| 373 |
+
rewards=per_agent_rewards,
|
| 374 |
+
team_reward=round(team_reward, 4),
|
| 375 |
+
done=done,
|
| 376 |
+
safety_reports=safety_reports,
|
| 377 |
+
oversight_report=oversight_report,
|
| 378 |
+
info=info,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def get_zone_info(self) -> Dict[int, ZoneInfo]:
|
| 382 |
+
"""Get metadata about each agent's zone."""
|
| 383 |
+
zones = {}
|
| 384 |
+
for agent_id in range(self.num_agents):
|
| 385 |
+
zones[agent_id] = ZoneInfo(
|
| 386 |
+
agent_id=agent_id,
|
| 387 |
+
zone_name=self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}",
|
| 388 |
+
bus_ids=self.zone_bus_ids.get(agent_id, []),
|
| 389 |
+
boundary_line_ids=self.boundary_lines.get(agent_id, []),
|
| 390 |
+
internal_line_ids=self.internal_lines.get(agent_id, []),
|
| 391 |
+
)
|
| 392 |
+
return zones
|
| 393 |
+
|
| 394 |
+
# ======================================================================
|
| 395 |
+
# Multi-Agent Reward Computation
|
| 396 |
+
# ======================================================================
|
| 397 |
+
|
| 398 |
+
def _compute_agent_reward(
|
| 399 |
+
self,
|
| 400 |
+
agent_id: int,
|
| 401 |
+
base_reward: GridReward,
|
| 402 |
+
safety_report: Optional[SafetyReport],
|
| 403 |
+
oversight_report: OversightReport,
|
| 404 |
+
is_blackout: bool,
|
| 405 |
+
) -> GridReward:
|
| 406 |
+
"""Compute per-agent reward with composable components.
|
| 407 |
+
|
| 408 |
+
Components:
|
| 409 |
+
- survival: shared team component (same for all)
|
| 410 |
+
- frequency: shared (all agents affected equally)
|
| 411 |
+
- local_congestion: penalty for overloads in agent's zone
|
| 412 |
+
- safety_compliance: penalty if safety layer corrected the action
|
| 413 |
+
- coordination: penalty from oversight for selfish/conflicting behavior
|
| 414 |
+
- efficiency: small bonus for minimal actions
|
| 415 |
+
"""
|
| 416 |
+
components = {}
|
| 417 |
+
|
| 418 |
+
# Shared components (from base reward)
|
| 419 |
+
components['survival'] = base_reward.components.get('survival', 1.0)
|
| 420 |
+
components['frequency'] = base_reward.components.get('frequency', 0.0)
|
| 421 |
+
|
| 422 |
+
# Global overload shared equally — ensures no line's penalty is lost
|
| 423 |
+
components['overload_shared'] = base_reward.components.get('overload', 0.0) / max(self.num_agents, 1)
|
| 424 |
+
|
| 425 |
+
# Local congestion: additional penalty for overloads on lines in agent's zone
|
| 426 |
+
zone_overload = 0.0
|
| 427 |
+
agent_lines = set(self.internal_lines.get(agent_id, []))
|
| 428 |
+
agent_lines.update(self.boundary_lines.get(agent_id, []))
|
| 429 |
+
for l in self.line_state:
|
| 430 |
+
if l['id'] in agent_lines and l['connected']:
|
| 431 |
+
limit = self._get_line_capacity(l['id'])
|
| 432 |
+
rho = abs(l['flow']) / limit if limit > 0 else 0.0
|
| 433 |
+
if rho > 1.0:
|
| 434 |
+
zone_overload -= (rho - 1.0) ** 2 * 10
|
| 435 |
+
elif rho > 0.8:
|
| 436 |
+
zone_overload -= 0.05
|
| 437 |
+
components['local_congestion'] = zone_overload
|
| 438 |
+
|
| 439 |
+
# Safety compliance penalty
|
| 440 |
+
if safety_report and safety_report.was_corrected:
|
| 441 |
+
components['safety_compliance'] = -0.3 * (
|
| 442 |
+
1 + safety_report.blocked_topology_actions
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
components['safety_compliance'] = 0.1 # Bonus for safe actions
|
| 446 |
+
|
| 447 |
+
# Coordination penalty from oversight
|
| 448 |
+
coord_penalty = oversight_report.coordination_penalties.get(agent_id, 0.0)
|
| 449 |
+
components['coordination'] = -coord_penalty
|
| 450 |
+
|
| 451 |
+
# Action cost
|
| 452 |
+
components['action_cost'] = base_reward.components.get('action_cost', 0.0) / max(self.num_agents, 1)
|
| 453 |
+
|
| 454 |
+
total = sum(components.values())
|
| 455 |
+
return GridReward(value=round(total, 4), components=components)
|
| 456 |
+
|
| 457 |
+
# ======================================================================
|
| 458 |
+
# POMDP Observation
|
| 459 |
+
# ======================================================================
|
| 460 |
+
|
| 461 |
+
def _get_zone_obs(self, agent_id: int) -> ZoneObservation:
|
| 462 |
+
"""Build partial observation for one agent (POMDP).
|
| 463 |
+
|
| 464 |
+
Each agent sees:
|
| 465 |
+
- Only buses in their zone
|
| 466 |
+
- Internal + boundary lines
|
| 467 |
+
- Noisy global frequency
|
| 468 |
+
- Limited neighbor signals
|
| 469 |
+
"""
|
| 470 |
+
# Local buses
|
| 471 |
+
zone_bus_ids = set(self.zone_bus_ids.get(agent_id, []))
|
| 472 |
+
local_buses = []
|
| 473 |
+
zone_load = 0.0
|
| 474 |
+
zone_gen = 0.0
|
| 475 |
+
for b in self.bus_state:
|
| 476 |
+
if b['id'] in zone_bus_ids:
|
| 477 |
+
b_cfg = self._find_bus_config(b['id'])
|
| 478 |
+
if b_cfg is None:
|
| 479 |
+
continue
|
| 480 |
+
local_buses.append(BusState(
|
| 481 |
+
id=b['id'], type=b_cfg['type'],
|
| 482 |
+
p_injection=round(b['p'], 4),
|
| 483 |
+
soc=round(b.get('soc', 0.0), 4),
|
| 484 |
+
ramp_rate=b_cfg.get('ramp_rate', 0.0),
|
| 485 |
+
))
|
| 486 |
+
if b_cfg['type'] == 'load':
|
| 487 |
+
zone_load += abs(b['p'])
|
| 488 |
+
elif b_cfg['type'] in ('generator', 'solar', 'wind', 'slack'):
|
| 489 |
+
zone_gen += b['p']
|
| 490 |
+
# battery: not classified as load or gen
|
| 491 |
+
|
| 492 |
+
# Internal lines (within zone)
|
| 493 |
+
int_line_ids = set(self.internal_lines.get(agent_id, []))
|
| 494 |
+
internal_lines = []
|
| 495 |
+
for l in self.line_state:
|
| 496 |
+
if l['id'] in int_line_ids:
|
| 497 |
+
limit = self._get_line_capacity(l['id'])
|
| 498 |
+
rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
|
| 499 |
+
# Add noise to line readings
|
| 500 |
+
noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho
|
| 501 |
+
noisy_rho = max(0.0, noisy_rho)
|
| 502 |
+
internal_lines.append(LineStatus(
|
| 503 |
+
id=l['id'], connected=l['connected'],
|
| 504 |
+
flow=round(l['flow'], 4),
|
| 505 |
+
rho=round(noisy_rho, 4),
|
| 506 |
+
))
|
| 507 |
+
|
| 508 |
+
# Boundary lines (connecting to other zones)
|
| 509 |
+
bnd_line_ids = set(self.boundary_lines.get(agent_id, []))
|
| 510 |
+
boundary_lines = []
|
| 511 |
+
for l in self.line_state:
|
| 512 |
+
if l['id'] in bnd_line_ids:
|
| 513 |
+
limit = self._get_line_capacity(l['id'])
|
| 514 |
+
rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
|
| 515 |
+
noisy_rho = rho + self._rng.normal(0, self.LINE_NOISE_STD) if self._rng else rho
|
| 516 |
+
noisy_rho = max(0.0, noisy_rho)
|
| 517 |
+
boundary_lines.append(LineStatus(
|
| 518 |
+
id=l['id'], connected=l['connected'],
|
| 519 |
+
flow=round(l['flow'], 4),
|
| 520 |
+
rho=round(noisy_rho, 4),
|
| 521 |
+
))
|
| 522 |
+
|
| 523 |
+
# Noisy frequency (POMDP — agents don't get perfect readings)
|
| 524 |
+
true_freq = self._compute_frequency()
|
| 525 |
+
noisy_freq = true_freq + (self._rng.normal(0, self.FREQ_NOISE_STD) if self._rng else 0.0)
|
| 526 |
+
|
| 527 |
+
# Neighbor signals: average bus injection of other zones
|
| 528 |
+
neighbor_signals = {}
|
| 529 |
+
for other_id in range(self.num_agents):
|
| 530 |
+
if other_id == agent_id:
|
| 531 |
+
continue
|
| 532 |
+
other_bus_ids = self.zone_bus_ids.get(other_id, [])
|
| 533 |
+
if other_bus_ids:
|
| 534 |
+
avg_inj = np.mean([
|
| 535 |
+
b['p'] for b in self.bus_state if b['id'] in other_bus_ids
|
| 536 |
+
])
|
| 537 |
+
neighbor_signals[other_id] = round(float(avg_inj), 2)
|
| 538 |
+
|
| 539 |
+
# Cooldowns for lines this agent can see
|
| 540 |
+
visible_lines = int_line_ids | bnd_line_ids
|
| 541 |
+
visible_cooldowns = {
|
| 542 |
+
k: v for k, v in self.cooldowns.items() if k in visible_lines
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
zone_name = self.zone_names[agent_id] if agent_id < len(self.zone_names) else f"Zone_{agent_id}"
|
| 546 |
+
|
| 547 |
+
return ZoneObservation(
|
| 548 |
+
agent_id=agent_id,
|
| 549 |
+
zone_name=zone_name,
|
| 550 |
+
timestep=self.timestep,
|
| 551 |
+
grid_frequency=round(noisy_freq, 4),
|
| 552 |
+
local_buses=local_buses,
|
| 553 |
+
boundary_lines=boundary_lines,
|
| 554 |
+
internal_lines=internal_lines,
|
| 555 |
+
neighbor_signals=neighbor_signals,
|
| 556 |
+
cooldowns=visible_cooldowns,
|
| 557 |
+
is_blackout=False,
|
| 558 |
+
zone_load_mw=round(zone_load, 2),
|
| 559 |
+
zone_gen_mw=round(zone_gen, 2),
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# ======================================================================
|
| 563 |
+
# Internal Methods (unchanged from original)
|
| 564 |
+
# ======================================================================
|
| 565 |
+
|
| 566 |
+
def _run_power_flow(self):
|
| 567 |
+
"""Build active line list, solve DC power flow, update line flows and slack injection."""
|
| 568 |
+
active_lines = []
|
| 569 |
+
for l_cfg in self.lines_config:
|
| 570 |
+
l_dyn = self._find_line(l_cfg['id'])
|
| 571 |
+
if l_dyn and l_dyn['connected']:
|
| 572 |
+
active_lines.append({
|
| 573 |
+
'id': l_cfg['id'], 'from': l_cfg['from'], 'to': l_cfg['to'],
|
| 574 |
+
'susceptance': l_cfg['susceptance'], 'connected': True
|
| 575 |
+
})
|
| 576 |
+
|
| 577 |
+
self.solver.update_grid(active_lines)
|
| 578 |
+
|
| 579 |
+
p_inj = np.zeros(self.num_buses)
|
| 580 |
+
for b_dyn in self.bus_state:
|
| 581 |
+
p_inj[b_dyn['id']] = b_dyn['p']
|
| 582 |
+
|
| 583 |
+
theta, flows, slack_inj = self.solver.solve(p_inj)
|
| 584 |
+
|
| 585 |
+
self.slack_injection = slack_inj
|
| 586 |
+
slack_dyn = self._find_bus_state(self.slack_bus_id)
|
| 587 |
+
if slack_dyn is not None:
|
| 588 |
+
slack_dyn['p'] = slack_inj
|
| 589 |
+
|
| 590 |
+
for l in self.line_state:
|
| 591 |
+
if l['connected'] and l['id'] in flows:
|
| 592 |
+
l['flow'] = flows[l['id']]
|
| 593 |
+
elif not l['connected']:
|
| 594 |
+
l['flow'] = 0.0
|
| 595 |
+
|
| 596 |
+
def _compute_frequency(self) -> float:
|
| 597 |
+
"""Frequency proxy using droop model, calibrated to system size."""
|
| 598 |
+
return self.NOMINAL_FREQ - self.droop_constant * self.slack_injection
|
| 599 |
+
|
| 600 |
+
def _update_loads_and_renewables(self):
|
| 601 |
+
"""Update time-varying loads and renewable generation. Uses per-episode RNG."""
|
| 602 |
+
for b_dyn in self.bus_state:
|
| 603 |
+
b_cfg = self._find_bus_config(b_dyn['id'])
|
| 604 |
+
if b_cfg is None:
|
| 605 |
+
continue
|
| 606 |
+
|
| 607 |
+
if b_cfg['type'] == 'load':
|
| 608 |
+
daily_cycle = math.sin((self.timestep % 24 - 6) * math.pi / 12)
|
| 609 |
+
b_dyn['p'] = -b_cfg['base_p'] * (0.8 + 0.4 * max(0, daily_cycle))
|
| 610 |
+
|
| 611 |
+
elif b_cfg['type'] == 'solar':
|
| 612 |
+
solar_cycle = max(0, math.sin((self.timestep % 24 - 6) * math.pi / 12))
|
| 613 |
+
b_dyn['p'] = b_cfg['max_p'] * solar_cycle
|
| 614 |
+
|
| 615 |
+
elif b_cfg['type'] == 'wind':
|
| 616 |
+
wind_delta = self._rng.uniform(-5, 5)
|
| 617 |
+
b_dyn['p'] = float(np.clip(b_dyn['p'] + wind_delta, 0, b_cfg['max_p']))
|
| 618 |
+
|
| 619 |
+
def _get_obs(self) -> GridObservation:
|
| 620 |
+
"""Build observation from current state."""
|
| 621 |
+
obs_lines = []
|
| 622 |
+
for l in self.line_state:
|
| 623 |
+
limit = self._get_line_capacity(l['id'])
|
| 624 |
+
rho = abs(l['flow']) / limit if l['connected'] and limit > 0 else 0.0
|
| 625 |
+
obs_lines.append(LineStatus(
|
| 626 |
+
id=l['id'], connected=l['connected'], flow=round(l['flow'], 4), rho=round(rho, 4)
|
| 627 |
+
))
|
| 628 |
+
|
| 629 |
+
obs_buses = []
|
| 630 |
+
for b in self.bus_state:
|
| 631 |
+
b_cfg = self._find_bus_config(b['id'])
|
| 632 |
+
if b_cfg is None:
|
| 633 |
+
continue
|
| 634 |
+
obs_buses.append(BusState(
|
| 635 |
+
id=b['id'], type=b_cfg['type'],
|
| 636 |
+
p_injection=round(b['p'], 4),
|
| 637 |
+
soc=round(b.get('soc', 0.0), 4),
|
| 638 |
+
ramp_rate=b_cfg.get('ramp_rate', 0.0)
|
| 639 |
+
))
|
| 640 |
+
|
| 641 |
+
freq = self._compute_frequency()
|
| 642 |
+
|
| 643 |
+
return GridObservation(
|
| 644 |
+
timestep=self.timestep,
|
| 645 |
+
grid_frequency=round(freq, 4),
|
| 646 |
+
buses=obs_buses,
|
| 647 |
+
lines=obs_lines,
|
| 648 |
+
cooldowns=self.cooldowns,
|
| 649 |
+
is_blackout=getattr(self, '_is_blackout', False)
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# ---------- Lookup Helpers (O(1) indexed + guarded fallbacks) ----------
|
| 653 |
+
|
| 654 |
+
def _find_line(self, line_id: str):
|
| 655 |
+
# Use index if available (built in reset), fall back to linear scan
|
| 656 |
+
idx = getattr(self, '_line_state_by_id', None)
|
| 657 |
+
if idx is not None:
|
| 658 |
+
return idx.get(line_id)
|
| 659 |
+
return next((l for l in self.line_state if l['id'] == line_id), None)
|
| 660 |
+
|
| 661 |
+
def _find_bus_config(self, bus_id: int):
|
| 662 |
+
return self._bus_cfg_by_id.get(bus_id)
|
| 663 |
+
|
| 664 |
+
def _find_bus_state(self, bus_id: int):
|
| 665 |
+
idx = getattr(self, '_bus_state_by_id', None)
|
| 666 |
+
if idx is not None:
|
| 667 |
+
return idx.get(bus_id)
|
| 668 |
+
return next((b for b in self.bus_state if b['id'] == bus_id), None)
|
| 669 |
+
|
| 670 |
+
def _get_line_capacity(self, line_id: str) -> float:
|
| 671 |
+
cfg = self._line_cfg_by_id.get(line_id)
|
| 672 |
+
return cfg['capacity'] if cfg else 1.0
|
src/grader.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, Callable, List
|
| 5 |
+
from .environment import OpenGridEnv
|
| 6 |
+
from .models import GridAction, BusAdjustment, TopologyAction
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _random_thrash_policy(obs, rng: np.random.Generator) -> GridAction:
|
| 12 |
+
"""Deliberately bad policy: random topology switching. Used as reward floor.
|
| 13 |
+
|
| 14 |
+
Alternates between opening and closing lines to maximize instability
|
| 15 |
+
across all steps (not just step 1). Uses an explicit RNG instance
|
| 16 |
+
(not global np.random) so that floor estimation is reproducible.
|
| 17 |
+
"""
|
| 18 |
+
top_actions = []
|
| 19 |
+
for line in obs.lines:
|
| 20 |
+
if rng.random() > 0.7:
|
| 21 |
+
action = "open" if line.connected else "close"
|
| 22 |
+
top_actions.append(TopologyAction(line_id=line.id, action=action))
|
| 23 |
+
return GridAction(topology_actions=top_actions)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compute_analytical_ceiling(max_steps: int) -> float:
|
| 27 |
+
"""Compute the theoretical maximum reward for an episode.
|
| 28 |
+
|
| 29 |
+
Perfect agent: survives every step (+1.0 survival) and achieves
|
| 30 |
+
tight frequency control bonus (+0.2) every step, with zero overload
|
| 31 |
+
and zero action cost.
|
| 32 |
+
|
| 33 |
+
ceiling = max_steps * (1.0 + 0.2) = max_steps * 1.2
|
| 34 |
+
|
| 35 |
+
NOTE: The +0.2 frequency bonus requires freq_dev < 0.1 Hz, which needs
|
| 36 |
+
|P_slack| < 0.04 * S_total (from droop model). On high-renewable tasks
|
| 37 |
+
(task_hard) where slack routinely absorbs >50 MW of imbalance, this band
|
| 38 |
+
may be structurally inaccessible. The effective ceiling on such tasks is
|
| 39 |
+
closer to max_steps * 1.0 = 50.0. Scores remain comparable across agents
|
| 40 |
+
on the same task — the ceiling just compresses the achievable range.
|
| 41 |
+
"""
|
| 42 |
+
return max_steps * 1.2
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Validator requires scores strictly in the open interval (0, 1).
|
| 46 |
+
# Using wide epsilon so that even aggressive rounding (e.g. round(x, 1))
|
| 47 |
+
# can never produce exactly 0.0 or 1.0.
|
| 48 |
+
_SCORE_EPSILON = 0.02
|
| 49 |
+
_SCORE_MIN = _SCORE_EPSILON # 0.02
|
| 50 |
+
_SCORE_MAX = 1.0 - _SCORE_EPSILON # 0.98
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _safe_float(x: float) -> float:
|
| 54 |
+
"""Convert to plain Python float; replace NaN/Inf with midpoint."""
|
| 55 |
+
v = float(x)
|
| 56 |
+
if not math.isfinite(v):
|
| 57 |
+
return 0.5 # safe fallback inside (0, 1)
|
| 58 |
+
return v
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _clamp_score(score: float) -> float:
|
| 62 |
+
"""Clamp a score to the open interval (0, 1) using Python-native min/max.
|
| 63 |
+
|
| 64 |
+
This avoids any numpy-scalar serialisation quirks and guarantees a plain
|
| 65 |
+
Python float that JSON-encodes to a normal number.
|
| 66 |
+
"""
|
| 67 |
+
score = _safe_float(score)
|
| 68 |
+
score = max(_SCORE_MIN, min(_SCORE_MAX, score))
|
| 69 |
+
# Truncate (not round) to 4 decimal places to avoid
|
| 70 |
+
# round(0.98500…, 4) == 0.985 becoming 0.99 after further rounding.
|
| 71 |
+
score = math.floor(score * 10000) / 10000
|
| 72 |
+
# Final safety: ensure truncation didn't land on a boundary
|
| 73 |
+
score = max(_SCORE_MIN, min(_SCORE_MAX, score))
|
| 74 |
+
return score
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def normalize_score(cumulative_reward: float, reward_floor: float, reward_ceiling: float,
|
| 78 |
+
n1_survival_rate: float = 1.0) -> float:
|
| 79 |
+
"""
|
| 80 |
+
Shared normalization: maps raw cumulative reward to the open interval (0, 1).
|
| 81 |
+
Used by both /grader endpoint and RobustnessGrader for consistency.
|
| 82 |
+
|
| 83 |
+
- reward_floor: empirical worst-case (random thrashing policy, seeded RNG)
|
| 84 |
+
- reward_ceiling: analytical upper bound (perfect survival + perfect frequency bonus)
|
| 85 |
+
- n1_survival_rate: fraction of episodes without blackout (adds up to 10% bonus)
|
| 86 |
+
|
| 87 |
+
Scores are clamped to [0.02, 0.98] so they are never exactly 0.0 or 1.0,
|
| 88 |
+
and cannot round to those values, satisfying the OpenEnv Phase-2 validator.
|
| 89 |
+
"""
|
| 90 |
+
raw_range = _safe_float(reward_ceiling) - _safe_float(reward_floor)
|
| 91 |
+
if raw_range < 1.0:
|
| 92 |
+
raw_range = 1.0 # Prevent division by near-zero
|
| 93 |
+
|
| 94 |
+
cumulative_reward = _safe_float(cumulative_reward)
|
| 95 |
+
normalized = (cumulative_reward - _safe_float(reward_floor)) / raw_range
|
| 96 |
+
|
| 97 |
+
# N-1 bonus: up to 10% boost for surviving without blackout
|
| 98 |
+
# Scale into available headroom so top performers still differentiate
|
| 99 |
+
n1_bonus = float(n1_survival_rate) * 0.1
|
| 100 |
+
available = _SCORE_MAX - normalized
|
| 101 |
+
if available > 0:
|
| 102 |
+
n1_bonus = min(n1_bonus, available * 0.5)
|
| 103 |
+
else:
|
| 104 |
+
n1_bonus = 0.0
|
| 105 |
+
score = normalized + n1_bonus
|
| 106 |
+
|
| 107 |
+
return _clamp_score(score)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class RobustnessGrader:
|
| 111 |
+
"""
|
| 112 |
+
Evaluates a policy's performance on an OpenGrid task.
|
| 113 |
+
|
| 114 |
+
Scoring:
|
| 115 |
+
- Floor: empirical estimate from adversarial random topology thrashing
|
| 116 |
+
(seeded RNG for reproducibility, n_samples=10 for stability)
|
| 117 |
+
- Ceiling: analytical upper bound = max_steps * 1.2
|
| 118 |
+
(perfect survival + perfect frequency bonus every step)
|
| 119 |
+
- Normalizes cumulative reward to 0.0–1.0
|
| 120 |
+
- Adds N-1 survival bonus (max 10%)
|
| 121 |
+
|
| 122 |
+
The heuristic baseline scores ~0.75–0.90, leaving headroom for
|
| 123 |
+
agents that employ active topology management and predictive scheduling.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, config: Dict):
|
| 127 |
+
self.config = config
|
| 128 |
+
self.reward_floor = None
|
| 129 |
+
self.reward_ceiling = None
|
| 130 |
+
|
| 131 |
+
def _estimate_bounds(self, n_samples: int = 10):
|
| 132 |
+
"""Estimate reward bounds.
|
| 133 |
+
|
| 134 |
+
Floor: adversarial random thrashing policy (empirical, seeded).
|
| 135 |
+
Ceiling: analytical upper bound (deterministic).
|
| 136 |
+
|
| 137 |
+
n_samples=10 to reduce variance in the floor estimate.
|
| 138 |
+
The floor uses mean - std to be conservatively low.
|
| 139 |
+
Each episode gets its own thrash RNG derived from a master seed
|
| 140 |
+
so that changing n_samples doesn't alter existing episodes.
|
| 141 |
+
"""
|
| 142 |
+
master_rng = np.random.default_rng(seed=12345)
|
| 143 |
+
|
| 144 |
+
floors = []
|
| 145 |
+
base_seed = self.config.get('seed', 42)
|
| 146 |
+
|
| 147 |
+
for i in range(n_samples):
|
| 148 |
+
# Per-episode thrash RNG — decoupled from other episodes
|
| 149 |
+
thrash_rng = np.random.default_rng(seed=int(master_rng.integers(0, 2**31)))
|
| 150 |
+
|
| 151 |
+
# Vary environment seed so floor reflects environment stochasticity
|
| 152 |
+
config_with_seed = {**self.config, 'seed': base_seed + i}
|
| 153 |
+
env = OpenGridEnv(config_with_seed)
|
| 154 |
+
obs = env.reset()
|
| 155 |
+
done = False
|
| 156 |
+
ep_reward = 0
|
| 157 |
+
while not done:
|
| 158 |
+
action = _random_thrash_policy(obs, rng=thrash_rng)
|
| 159 |
+
obs, reward, done, info = env.step(action)
|
| 160 |
+
ep_reward += reward.value
|
| 161 |
+
floors.append(ep_reward)
|
| 162 |
+
|
| 163 |
+
self.reward_floor = float(np.mean(floors) - np.std(floors))
|
| 164 |
+
logger.debug("Floor estimate: mean=%.2f, std=%.2f, floor=%.2f",
|
| 165 |
+
np.mean(floors), np.std(floors), self.reward_floor)
|
| 166 |
+
|
| 167 |
+
# Ceiling: analytical upper bound (not heuristic)
|
| 168 |
+
max_steps = self.config.get('max_steps', 50)
|
| 169 |
+
analytical_ceiling = compute_analytical_ceiling(max_steps)
|
| 170 |
+
self.reward_ceiling = analytical_ceiling
|
| 171 |
+
|
| 172 |
+
# Ensure minimum spread — expand floor downward, not ceiling upward
|
| 173 |
+
if self.reward_ceiling - self.reward_floor < 10.0:
|
| 174 |
+
self.reward_floor = self.reward_ceiling - max(10.0, analytical_ceiling * 0.2)
|
| 175 |
+
logger.debug("Spread too small, adjusted floor to %.2f", self.reward_floor)
|
| 176 |
+
|
| 177 |
+
def get_bounds(self) -> Dict[str, float]:
|
| 178 |
+
"""Return the reward floor and ceiling, computing if needed."""
|
| 179 |
+
if self.reward_floor is None:
|
| 180 |
+
self._estimate_bounds()
|
| 181 |
+
return {"reward_floor": self.reward_floor, "reward_ceiling": self.reward_ceiling}
|
| 182 |
+
|
| 183 |
+
def evaluate_policy(self, policy_fn: Callable, n_episodes: int = 10) -> Dict:
|
| 184 |
+
"""Run a policy for n_episodes and return normalized score.
|
| 185 |
+
|
| 186 |
+
Each episode uses a different environment seed (offset by 1000 from
|
| 187 |
+
floor estimation seeds) to measure policy robustness across diverse
|
| 188 |
+
wind/load trajectories.
|
| 189 |
+
"""
|
| 190 |
+
if self.reward_floor is None:
|
| 191 |
+
self._estimate_bounds()
|
| 192 |
+
|
| 193 |
+
base_seed = self.config.get('seed', 42)
|
| 194 |
+
rewards = []
|
| 195 |
+
n1_survivals = 0
|
| 196 |
+
|
| 197 |
+
for ep in range(n_episodes):
|
| 198 |
+
# Offset by 1000 to avoid overlap with floor estimation seeds
|
| 199 |
+
config_with_seed = {**self.config, 'seed': base_seed + ep + 1000}
|
| 200 |
+
env = OpenGridEnv(config_with_seed)
|
| 201 |
+
obs = env.reset()
|
| 202 |
+
done = False
|
| 203 |
+
ep_reward = 0
|
| 204 |
+
|
| 205 |
+
while not done:
|
| 206 |
+
action = policy_fn(obs)
|
| 207 |
+
obs, reward, done, info = env.step(action)
|
| 208 |
+
ep_reward += reward.value
|
| 209 |
+
|
| 210 |
+
rewards.append(ep_reward)
|
| 211 |
+
if not info.is_blackout:
|
| 212 |
+
n1_survivals += 1
|
| 213 |
+
|
| 214 |
+
avg_reward = float(np.mean(rewards))
|
| 215 |
+
n1_rate = n1_survivals / n_episodes
|
| 216 |
+
logger.debug("Policy eval: avg=%.2f, n1_rate=%.2f, episodes=%d",
|
| 217 |
+
avg_reward, n1_rate, n_episodes)
|
| 218 |
+
|
| 219 |
+
final_score = normalize_score(
|
| 220 |
+
cumulative_reward=avg_reward,
|
| 221 |
+
reward_floor=self.reward_floor,
|
| 222 |
+
reward_ceiling=self.reward_ceiling,
|
| 223 |
+
n1_survival_rate=n1_rate
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
"avg_raw_reward": round(avg_reward, 4),
|
| 228 |
+
"n1_survival_rate": round(n1_rate, 4),
|
| 229 |
+
"reward_floor": round(self.reward_floor, 4),
|
| 230 |
+
"reward_ceiling": round(self.reward_ceiling, 4),
|
| 231 |
+
"score": final_score
|
| 232 |
+
}
|
src/models.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Literal, Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TopologyAction(BaseModel):
|
| 6 |
+
"""A topology switching action on a transmission line."""
|
| 7 |
+
line_id: str
|
| 8 |
+
action: Literal["open", "close"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BusAdjustment(BaseModel):
|
| 12 |
+
"""A power injection adjustment on a bus."""
|
| 13 |
+
bus_id: int
|
| 14 |
+
delta: float # MW change (positive = inject more)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GridAction(BaseModel):
|
| 18 |
+
"""Agent action: adjust bus injections and/or switch line topology."""
|
| 19 |
+
bus_adjustments: List[BusAdjustment] = []
|
| 20 |
+
topology_actions: List[TopologyAction] = []
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LineStatus(BaseModel):
|
| 24 |
+
"""Current state of a transmission line."""
|
| 25 |
+
id: str
|
| 26 |
+
connected: bool
|
| 27 |
+
flow: float = 0.0
|
| 28 |
+
rho: float = Field(0.0, ge=0.0, description="Loading percentage (flow/capacity)")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BusState(BaseModel):
|
| 32 |
+
"""Current state of a bus (generator, load, battery, or renewable)."""
|
| 33 |
+
id: int
|
| 34 |
+
type: Literal["slack", "generator", "load", "battery", "solar", "wind"]
|
| 35 |
+
p_injection: float
|
| 36 |
+
soc: float = Field(0.0, ge=0.0, description="State of charge (MWh)")
|
| 37 |
+
ramp_rate: float = 0.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class GridObservation(BaseModel):
|
| 41 |
+
"""Full grid observation returned by reset()/step()/state()."""
|
| 42 |
+
timestep: int
|
| 43 |
+
grid_frequency: float
|
| 44 |
+
buses: List[BusState]
|
| 45 |
+
lines: List[LineStatus]
|
| 46 |
+
cooldowns: Dict[str, int]
|
| 47 |
+
is_blackout: bool = False
|
| 48 |
+
|
| 49 |
+
def __repr__(self) -> str:
|
| 50 |
+
return (
|
| 51 |
+
f"GridObservation(t={self.timestep}, f={self.grid_frequency:.2f}, "
|
| 52 |
+
f"buses={len(self.buses)}, lines={len(self.lines)}, "
|
| 53 |
+
f"blackout={self.is_blackout})"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GridReward(BaseModel):
|
| 58 |
+
"""Reward signal with component breakdown."""
|
| 59 |
+
value: float
|
| 60 |
+
components: Dict[str, float]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GridInfo(BaseModel):
|
| 64 |
+
"""Episode info (metadata alongside reward)."""
|
| 65 |
+
task_id: str
|
| 66 |
+
is_blackout: bool
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Multi-Agent POMDP Models
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
class ZoneInfo(BaseModel):
|
| 74 |
+
"""Metadata about an agent's zone."""
|
| 75 |
+
agent_id: int
|
| 76 |
+
zone_name: str
|
| 77 |
+
bus_ids: List[int]
|
| 78 |
+
boundary_line_ids: List[str]
|
| 79 |
+
internal_line_ids: List[str]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ZoneObservation(BaseModel):
|
| 83 |
+
"""Partial observation for one agent under POMDP.
|
| 84 |
+
|
| 85 |
+
Each agent sees only:
|
| 86 |
+
- Their local buses (within their zone)
|
| 87 |
+
- Boundary lines (connecting to other zones)
|
| 88 |
+
- Internal lines (within their zone)
|
| 89 |
+
- A noisy estimate of global grid frequency
|
| 90 |
+
- Limited communication signals from neighboring agents
|
| 91 |
+
"""
|
| 92 |
+
agent_id: int
|
| 93 |
+
zone_name: str
|
| 94 |
+
timestep: int
|
| 95 |
+
grid_frequency: float # noisy — Gaussian noise added
|
| 96 |
+
local_buses: List[BusState]
|
| 97 |
+
boundary_lines: List[LineStatus]
|
| 98 |
+
internal_lines: List[LineStatus]
|
| 99 |
+
neighbor_signals: Dict[int, float] = Field(
|
| 100 |
+
default_factory=dict,
|
| 101 |
+
description="Limited info from other agents: {agent_id: their avg bus injection}"
|
| 102 |
+
)
|
| 103 |
+
cooldowns: Dict[str, int] = Field(default_factory=dict)
|
| 104 |
+
is_blackout: bool = False
|
| 105 |
+
zone_load_mw: float = 0.0
|
| 106 |
+
zone_gen_mw: float = 0.0
|
| 107 |
+
|
| 108 |
+
def __repr__(self) -> str:
|
| 109 |
+
return (
|
| 110 |
+
f"ZoneObservation(agent={self.agent_id}, zone={self.zone_name}, "
|
| 111 |
+
f"t={self.timestep}, f={self.grid_frequency:.2f}, "
|
| 112 |
+
f"buses={len(self.local_buses)}, blackout={self.is_blackout})"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class SafetyReport(BaseModel):
|
| 117 |
+
"""Report from the safety layer about action corrections."""
|
| 118 |
+
agent_id: int
|
| 119 |
+
was_corrected: bool
|
| 120 |
+
correction_reason: str = ""
|
| 121 |
+
n1_violations_detected: int = 0
|
| 122 |
+
proposed_topology_actions: int = 0
|
| 123 |
+
blocked_topology_actions: int = 0
|
| 124 |
+
original_total_delta_mw: float = 0.0
|
| 125 |
+
corrected_total_delta_mw: float = 0.0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OversightReport(BaseModel):
|
| 129 |
+
"""Report from the oversight agent about multi-agent coordination."""
|
| 130 |
+
coordination_score: float = Field(
|
| 131 |
+
1.0, description="1.0 = perfect cooperation, 0.0 = total conflict"
|
| 132 |
+
)
|
| 133 |
+
conflicting_actions_detected: int = 0
|
| 134 |
+
selfish_actions_detected: int = 0
|
| 135 |
+
coordination_penalties: Dict[int, float] = Field(default_factory=dict)
|
| 136 |
+
global_frequency_contribution: Dict[int, float] = Field(
|
| 137 |
+
default_factory=dict,
|
| 138 |
+
description="Each agent's net impact on frequency deviation"
|
| 139 |
+
)
|
| 140 |
+
notes: List[str] = Field(default_factory=list)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class MultiAgentAction(BaseModel):
|
| 144 |
+
"""Request body for /step_multi: per-agent actions keyed by agent_id."""
|
| 145 |
+
agent_actions: Dict[int, GridAction] = Field(
|
| 146 |
+
default_factory=dict,
|
| 147 |
+
description="Actions for each agent, keyed by agent_id"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MultiAgentStepResult(BaseModel):
|
| 152 |
+
"""Result of a multi-agent step — per-agent observations, rewards, reports."""
|
| 153 |
+
observations: Dict[int, ZoneObservation]
|
| 154 |
+
rewards: Dict[int, GridReward]
|
| 155 |
+
team_reward: float
|
| 156 |
+
done: bool
|
| 157 |
+
safety_reports: Dict[int, SafetyReport] = Field(
|
| 158 |
+
default_factory=dict,
|
| 159 |
+
description="Per-agent safety reports, keyed by agent_id"
|
| 160 |
+
)
|
| 161 |
+
oversight_report: OversightReport
|
| 162 |
+
info: GridInfo
|
src/oversight.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Oversight Agent — Multi-Agent Coordination Monitor
|
| 3 |
+
===================================================
|
| 4 |
+
A rule-based meta-agent that monitors coordination quality across zones.
|
| 5 |
+
|
| 6 |
+
Responsibilities:
|
| 7 |
+
1. Detect conflicting actions (agents pulling frequency opposite ways)
|
| 8 |
+
2. Detect selfish behavior (local improvement at global cost)
|
| 9 |
+
3. Assign coordination penalties to agents
|
| 10 |
+
4. Track safety layer intervention frequency
|
| 11 |
+
|
| 12 |
+
This is NOT a trained agent — it's a deterministic rule engine that
|
| 13 |
+
provides additional reward signal to guide multi-agent learning.
|
| 14 |
+
|
| 15 |
+
References:
|
| 16 |
+
- Symphony: Multi-Agent Intelligence in a Collective Fabric (Gradient, 2025)
|
| 17 |
+
- Massgen: When Multiple LLMs Think Together (Gradient, 2025)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import math
|
| 22 |
+
from typing import Dict, List
|
| 23 |
+
from .models import GridAction, SafetyReport, OversightReport
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OversightAgent:
|
| 29 |
+
"""Rule-based oversight agent for multi-agent coordination.
|
| 30 |
+
|
| 31 |
+
Sits above zone agents and evaluates whether their combined actions
|
| 32 |
+
are globally beneficial or harmful. Produces an OversightReport
|
| 33 |
+
with coordination scores and penalties.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: Dict):
|
| 37 |
+
self.config = config
|
| 38 |
+
self.zone_assignments = config.get('zone_assignments', {})
|
| 39 |
+
self.num_agents = config.get('num_agents', 1)
|
| 40 |
+
self.intervention_history: Dict[int, int] = {
|
| 41 |
+
i: 0 for i in range(self.num_agents)
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def evaluate(
|
| 45 |
+
self,
|
| 46 |
+
agent_actions: Dict[int, GridAction],
|
| 47 |
+
safety_reports: Dict[int, SafetyReport],
|
| 48 |
+
pre_frequency: float,
|
| 49 |
+
post_frequency: float,
|
| 50 |
+
pre_bus_state: List[Dict],
|
| 51 |
+
post_bus_state: List[Dict],
|
| 52 |
+
) -> OversightReport:
|
| 53 |
+
"""Evaluate multi-agent coordination quality.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
agent_actions: {agent_id: GridAction} — proposed actions
|
| 57 |
+
safety_reports: {agent_id: SafetyReport} — per-agent safety results
|
| 58 |
+
pre_frequency: Grid frequency before this step
|
| 59 |
+
post_frequency: Grid frequency after this step
|
| 60 |
+
pre_bus_state: Bus states before actions
|
| 61 |
+
post_bus_state: Bus states after actions
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
OversightReport with scores, penalties, and notes
|
| 65 |
+
"""
|
| 66 |
+
notes = []
|
| 67 |
+
penalties: Dict[int, float] = {i: 0.0 for i in range(self.num_agents)}
|
| 68 |
+
conflicts = 0
|
| 69 |
+
selfish_count = 0
|
| 70 |
+
|
| 71 |
+
# --- 1. Track safety interventions ---
|
| 72 |
+
for agent_id, report in safety_reports.items():
|
| 73 |
+
# Validate agent_id is within expected range
|
| 74 |
+
if agent_id not in self.intervention_history:
|
| 75 |
+
notes.append(f"WARNING: unknown agent_id {agent_id} in safety report")
|
| 76 |
+
continue
|
| 77 |
+
if report.was_corrected:
|
| 78 |
+
self.intervention_history[agent_id] += 1
|
| 79 |
+
# Penalty scales with repeated violations
|
| 80 |
+
repeat_count = self.intervention_history[agent_id]
|
| 81 |
+
penalties[agent_id] += 0.1 * min(repeat_count, 5)
|
| 82 |
+
notes.append(
|
| 83 |
+
f"Agent {agent_id}: safety correction #{repeat_count}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# --- 2. Detect conflicting frequency actions ---
|
| 87 |
+
# If agents are pushing frequency in opposite directions, that's waste
|
| 88 |
+
net_deltas = {}
|
| 89 |
+
for agent_id, action in agent_actions.items():
|
| 90 |
+
total_delta = sum(a.delta for a in action.bus_adjustments)
|
| 91 |
+
n_topo = len(action.topology_actions)
|
| 92 |
+
if n_topo > 0:
|
| 93 |
+
notes.append(
|
| 94 |
+
f"Agent {agent_id}: {n_topo} topology action(s) "
|
| 95 |
+
f"not included in conflict analysis"
|
| 96 |
+
)
|
| 97 |
+
net_deltas[agent_id] = total_delta
|
| 98 |
+
|
| 99 |
+
if len(net_deltas) >= 2:
|
| 100 |
+
deltas = list(net_deltas.values())
|
| 101 |
+
# Check if some agents inject and others withdraw significantly
|
| 102 |
+
injectors = [d for d in deltas if d > 2.0]
|
| 103 |
+
withdrawers = [d for d in deltas if d < -2.0]
|
| 104 |
+
if injectors and withdrawers:
|
| 105 |
+
conflicts += 1
|
| 106 |
+
notes.append(
|
| 107 |
+
"Conflicting actions: some agents inject while others withdraw"
|
| 108 |
+
)
|
| 109 |
+
# Penalize the agent pushing AGAINST the needed direction
|
| 110 |
+
freq_error = 50.0 - pre_frequency
|
| 111 |
+
|
| 112 |
+
if abs(freq_error) > 0.1:
|
| 113 |
+
# Clear direction needed — penalize the opposing side
|
| 114 |
+
for agent_id, delta in net_deltas.items():
|
| 115 |
+
# If freq < 50 (need more injection) but agent withdraws
|
| 116 |
+
if freq_error > 0.1 and delta < -2.0:
|
| 117 |
+
penalties[agent_id] += 0.2
|
| 118 |
+
selfish_count += 1
|
| 119 |
+
notes.append(
|
| 120 |
+
f"Agent {agent_id}: withdrew {delta:.1f} MW "
|
| 121 |
+
f"when grid needed injection"
|
| 122 |
+
)
|
| 123 |
+
# If freq > 50 (need less injection) but agent injects
|
| 124 |
+
elif freq_error < -0.1 and delta > 2.0:
|
| 125 |
+
penalties[agent_id] += 0.2
|
| 126 |
+
selfish_count += 1
|
| 127 |
+
notes.append(
|
| 128 |
+
f"Agent {agent_id}: injected {delta:.1f} MW "
|
| 129 |
+
f"when grid had excess"
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
# Near-nominal: penalize all significant participants equally
|
| 133 |
+
for agent_id, delta in net_deltas.items():
|
| 134 |
+
if abs(delta) > 2.0:
|
| 135 |
+
penalties[agent_id] += 0.1
|
| 136 |
+
notes.append(
|
| 137 |
+
f"Agent {agent_id}: conflicting injection "
|
| 138 |
+
f"({delta:+.1f} MW) with no clear grid need"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# --- 3. Evaluate frequency impact per agent ---
|
| 142 |
+
freq_contribution: Dict[int, float] = {}
|
| 143 |
+
freq_dev_before = abs(pre_frequency - 50.0)
|
| 144 |
+
freq_dev_after = abs(post_frequency - 50.0)
|
| 145 |
+
freq_improved = freq_dev_after < freq_dev_before
|
| 146 |
+
|
| 147 |
+
for agent_id in range(self.num_agents):
|
| 148 |
+
# Net MW delta (not frequency impact — would need droop constant)
|
| 149 |
+
total_delta = net_deltas.get(agent_id, 0.0)
|
| 150 |
+
freq_contribution[agent_id] = round(total_delta, 4)
|
| 151 |
+
|
| 152 |
+
# --- 4. Compute coordination score ---
|
| 153 |
+
# Sub-linear scaling: diminishing penalty per additional incident
|
| 154 |
+
# prevents score from collapsing to 0.0 for mildly bad teams
|
| 155 |
+
safety_corrections = sum(
|
| 156 |
+
1 for r in safety_reports.values() if r.was_corrected
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
conflict_penalty = 1.0 - math.exp(-conflicts * 0.3)
|
| 160 |
+
selfish_penalty = 1.0 - math.exp(-selfish_count * 0.2)
|
| 161 |
+
safety_penalty = 1.0 - math.exp(-safety_corrections * 0.2)
|
| 162 |
+
|
| 163 |
+
base_score = (1.0
|
| 164 |
+
- 0.4 * conflict_penalty
|
| 165 |
+
- 0.3 * selfish_penalty
|
| 166 |
+
- 0.3 * safety_penalty)
|
| 167 |
+
|
| 168 |
+
# Frequency improvement bonus / degradation penalty
|
| 169 |
+
if freq_improved:
|
| 170 |
+
base_score += 0.1
|
| 171 |
+
else:
|
| 172 |
+
degradation = freq_dev_after - freq_dev_before
|
| 173 |
+
base_score -= min(degradation * 0.5, 0.2)
|
| 174 |
+
|
| 175 |
+
coordination_score = max(0.0, min(1.0, base_score))
|
| 176 |
+
|
| 177 |
+
return OversightReport(
|
| 178 |
+
coordination_score=round(coordination_score, 4),
|
| 179 |
+
conflicting_actions_detected=conflicts,
|
| 180 |
+
selfish_actions_detected=selfish_count,
|
| 181 |
+
coordination_penalties=penalties,
|
| 182 |
+
global_frequency_contribution=freq_contribution,
|
| 183 |
+
notes=notes,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def reset(self):
|
| 187 |
+
"""Reset intervention history for a new episode."""
|
| 188 |
+
self.intervention_history = {
|
| 189 |
+
i: 0 for i in range(self.num_agents)
|
| 190 |
+
}
|
src/physics.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DC Power Flow Solver
|
| 3 |
+
====================
|
| 4 |
+
Implements the standard DC approximation: B * θ = P
|
| 5 |
+
|
| 6 |
+
Assumptions:
|
| 7 |
+
- Flat voltage profile (|V| ≈ 1.0 p.u.)
|
| 8 |
+
- Small angle differences (sin(θ) ≈ θ)
|
| 9 |
+
- Negligible resistance (R ≈ 0, only susceptance used)
|
| 10 |
+
|
| 11 |
+
Flow sign convention:
|
| 12 |
+
flow = b * (θ_from - θ_to)
|
| 13 |
+
Positive flow = power flowing from 'from' bus to 'to' bus.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import warnings
|
| 18 |
+
import numpy as np
|
| 19 |
+
from typing import List, Dict, Tuple
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IslandedException(Exception):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DCSolver:
|
| 29 |
+
"""DC power flow solver with graph-based islanding detection.
|
| 30 |
+
|
| 31 |
+
The slack bus absorbs any power imbalance and has its voltage angle
|
| 32 |
+
fixed to 0 (reference). By default this is bus 0, but can be
|
| 33 |
+
configured via the slack_bus parameter.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, num_buses: int, slack_bus: int = 0):
|
| 37 |
+
self.num_buses = num_buses
|
| 38 |
+
self.slack_bus = slack_bus
|
| 39 |
+
self.B = np.zeros((num_buses, num_buses))
|
| 40 |
+
self.line_map = {}
|
| 41 |
+
self._grid_loaded = False
|
| 42 |
+
|
| 43 |
+
def update_grid(self, lines: List[Dict]):
|
| 44 |
+
"""Rebuild the B matrix and check connectivity.
|
| 45 |
+
|
| 46 |
+
Skips zero-susceptance lines (no electrical contribution).
|
| 47 |
+
Validates bus indices to prevent silent corruption.
|
| 48 |
+
"""
|
| 49 |
+
self.B = np.zeros((self.num_buses, self.num_buses))
|
| 50 |
+
self.line_map = {}
|
| 51 |
+
|
| 52 |
+
# Union-Find for O(n) connectivity check (replaces NetworkX)
|
| 53 |
+
parent = list(range(self.num_buses))
|
| 54 |
+
rank = [0] * self.num_buses
|
| 55 |
+
|
| 56 |
+
def find(x):
|
| 57 |
+
while parent[x] != x:
|
| 58 |
+
parent[x] = parent[parent[x]] # path compression
|
| 59 |
+
x = parent[x]
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
def union(x, y):
|
| 63 |
+
rx, ry = find(x), find(y)
|
| 64 |
+
if rx == ry:
|
| 65 |
+
return
|
| 66 |
+
if rank[rx] < rank[ry]:
|
| 67 |
+
rx, ry = ry, rx
|
| 68 |
+
parent[ry] = rx
|
| 69 |
+
if rank[rx] == rank[ry]:
|
| 70 |
+
rank[rx] += 1
|
| 71 |
+
|
| 72 |
+
for line in lines:
|
| 73 |
+
if line['connected']:
|
| 74 |
+
i, j = line['from'], line['to']
|
| 75 |
+
b = line['susceptance']
|
| 76 |
+
|
| 77 |
+
# Validate bus indices
|
| 78 |
+
if not (0 <= i < self.num_buses and 0 <= j < self.num_buses):
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"Line {line['id']}: bus indices ({i}, {j}) out of range "
|
| 81 |
+
f"for {self.num_buses} buses"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Skip zero-susceptance lines (no electrical contribution)
|
| 85 |
+
if abs(b) < 1e-12:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
self.B[i, j] -= b
|
| 89 |
+
self.B[j, i] -= b
|
| 90 |
+
self.B[i, i] += b
|
| 91 |
+
self.B[j, j] += b
|
| 92 |
+
|
| 93 |
+
self.line_map[line['id']] = (i, j, b)
|
| 94 |
+
union(i, j)
|
| 95 |
+
|
| 96 |
+
# Connectivity check via union-find
|
| 97 |
+
root = find(0)
|
| 98 |
+
if not all(find(i) == root for i in range(self.num_buses)):
|
| 99 |
+
# Build component info for diagnostics
|
| 100 |
+
components = {}
|
| 101 |
+
for i in range(self.num_buses):
|
| 102 |
+
r = find(i)
|
| 103 |
+
components.setdefault(r, []).append(i)
|
| 104 |
+
comp_sizes = [len(c) for c in components.values()]
|
| 105 |
+
raise IslandedException(
|
| 106 |
+
f"Grid is islanded: {len(components)} components, "
|
| 107 |
+
f"sizes={comp_sizes}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self._grid_loaded = True
|
| 111 |
+
|
| 112 |
+
def solve(self, p_inj: np.ndarray) -> Tuple[np.ndarray, Dict[str, float], float]:
|
| 113 |
+
"""Solve DC power flow: B_red * θ_red = P_red.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
p_inj: Real power injection at each bus (MW). Shape must be (num_buses,).
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
(theta, line_flows, slack_injection) tuple.
|
| 120 |
+
theta: voltage angles (radians). Slack bus angle = 0.
|
| 121 |
+
line_flows: {line_id: flow_MW}. Positive = from→to direction.
|
| 122 |
+
slack_injection: MW absorbed/injected by the slack bus.
|
| 123 |
+
"""
|
| 124 |
+
if not self._grid_loaded:
|
| 125 |
+
raise RuntimeError("DCSolver.solve() called before update_grid()")
|
| 126 |
+
|
| 127 |
+
# Validate input
|
| 128 |
+
p_inj = np.asarray(p_inj).ravel()
|
| 129 |
+
if len(p_inj) != self.num_buses:
|
| 130 |
+
raise ValueError(
|
| 131 |
+
f"p_inj length {len(p_inj)} != num_buses {self.num_buses}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Remove slack bus row/column
|
| 135 |
+
mask = np.arange(self.num_buses) != self.slack_bus
|
| 136 |
+
B_red = self.B[np.ix_(mask, mask)]
|
| 137 |
+
p_red = p_inj[mask]
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
theta_red = np.linalg.solve(B_red, p_red)
|
| 141 |
+
except np.linalg.LinAlgError:
|
| 142 |
+
raise IslandedException("Grid is islanded (singular B matrix)")
|
| 143 |
+
|
| 144 |
+
# Check conditioning
|
| 145 |
+
cond = np.linalg.cond(B_red)
|
| 146 |
+
if cond > 1e12:
|
| 147 |
+
warnings.warn(
|
| 148 |
+
f"DCSolver: B_red is ill-conditioned (cond={cond:.2e}). "
|
| 149 |
+
f"Results may be numerically unreliable.",
|
| 150 |
+
RuntimeWarning,
|
| 151 |
+
stacklevel=2,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Insert slack bus angle (= 0)
|
| 155 |
+
theta = np.zeros(self.num_buses)
|
| 156 |
+
theta[mask] = theta_red
|
| 157 |
+
|
| 158 |
+
# Compute line flows
|
| 159 |
+
flows = {}
|
| 160 |
+
for line_id, (i, j, b) in self.line_map.items():
|
| 161 |
+
flows[line_id] = (theta[i] - theta[j]) * b
|
| 162 |
+
|
| 163 |
+
# Slack injection from power balance (more robust than summing flows)
|
| 164 |
+
slack_injection = -float(p_inj[mask].sum())
|
| 165 |
+
|
| 166 |
+
return theta, flows, slack_injection
|
| 167 |
+
|
| 168 |
+
def __repr__(self):
|
| 169 |
+
return (
|
| 170 |
+
f"DCSolver(num_buses={self.num_buses}, slack={self.slack_bus}, "
|
| 171 |
+
f"lines={len(self.line_map)}, loaded={self._grid_loaded})"
|
| 172 |
+
)
|
src/safety.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Safety Layer — Hard Constraint Filter for OpenGrid
|
| 3 |
+
===================================================
|
| 4 |
+
Validates agent actions BEFORE they are applied to the environment.
|
| 5 |
+
If constraints are violated, actions are projected to the nearest safe alternative.
|
| 6 |
+
|
| 7 |
+
This is the core safety innovation: constraint violations should NEVER
|
| 8 |
+
reach the physics engine. The safety layer catches them first.
|
| 9 |
+
|
| 10 |
+
Checks:
|
| 11 |
+
1. Anti-Islanding: topology actions that would disconnect the grid are blocked
|
| 12 |
+
2. N-1 Security: for each critical line, simulate failure → check grid survives
|
| 13 |
+
3. Generation Limits: bus adjustments respect ramp rates and capacity
|
| 14 |
+
4. Zone Boundary: agents can only adjust buses in their assigned zone
|
| 15 |
+
|
| 16 |
+
References:
|
| 17 |
+
- KPTCL N-1 security criterion (Indian Grid Code, IEGC)
|
| 18 |
+
- Control Barrier Functions for safe RL (Ames et al., 2019)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
import networkx as nx
|
| 23 |
+
import numpy as np
|
| 24 |
+
from typing import List, Dict, Tuple
|
| 25 |
+
from .models import GridAction, BusAdjustment, TopologyAction, SafetyReport
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SafetyLayer:
|
| 31 |
+
"""Hard constraint filter that validates and corrects agent actions.
|
| 32 |
+
|
| 33 |
+
The safety layer sits between agents and the environment:
|
| 34 |
+
Agent proposes action → SafetyLayer validates → corrected action → Environment
|
| 35 |
+
|
| 36 |
+
If an action would cause a constraint violation, it is PROJECTED to the
|
| 37 |
+
nearest safe alternative (not just rejected). This preserves the agent's
|
| 38 |
+
intent while enforcing safety, and provides a richer training signal
|
| 39 |
+
than binary accept/reject.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, config: Dict):
|
| 43 |
+
self.config = config
|
| 44 |
+
self.num_buses = config['num_buses']
|
| 45 |
+
self.lines_config = config['lines']
|
| 46 |
+
self.buses_config = config['buses']
|
| 47 |
+
self.zone_assignments = config.get('zone_assignments', {})
|
| 48 |
+
self.zone_enforcement = bool(self.zone_assignments)
|
| 49 |
+
|
| 50 |
+
# Build config index for O(1) lookups
|
| 51 |
+
self._bus_cfg_by_id = {b['id']: b for b in self.buses_config}
|
| 52 |
+
|
| 53 |
+
def validate_and_correct(
|
| 54 |
+
self,
|
| 55 |
+
agent_id: int,
|
| 56 |
+
proposed_action: GridAction,
|
| 57 |
+
current_line_state: List[Dict],
|
| 58 |
+
current_bus_state: List[Dict],
|
| 59 |
+
cooldowns: Dict[str, int],
|
| 60 |
+
) -> Tuple[GridAction, SafetyReport]:
|
| 61 |
+
"""Full validation pipeline for one agent's proposed action.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
corrected_action: Safe version of the proposed action
|
| 65 |
+
report: Details of what was checked and corrected
|
| 66 |
+
"""
|
| 67 |
+
corrections = []
|
| 68 |
+
n1_violations = 0
|
| 69 |
+
|
| 70 |
+
# Track original action stats
|
| 71 |
+
original_delta = sum(abs(a.delta) for a in proposed_action.bus_adjustments)
|
| 72 |
+
proposed_topo_count = len(proposed_action.topology_actions)
|
| 73 |
+
blocked_topo_count = 0
|
| 74 |
+
|
| 75 |
+
# Build bus state index for O(1) lookups
|
| 76 |
+
bus_dyn_by_id = {b['id']: b for b in current_bus_state}
|
| 77 |
+
|
| 78 |
+
# --- 1. Zone boundary enforcement ---
|
| 79 |
+
safe_bus_adj = []
|
| 80 |
+
for adj in proposed_action.bus_adjustments:
|
| 81 |
+
bus_zone = self.zone_assignments.get(adj.bus_id, -1)
|
| 82 |
+
if not self.zone_enforcement or bus_zone == agent_id:
|
| 83 |
+
# Agent owns this bus, or single-agent mode
|
| 84 |
+
safe_bus_adj.append(adj)
|
| 85 |
+
else:
|
| 86 |
+
corrections.append(
|
| 87 |
+
f"Blocked bus {adj.bus_id} adjustment: "
|
| 88 |
+
f"belongs to zone {bus_zone}, not agent {agent_id}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# --- 2. Generation limit enforcement ---
|
| 92 |
+
# Aggregate adjustments per bus to prevent double-spending
|
| 93 |
+
bus_deltas: Dict[int, float] = {}
|
| 94 |
+
for adj in safe_bus_adj:
|
| 95 |
+
bus_deltas[adj.bus_id] = bus_deltas.get(adj.bus_id, 0.0) + adj.delta
|
| 96 |
+
|
| 97 |
+
clamped_bus_adj = []
|
| 98 |
+
for bus_id, total_delta in bus_deltas.items():
|
| 99 |
+
bus_cfg = self._bus_cfg_by_id.get(bus_id)
|
| 100 |
+
bus_dyn = bus_dyn_by_id.get(bus_id)
|
| 101 |
+
if bus_cfg is None or bus_dyn is None:
|
| 102 |
+
corrections.append(f"Blocked bus {bus_id}: not found")
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
delta = total_delta
|
| 106 |
+
bus_type = bus_cfg['type']
|
| 107 |
+
|
| 108 |
+
# Loads and renewables can't be directly adjusted
|
| 109 |
+
if bus_type in ['load', 'solar', 'wind']:
|
| 110 |
+
corrections.append(
|
| 111 |
+
f"Blocked bus {bus_id}: type '{bus_type}' is not controllable"
|
| 112 |
+
)
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
# Enforce ramp rate
|
| 116 |
+
max_ramp = bus_cfg.get('ramp_rate', 20.0)
|
| 117 |
+
if abs(delta) > max_ramp:
|
| 118 |
+
delta = np.clip(delta, -max_ramp, max_ramp)
|
| 119 |
+
corrections.append(
|
| 120 |
+
f"Clamped bus {bus_id} delta to ramp rate ±{max_ramp}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Enforce battery SoC limits
|
| 124 |
+
if bus_type == 'battery':
|
| 125 |
+
soc = bus_dyn.get('soc', 0.0)
|
| 126 |
+
capacity = bus_cfg.get('capacity', 50.0)
|
| 127 |
+
if delta > 0 and delta > soc:
|
| 128 |
+
delta = soc
|
| 129 |
+
corrections.append(
|
| 130 |
+
f"Clamped bus {bus_id} discharge to SoC={soc:.1f}"
|
| 131 |
+
)
|
| 132 |
+
elif delta < 0 and abs(delta) > (capacity - soc):
|
| 133 |
+
delta = -(capacity - soc)
|
| 134 |
+
corrections.append(
|
| 135 |
+
f"Clamped bus {bus_id} charge to remaining capacity"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Enforce generator limits
|
| 139 |
+
# NOTE: This is a best-effort projection based on pre-step state.
|
| 140 |
+
# If multiple agents adjust the same bus via different zones,
|
| 141 |
+
# the environment provides a secondary clamp.
|
| 142 |
+
if bus_type in ['slack', 'generator']:
|
| 143 |
+
current_p = bus_dyn.get('p', 0.0)
|
| 144 |
+
new_p = current_p + delta
|
| 145 |
+
min_p = bus_cfg.get('min_p', -50)
|
| 146 |
+
max_p = bus_cfg.get('max_p', 100)
|
| 147 |
+
if new_p < min_p or new_p > max_p:
|
| 148 |
+
new_p = np.clip(new_p, min_p, max_p)
|
| 149 |
+
delta = new_p - current_p
|
| 150 |
+
corrections.append(
|
| 151 |
+
f"Clamped bus {bus_id} to generation limits "
|
| 152 |
+
f"[{min_p}, {max_p}]"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
clamped_bus_adj.append(BusAdjustment(bus_id=bus_id, delta=delta))
|
| 156 |
+
|
| 157 |
+
# --- 3. Topology safety (anti-islanding + N-1) ---
|
| 158 |
+
# Build base graph once for all topology checks
|
| 159 |
+
base_graph = self._build_connectivity_graph(current_line_state)
|
| 160 |
+
|
| 161 |
+
safe_topo = []
|
| 162 |
+
approved_opens: set = set() # Track approved opens for cumulative check
|
| 163 |
+
for t_act in proposed_action.topology_actions:
|
| 164 |
+
line_id = t_act.line_id
|
| 165 |
+
|
| 166 |
+
# Check cooldown
|
| 167 |
+
if cooldowns.get(line_id, 0) > 0:
|
| 168 |
+
corrections.append(
|
| 169 |
+
f"Blocked {line_id} switch: cooldown active "
|
| 170 |
+
f"({cooldowns[line_id]} steps)"
|
| 171 |
+
)
|
| 172 |
+
blocked_topo_count += 1
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
# Check if opening this line would island the grid
|
| 176 |
+
# (cumulative: checks against already-approved opens)
|
| 177 |
+
if t_act.action == "open":
|
| 178 |
+
if self._would_island(
|
| 179 |
+
line_id, base_graph, additional_opens=approved_opens
|
| 180 |
+
):
|
| 181 |
+
corrections.append(
|
| 182 |
+
f"Blocked opening {line_id}: would island the grid"
|
| 183 |
+
)
|
| 184 |
+
blocked_topo_count += 1
|
| 185 |
+
n1_violations += 1
|
| 186 |
+
continue
|
| 187 |
+
approved_opens.add(line_id)
|
| 188 |
+
|
| 189 |
+
safe_topo.append(t_act)
|
| 190 |
+
|
| 191 |
+
# --- 4. N-1 check on final combined action ---
|
| 192 |
+
if safe_topo:
|
| 193 |
+
n1_fails = self._check_n1_post_action(safe_topo, current_line_state)
|
| 194 |
+
if n1_fails > 0:
|
| 195 |
+
n1_violations += n1_fails
|
| 196 |
+
corrections.append(
|
| 197 |
+
f"N-1 warning: {n1_fails} lines would leave grid "
|
| 198 |
+
f"vulnerable after action"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
corrected_action = GridAction(
|
| 202 |
+
bus_adjustments=clamped_bus_adj,
|
| 203 |
+
topology_actions=safe_topo
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
corrected_delta = sum(abs(a.delta) for a in clamped_bus_adj)
|
| 207 |
+
|
| 208 |
+
was_corrected = len(corrections) > 0
|
| 209 |
+
report = SafetyReport(
|
| 210 |
+
agent_id=agent_id,
|
| 211 |
+
was_corrected=was_corrected,
|
| 212 |
+
correction_reason="; ".join(corrections) if corrections else "",
|
| 213 |
+
n1_violations_detected=n1_violations,
|
| 214 |
+
proposed_topology_actions=proposed_topo_count,
|
| 215 |
+
blocked_topology_actions=blocked_topo_count,
|
| 216 |
+
original_total_delta_mw=round(original_delta, 4),
|
| 217 |
+
corrected_total_delta_mw=round(corrected_delta, 4),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return corrected_action, report
|
| 221 |
+
|
| 222 |
+
def _build_connectivity_graph(
|
| 223 |
+
self, current_line_state: List[Dict]
|
| 224 |
+
) -> nx.Graph:
|
| 225 |
+
"""Build the connectivity graph from current line state (once)."""
|
| 226 |
+
G = nx.Graph()
|
| 227 |
+
G.add_nodes_from(range(self.num_buses))
|
| 228 |
+
|
| 229 |
+
line_dyn_by_id = {l['id']: l for l in current_line_state}
|
| 230 |
+
for l_cfg in self.lines_config:
|
| 231 |
+
l_dyn = line_dyn_by_id.get(l_cfg['id'])
|
| 232 |
+
if l_dyn is not None and l_dyn.get('connected', True):
|
| 233 |
+
G.add_edge(l_cfg['from'], l_cfg['to'])
|
| 234 |
+
|
| 235 |
+
return G
|
| 236 |
+
|
| 237 |
+
def _would_island(
|
| 238 |
+
self,
|
| 239 |
+
line_id: str,
|
| 240 |
+
base_graph: nx.Graph,
|
| 241 |
+
additional_opens: set = None,
|
| 242 |
+
) -> bool:
|
| 243 |
+
"""Check if opening a line would disconnect the grid.
|
| 244 |
+
|
| 245 |
+
Takes cumulative approved opens into account so that
|
| 246 |
+
multiple simultaneous opens are correctly checked.
|
| 247 |
+
"""
|
| 248 |
+
additional_opens = additional_opens or set()
|
| 249 |
+
|
| 250 |
+
# Find the edge for this line
|
| 251 |
+
line_cfg = next(
|
| 252 |
+
(l for l in self.lines_config if l['id'] == line_id), None
|
| 253 |
+
)
|
| 254 |
+
if line_cfg is None:
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
# Build a test graph with all proposed removals
|
| 258 |
+
G = base_graph.copy()
|
| 259 |
+
# Remove previously approved opens
|
| 260 |
+
for open_id in additional_opens:
|
| 261 |
+
open_cfg = next(
|
| 262 |
+
(l for l in self.lines_config if l['id'] == open_id), None
|
| 263 |
+
)
|
| 264 |
+
if open_cfg and G.has_edge(open_cfg['from'], open_cfg['to']):
|
| 265 |
+
G.remove_edge(open_cfg['from'], open_cfg['to'])
|
| 266 |
+
|
| 267 |
+
# Remove the line under test
|
| 268 |
+
if G.has_edge(line_cfg['from'], line_cfg['to']):
|
| 269 |
+
G.remove_edge(line_cfg['from'], line_cfg['to'])
|
| 270 |
+
|
| 271 |
+
return not nx.is_connected(G)
|
| 272 |
+
|
| 273 |
+
def _check_n1_post_action(
|
| 274 |
+
self,
|
| 275 |
+
topo_actions: List[TopologyAction],
|
| 276 |
+
current_line_state: List[Dict],
|
| 277 |
+
) -> int:
|
| 278 |
+
"""Check N-1 security after applying proposed topology actions.
|
| 279 |
+
|
| 280 |
+
For each remaining connected line, simulate its loss and check
|
| 281 |
+
connectivity. Uses edge removal/restoration instead of rebuilding
|
| 282 |
+
the full graph for each contingency.
|
| 283 |
+
|
| 284 |
+
Returns the number of single-line contingencies that would island.
|
| 285 |
+
"""
|
| 286 |
+
# Build the post-action line state
|
| 287 |
+
post_state = {}
|
| 288 |
+
for l_dyn in current_line_state:
|
| 289 |
+
post_state[l_dyn['id']] = l_dyn.get('connected', True)
|
| 290 |
+
for t_act in topo_actions:
|
| 291 |
+
post_state[t_act.line_id] = (t_act.action == "close")
|
| 292 |
+
|
| 293 |
+
# Build post-action graph once
|
| 294 |
+
G = nx.Graph()
|
| 295 |
+
G.add_nodes_from(range(self.num_buses))
|
| 296 |
+
|
| 297 |
+
edge_to_line = {}
|
| 298 |
+
for l_cfg in self.lines_config:
|
| 299 |
+
if post_state.get(l_cfg['id'], True):
|
| 300 |
+
u, v = l_cfg['from'], l_cfg['to']
|
| 301 |
+
G.add_edge(u, v)
|
| 302 |
+
edge_to_line[(u, v)] = l_cfg['id']
|
| 303 |
+
|
| 304 |
+
# Test each contingency via edge removal/restoration
|
| 305 |
+
n1_failures = 0
|
| 306 |
+
for (u, v), line_id in edge_to_line.items():
|
| 307 |
+
G.remove_edge(u, v)
|
| 308 |
+
if not nx.is_connected(G):
|
| 309 |
+
n1_failures += 1
|
| 310 |
+
G.add_edge(u, v) # restore
|
| 311 |
+
|
| 312 |
+
return n1_failures
|
| 313 |
+
|
| 314 |
+
def reset(self):
|
| 315 |
+
"""Reset any per-episode state (future-proofing)."""
|
| 316 |
+
pass
|
src/tasks.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grid Generator & Task Definitions
|
| 3 |
+
===================================
|
| 4 |
+
Generates reproducible power grid configurations for OpenGrid RL tasks.
|
| 5 |
+
|
| 6 |
+
Procedural grids use Watts-Strogatz small-world topology with
|
| 7 |
+
configurable difficulty (bus count, renewable penetration).
|
| 8 |
+
|
| 9 |
+
The Karnataka task is a hand-crafted 15-bus grid based on the
|
| 10 |
+
actual KPTCL transmission map.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import copy
|
| 14 |
+
import networkx as nx
|
| 15 |
+
import numpy as np
|
| 16 |
+
from typing import Dict, List, Tuple
|
| 17 |
+
|
| 18 |
+
__all__ = ['generate_procedural_grid', 'generate_karnataka_task', 'TASKS', 'get_task']
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# KPTCL-inspired zone names
|
| 22 |
+
def _get_zone_names(num_agents: int) -> List[str]:
|
| 23 |
+
"""Get human-readable zone names for a given agent count."""
|
| 24 |
+
base_names = [
|
| 25 |
+
"Bengaluru_Region", "Mysuru_Region", "Kalburagi_Region",
|
| 26 |
+
"Belagavi_Region", "Mangaluru_Region",
|
| 27 |
+
]
|
| 28 |
+
if num_agents <= len(base_names):
|
| 29 |
+
return base_names[:num_agents]
|
| 30 |
+
return [f"Zone_{i}" for i in range(num_agents)]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _partition_into_zones(G: nx.Graph, num_agents: int) -> Dict[int, int]:
|
| 34 |
+
"""Partition graph nodes into balanced, connected zones.
|
| 35 |
+
|
| 36 |
+
Returns mapping of {bus_id: agent_id}.
|
| 37 |
+
Guarantees: every bus is assigned, each zone has at least 1 node,
|
| 38 |
+
and zones are roughly balanced in size.
|
| 39 |
+
|
| 40 |
+
NOTE: Uses greedy modularity which is deterministic for a given graph
|
| 41 |
+
structure but not guaranteed across NetworkX versions.
|
| 42 |
+
"""
|
| 43 |
+
nodes = sorted(G.nodes())
|
| 44 |
+
n = len(nodes)
|
| 45 |
+
|
| 46 |
+
if n <= num_agents:
|
| 47 |
+
# Trivial case: 1 bus per agent
|
| 48 |
+
return {node: i for i, node in enumerate(nodes)}
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
communities = nx.community.greedy_modularity_communities(G, cutoff=num_agents)
|
| 52 |
+
communities = [set(c) for c in sorted(communities, key=len, reverse=True)]
|
| 53 |
+
except Exception:
|
| 54 |
+
# Fallback: round-robin assignment by node index
|
| 55 |
+
communities = [set() for _ in range(num_agents)]
|
| 56 |
+
for i, node in enumerate(nodes):
|
| 57 |
+
communities[i % num_agents].add(node)
|
| 58 |
+
|
| 59 |
+
# If we got more communities than agents, merge smallest into largest
|
| 60 |
+
while len(communities) > num_agents:
|
| 61 |
+
smallest = communities.pop()
|
| 62 |
+
communities[0] = communities[0] | smallest
|
| 63 |
+
|
| 64 |
+
# If we got fewer, split the largest using topology-aware bisection
|
| 65 |
+
while len(communities) < num_agents:
|
| 66 |
+
largest = max(communities, key=len)
|
| 67 |
+
communities.remove(largest)
|
| 68 |
+
|
| 69 |
+
# Attempt topology-aware split
|
| 70 |
+
subG = G.subgraph(largest).copy()
|
| 71 |
+
split_done = False
|
| 72 |
+
if nx.is_connected(subG) and len(largest) >= 2:
|
| 73 |
+
# Find edge whose removal creates the most balanced partition
|
| 74 |
+
best_edge, best_balance = None, float('inf')
|
| 75 |
+
target = len(largest) / 2
|
| 76 |
+
for u, v in subG.edges():
|
| 77 |
+
subG.remove_edge(u, v)
|
| 78 |
+
components = list(nx.connected_components(subG))
|
| 79 |
+
if len(components) == 2:
|
| 80 |
+
balance = abs(len(components[0]) - target) + abs(len(components[1]) - target)
|
| 81 |
+
if balance < best_balance:
|
| 82 |
+
best_edge = (u, v)
|
| 83 |
+
best_balance = balance
|
| 84 |
+
subG.add_edge(u, v)
|
| 85 |
+
if best_edge:
|
| 86 |
+
subG.remove_edge(*best_edge)
|
| 87 |
+
parts = list(nx.connected_components(subG))
|
| 88 |
+
communities.extend(parts)
|
| 89 |
+
split_done = True
|
| 90 |
+
|
| 91 |
+
if not split_done:
|
| 92 |
+
# Fallback: arbitrary split
|
| 93 |
+
largest_list = sorted(largest)
|
| 94 |
+
half = len(largest) // 2
|
| 95 |
+
communities.append(set(largest_list[:half]))
|
| 96 |
+
communities.append(set(largest_list[half:]))
|
| 97 |
+
|
| 98 |
+
# Ensure no empty zones
|
| 99 |
+
for i, comm in enumerate(communities):
|
| 100 |
+
if len(comm) == 0:
|
| 101 |
+
# Steal a node from the largest community
|
| 102 |
+
largest = max(communities, key=len)
|
| 103 |
+
stolen = next(iter(largest))
|
| 104 |
+
largest.remove(stolen)
|
| 105 |
+
communities[i] = {stolen}
|
| 106 |
+
|
| 107 |
+
zone_map = {}
|
| 108 |
+
for agent_id, comm in enumerate(communities):
|
| 109 |
+
for node in comm:
|
| 110 |
+
zone_map[node] = agent_id
|
| 111 |
+
|
| 112 |
+
return zone_map
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _classify_lines(
|
| 116 |
+
lines_config: List[Dict], zone_assignments: Dict[int, int]
|
| 117 |
+
) -> Tuple[Dict[int, List[str]], Dict[int, List[str]]]:
|
| 118 |
+
"""Classify lines as internal (both endpoints in same zone) or boundary.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
internal_lines: {agent_id: [line_ids within this zone]}
|
| 122 |
+
boundary_lines: {agent_id: [line_ids on this zone's boundary]}
|
| 123 |
+
"""
|
| 124 |
+
agents = set(zone_assignments.values())
|
| 125 |
+
internal = {a: [] for a in agents}
|
| 126 |
+
boundary = {a: [] for a in agents}
|
| 127 |
+
|
| 128 |
+
for line in lines_config:
|
| 129 |
+
from_zone = zone_assignments.get(line['from'])
|
| 130 |
+
to_zone = zone_assignments.get(line['to'])
|
| 131 |
+
|
| 132 |
+
# Skip lines with unassigned bus endpoints
|
| 133 |
+
if from_zone is None or to_zone is None:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
if from_zone == to_zone:
|
| 137 |
+
internal[from_zone].append(line['id'])
|
| 138 |
+
else:
|
| 139 |
+
boundary[from_zone].append(line['id'])
|
| 140 |
+
boundary[to_zone].append(line['id'])
|
| 141 |
+
|
| 142 |
+
return internal, boundary
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def generate_procedural_grid(difficulty: str = "easy", seed: int = 42):
|
| 146 |
+
"""Generate a reproducible grid configuration for a given difficulty level.
|
| 147 |
+
|
| 148 |
+
Easy: 5 buses, 20% renewables — simple balancing
|
| 149 |
+
Medium: 10 buses, 50% renewables — congestion management
|
| 150 |
+
Hard: 14 buses, 70% renewables — volatile supply, tight margins
|
| 151 |
+
|
| 152 |
+
Guarantees: at least 30% of non-slack buses are loads, and at least 1 battery.
|
| 153 |
+
Includes multi-agent zone assignments for POMDP mode.
|
| 154 |
+
"""
|
| 155 |
+
rng = np.random.default_rng(seed)
|
| 156 |
+
|
| 157 |
+
if difficulty == "easy":
|
| 158 |
+
n_buses = 5
|
| 159 |
+
renewable_mix = 0.2
|
| 160 |
+
max_steps = 50
|
| 161 |
+
num_agents = 2 # Small grid: 2 agents
|
| 162 |
+
elif difficulty == "medium":
|
| 163 |
+
n_buses = 10
|
| 164 |
+
renewable_mix = 0.5
|
| 165 |
+
max_steps = 50
|
| 166 |
+
num_agents = 3
|
| 167 |
+
else: # Hard
|
| 168 |
+
n_buses = 14
|
| 169 |
+
renewable_mix = 0.7
|
| 170 |
+
max_steps = 50
|
| 171 |
+
num_agents = 3
|
| 172 |
+
|
| 173 |
+
G = nx.connected_watts_strogatz_graph(n_buses, k=4, p=0.3, seed=seed)
|
| 174 |
+
|
| 175 |
+
# Generate bus types with guaranteed minimums
|
| 176 |
+
n_non_slack = n_buses - 1
|
| 177 |
+
min_loads = max(2, int(n_non_slack * 0.3)) # At least 30% loads
|
| 178 |
+
min_batteries = 1
|
| 179 |
+
|
| 180 |
+
types = ['slack']
|
| 181 |
+
|
| 182 |
+
# Assign guaranteed loads first
|
| 183 |
+
assigned = []
|
| 184 |
+
for _ in range(min_loads):
|
| 185 |
+
assigned.append('load')
|
| 186 |
+
for _ in range(min_batteries):
|
| 187 |
+
assigned.append('battery')
|
| 188 |
+
|
| 189 |
+
# Fill remaining slots with renewable_mix probability
|
| 190 |
+
remaining = n_non_slack - len(assigned)
|
| 191 |
+
for _ in range(remaining):
|
| 192 |
+
r = rng.random()
|
| 193 |
+
if r < renewable_mix:
|
| 194 |
+
assigned.append(str(rng.choice(['solar', 'wind'])))
|
| 195 |
+
elif r < renewable_mix + 0.15:
|
| 196 |
+
assigned.append('battery')
|
| 197 |
+
else:
|
| 198 |
+
assigned.append('load')
|
| 199 |
+
|
| 200 |
+
# Shuffle to avoid spatial bias (loads always first)
|
| 201 |
+
rng.shuffle(assigned)
|
| 202 |
+
types.extend(assigned)
|
| 203 |
+
|
| 204 |
+
# Estimate total load for slack bus sizing
|
| 205 |
+
load_estimates = []
|
| 206 |
+
buses = []
|
| 207 |
+
lines = []
|
| 208 |
+
|
| 209 |
+
for i, t in enumerate(types):
|
| 210 |
+
base_p = float(rng.uniform(20, 50)) if t == 'load' else 0
|
| 211 |
+
if t == 'load':
|
| 212 |
+
load_estimates.append(base_p)
|
| 213 |
+
|
| 214 |
+
# Set max_p based on bus type
|
| 215 |
+
if t == 'battery':
|
| 216 |
+
max_p = float(rng.uniform(30, 60)) # batteries can discharge
|
| 217 |
+
elif t in ['solar', 'wind', 'generator']:
|
| 218 |
+
max_p = float(rng.uniform(50, 100))
|
| 219 |
+
elif t == 'slack':
|
| 220 |
+
# Slack max_p sized to cover expected imbalance
|
| 221 |
+
max_p = 0 # placeholder, set below
|
| 222 |
+
else:
|
| 223 |
+
max_p = 0
|
| 224 |
+
|
| 225 |
+
buses.append({
|
| 226 |
+
'id': i, 'type': t,
|
| 227 |
+
'base_p': base_p,
|
| 228 |
+
'max_p': max_p,
|
| 229 |
+
'min_p': 0 if t in ['solar', 'wind', 'generator'] else -50,
|
| 230 |
+
'capacity': 50 if t == 'battery' else 0,
|
| 231 |
+
'init_soc': 25.0 if t == 'battery' else 0,
|
| 232 |
+
'ramp_rate': 20.0 if t not in ['load', 'solar', 'wind'] else 0.0,
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Size slack bus to cover expected imbalance
|
| 236 |
+
total_load_est = sum(load_estimates) if load_estimates else 100
|
| 237 |
+
slack_max_p = max(100, total_load_est * 0.6)
|
| 238 |
+
for b in buses:
|
| 239 |
+
if b['type'] == 'slack':
|
| 240 |
+
b['max_p'] = slack_max_p
|
| 241 |
+
b['min_p'] = -slack_max_p
|
| 242 |
+
|
| 243 |
+
for idx, (u, v) in enumerate(G.edges()):
|
| 244 |
+
lines.append({
|
| 245 |
+
'id': f"L_{idx}",
|
| 246 |
+
'from': u, 'to': v,
|
| 247 |
+
'susceptance': 50.0,
|
| 248 |
+
'capacity': float(rng.uniform(80, 150))
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
# Multi-agent zone assignment
|
| 252 |
+
zone_assignments = _partition_into_zones(G, num_agents)
|
| 253 |
+
internal_lines, boundary_lines = _classify_lines(lines, zone_assignments)
|
| 254 |
+
|
| 255 |
+
zone_names = _get_zone_names(num_agents)
|
| 256 |
+
|
| 257 |
+
# Build per-zone bus lists
|
| 258 |
+
zone_bus_ids = {a: [] for a in range(num_agents)}
|
| 259 |
+
for bus_id, agent_id in zone_assignments.items():
|
| 260 |
+
zone_bus_ids[agent_id].append(bus_id)
|
| 261 |
+
|
| 262 |
+
return {
|
| 263 |
+
"id": f"task_{difficulty}",
|
| 264 |
+
"num_buses": n_buses,
|
| 265 |
+
"buses": buses,
|
| 266 |
+
"lines": lines,
|
| 267 |
+
"max_steps": max_steps,
|
| 268 |
+
"seed": seed,
|
| 269 |
+
"difficulty": difficulty,
|
| 270 |
+
# Multi-agent fields
|
| 271 |
+
"num_agents": num_agents,
|
| 272 |
+
"zone_assignments": zone_assignments, # {bus_id: agent_id}
|
| 273 |
+
"zone_names": zone_names,
|
| 274 |
+
"zone_bus_ids": zone_bus_ids, # {agent_id: [bus_ids]}
|
| 275 |
+
"internal_lines": internal_lines, # {agent_id: [line_ids]}
|
| 276 |
+
"boundary_lines": boundary_lines, # {agent_id: [line_ids]}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def generate_karnataka_task(seed: int = 808) -> Dict:
|
| 281 |
+
"""
|
| 282 |
+
A highly realistic 15-bus grid topology based on the actual Karnataka
|
| 283 |
+
KPTCL transmission map. Nodes have real GPS coordinates for GIS rendering.
|
| 284 |
+
"""
|
| 285 |
+
nodes = [
|
| 286 |
+
{"id": 0, "name": "Raichur_TPS", "type": "slack", "lat": 16.20, "lon": 77.36, "max_p": 200, "base_p": 0},
|
| 287 |
+
{"id": 1, "name": "Kalaburagi", "type": "load", "lat": 17.33, "lon": 76.83, "max_p": 0, "base_p": 40},
|
| 288 |
+
{"id": 2, "name": "Belagavi", "type": "load", "lat": 15.85, "lon": 74.50, "max_p": 0, "base_p": 35},
|
| 289 |
+
{"id": 3, "name": "Hubballi", "type": "load", "lat": 15.36, "lon": 75.13, "max_p": 0, "base_p": 45},
|
| 290 |
+
{"id": 4, "name": "Ballari_TPS", "type": "generator", "lat": 15.14, "lon": 76.92, "max_p": 150, "base_p": 0},
|
| 291 |
+
{"id": 5, "name": "Chitradurga_Wind", "type": "wind", "lat": 14.23, "lon": 76.40, "max_p": 80, "base_p": 0},
|
| 292 |
+
{"id": 6, "name": "Pavagada_Solar", "type": "solar", "lat": 14.10, "lon": 77.27, "max_p": 120, "base_p": 0},
|
| 293 |
+
{"id": 7, "name": "Sharavathi_Hydro", "type": "generator", "lat": 14.18, "lon": 74.83, "max_p": 100, "base_p": 0},
|
| 294 |
+
{"id": 8, "name": "Shivamogga", "type": "load", "lat": 13.93, "lon": 75.57, "max_p": 0, "base_p": 30},
|
| 295 |
+
{"id": 9, "name": "Mangaluru", "type": "load", "lat": 12.87, "lon": 74.88, "max_p": 0, "base_p": 50},
|
| 296 |
+
{"id": 10, "name": "Hassan_BESS", "type": "battery", "lat": 13.01, "lon": 76.10, "max_p": 50, "base_p": 0},
|
| 297 |
+
{"id": 11, "name": "Mysuru", "type": "load", "lat": 12.30, "lon": 76.64, "max_p": 0, "base_p": 40},
|
| 298 |
+
{"id": 12, "name": "Nelamangala", "type": "battery", "lat": 13.10, "lon": 77.39, "max_p": 50, "base_p": 0},
|
| 299 |
+
{"id": 13, "name": "Bengaluru_City", "type": "load", "lat": 12.97, "lon": 77.59, "max_p": 0, "base_p": 120},
|
| 300 |
+
{"id": 14, "name": "Kolar_Solar", "type": "solar", "lat": 13.13, "lon": 78.13, "max_p": 60, "base_p": 0},
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
edges = [
|
| 304 |
+
(0,1), (0,4), (4,5), (4,6), (5,3), (3,2), (3,7),
|
| 305 |
+
(7,8), (8,9), (8,10), (9,10), # (9,10) added: connects Mangaluru within zone 2
|
| 306 |
+
(10,11), (10,12), (5,12),
|
| 307 |
+
(6,12), (12,13), (13,14), (11,13)
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
buses = []
|
| 311 |
+
for n in nodes:
|
| 312 |
+
buses.append({
|
| 313 |
+
'id': n['id'], 'name': n['name'], 'type': n['type'],
|
| 314 |
+
'lat': n['lat'], 'lon': n['lon'],
|
| 315 |
+
'base_p': n['base_p'], 'max_p': n['max_p'],
|
| 316 |
+
'min_p': 0 if n['type'] in ['solar', 'wind', 'generator'] else -50,
|
| 317 |
+
'capacity': 100 if n['type'] == 'battery' else 0,
|
| 318 |
+
'init_soc': 50.0 if n['type'] == 'battery' else 0,
|
| 319 |
+
'ramp_rate': 40.0 if n['type'] not in ['load', 'solar', 'wind'] else 0.0,
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
lines = []
|
| 323 |
+
for idx, (u, v) in enumerate(edges):
|
| 324 |
+
lines.append({
|
| 325 |
+
'id': f"L_{u}_{v}", 'from': u, 'to': v,
|
| 326 |
+
'susceptance': 80.0, 'capacity': 150.0
|
| 327 |
+
})
|
| 328 |
+
|
| 329 |
+
# Realistic agents based on regional discoms/SLDC zones
|
| 330 |
+
zone_assignments = {
|
| 331 |
+
0: 0, 1: 0, 4: 0, # North Zone (Raichur/Bellary)
|
| 332 |
+
2: 1, 3: 1, 5: 1, 7: 1, 8: 1, # Hubli/Central Zone
|
| 333 |
+
9: 2, 10: 2, 11: 2, # Mysuru/Coast Zone
|
| 334 |
+
6: 3, 12: 3, 13: 3, 14: 3 # Bengaluru Zone
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
internal_lines, boundary_lines = _classify_lines(lines, zone_assignments)
|
| 338 |
+
|
| 339 |
+
zone_bus_ids = {a: [] for a in range(4)}
|
| 340 |
+
for b_id, a_id in zone_assignments.items():
|
| 341 |
+
zone_bus_ids[a_id].append(b_id)
|
| 342 |
+
|
| 343 |
+
return {
|
| 344 |
+
"id": "task_karnataka",
|
| 345 |
+
"num_buses": len(buses),
|
| 346 |
+
"buses": buses,
|
| 347 |
+
"lines": lines,
|
| 348 |
+
"max_steps": 50,
|
| 349 |
+
"seed": seed,
|
| 350 |
+
"difficulty": "karnataka",
|
| 351 |
+
"num_agents": 4,
|
| 352 |
+
"zone_assignments": zone_assignments,
|
| 353 |
+
"zone_names": ["Kalaburagi_Region", "Hubballi_Region", "Mysuru_Region", "Bengaluru_Region"],
|
| 354 |
+
"zone_bus_ids": zone_bus_ids,
|
| 355 |
+
"internal_lines": internal_lines,
|
| 356 |
+
"boundary_lines": boundary_lines,
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_task(task_id: str) -> Dict:
|
| 361 |
+
"""Get a deep-copied task config by ID."""
|
| 362 |
+
if task_id not in _TASK_GENERATORS:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
f"Unknown task: {task_id}. "
|
| 365 |
+
f"Available: {list(_TASK_GENERATORS.keys())}"
|
| 366 |
+
)
|
| 367 |
+
return copy.deepcopy(_TASK_GENERATORS[task_id]())
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
_TASK_GENERATORS = {
|
| 371 |
+
"task_easy": lambda: generate_procedural_grid("easy", seed=101),
|
| 372 |
+
"task_medium": lambda: generate_procedural_grid("medium", seed=102),
|
| 373 |
+
"task_hard": lambda: generate_procedural_grid("hard", seed=103),
|
| 374 |
+
"task_karnataka": lambda: generate_karnataka_task(),
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
# Deterministic tasks — same seed always produces the same grid
|
| 378 |
+
# NOTE: These are shared instances. Use get_task() for a mutable copy.
|
| 379 |
+
TASKS = {
|
| 380 |
+
"task_easy": generate_procedural_grid("easy", seed=101),
|
| 381 |
+
"task_medium": generate_procedural_grid("medium", seed=102),
|
| 382 |
+
"task_hard": generate_procedural_grid("hard", seed=103),
|
| 383 |
+
"task_karnataka": generate_karnataka_task()
|
| 384 |
+
}
|
src/visualization.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grid Visualization — Dashboard Generator
|
| 3 |
+
==========================================
|
| 4 |
+
Generates a base64-encoded PNG dashboard with two panels:
|
| 5 |
+
1. Grid topology with bus-type coloring and line-loading heat map
|
| 6 |
+
2. Frequency stability trace over time
|
| 7 |
+
|
| 8 |
+
Supports both GridObservation (single-agent) and ZoneObservation (multi-agent).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import io
|
| 12 |
+
import base64
|
| 13 |
+
import logging
|
| 14 |
+
from typing import List, Optional, Sequence, Dict, Tuple
|
| 15 |
+
|
| 16 |
+
import matplotlib
|
| 17 |
+
matplotlib.use('Agg') # Non-interactive backend for server use
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
from matplotlib.lines import Line2D
|
| 20 |
+
import networkx as nx
|
| 21 |
+
|
| 22 |
+
from .models import GridObservation
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _parse_line_endpoints(line_id: str) -> Optional[Tuple[int, int]]:
|
| 28 |
+
"""Parse line ID format 'L_<from>_<to>' into endpoint bus IDs.
|
| 29 |
+
|
| 30 |
+
Returns (from, to) on success, None on parse failure.
|
| 31 |
+
Requires exactly the format L_<int>_<int>.
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
parts = line_id.split('_')
|
| 35 |
+
if len(parts) == 3 and parts[0] == "L":
|
| 36 |
+
return int(parts[1]), int(parts[2])
|
| 37 |
+
except (ValueError, IndexError):
|
| 38 |
+
pass
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_dashboard(
|
| 43 |
+
history: Sequence,
|
| 44 |
+
current_obs,
|
| 45 |
+
config: Optional[Dict] = None,
|
| 46 |
+
) -> str:
|
| 47 |
+
"""Generate a base64-encoded PNG dashboard image.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
history: Sequence of observation objects for frequency trace.
|
| 51 |
+
current_obs: Current GridObservation or ZoneObservation for topology.
|
| 52 |
+
config: Optional grid config dict. When provided, line endpoints
|
| 53 |
+
are read from config (robust) instead of parsed from IDs.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Base64-encoded PNG image string (without data URI prefix).
|
| 57 |
+
"""
|
| 58 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Support both GridObservation and ZoneObservation
|
| 62 |
+
buses = getattr(current_obs, "buses",
|
| 63 |
+
getattr(current_obs, "local_buses", []))
|
| 64 |
+
lines = getattr(current_obs, "lines", None)
|
| 65 |
+
if lines is None:
|
| 66 |
+
internal = getattr(current_obs, "internal_lines", [])
|
| 67 |
+
boundary = getattr(current_obs, "boundary_lines", [])
|
| 68 |
+
lines = list(internal) + list(boundary)
|
| 69 |
+
|
| 70 |
+
# Build line endpoint lookup from config if available
|
| 71 |
+
line_endpoints: Dict[str, Tuple[int, int]] = {}
|
| 72 |
+
if config:
|
| 73 |
+
for l_cfg in config.get("lines", []):
|
| 74 |
+
line_endpoints[l_cfg["id"]] = (l_cfg["from"], l_cfg["to"])
|
| 75 |
+
|
| 76 |
+
# --- Plot 1: Grid Topology ---
|
| 77 |
+
G = nx.Graph()
|
| 78 |
+
|
| 79 |
+
color_map = {}
|
| 80 |
+
for bus in buses:
|
| 81 |
+
G.add_node(bus.id)
|
| 82 |
+
if bus.type in ['generator', 'slack']:
|
| 83 |
+
color_map[bus.id] = '#2ecc71' # green
|
| 84 |
+
elif bus.type == 'load':
|
| 85 |
+
color_map[bus.id] = '#e74c3c' # red
|
| 86 |
+
elif bus.type == 'battery':
|
| 87 |
+
color_map[bus.id] = '#3498db' # blue
|
| 88 |
+
else:
|
| 89 |
+
color_map[bus.id] = '#f1c40f' # yellow (renewables)
|
| 90 |
+
|
| 91 |
+
# Build graph with line data as edge attributes
|
| 92 |
+
for line in lines:
|
| 93 |
+
# Get endpoints from config (preferred) or parse from ID
|
| 94 |
+
if line.id in line_endpoints:
|
| 95 |
+
u, v = line_endpoints[line.id]
|
| 96 |
+
else:
|
| 97 |
+
parsed = _parse_line_endpoints(line.id)
|
| 98 |
+
if parsed is None:
|
| 99 |
+
continue
|
| 100 |
+
u, v = parsed
|
| 101 |
+
|
| 102 |
+
G.add_edge(u, v, line_id=line.id, rho=line.rho,
|
| 103 |
+
connected=line.connected)
|
| 104 |
+
|
| 105 |
+
# Build edge colors in G.edges() order (correct alignment)
|
| 106 |
+
edge_colors = []
|
| 107 |
+
edge_styles = []
|
| 108 |
+
for u, v, data in G.edges(data=True):
|
| 109 |
+
connected = data.get('connected', True)
|
| 110 |
+
rho = abs(data.get('rho', 0.0))
|
| 111 |
+
|
| 112 |
+
if not connected:
|
| 113 |
+
edge_colors.append('lightgray')
|
| 114 |
+
edge_styles.append('dashed')
|
| 115 |
+
elif rho > 0.9:
|
| 116 |
+
edge_colors.append('#e74c3c') # red
|
| 117 |
+
edge_styles.append('solid')
|
| 118 |
+
elif rho > 0.7:
|
| 119 |
+
edge_colors.append('#e67e22') # orange
|
| 120 |
+
edge_styles.append('solid')
|
| 121 |
+
else:
|
| 122 |
+
edge_colors.append('#2ecc71') # green
|
| 123 |
+
edge_styles.append('solid')
|
| 124 |
+
|
| 125 |
+
node_colors = [color_map.get(n, 'gray') for n in G.nodes()]
|
| 126 |
+
|
| 127 |
+
# Use config coordinates if available (stable layout)
|
| 128 |
+
pos = None
|
| 129 |
+
if config:
|
| 130 |
+
bus_coords = {}
|
| 131 |
+
for b_cfg in config.get("buses", []):
|
| 132 |
+
if "lon" in b_cfg and "lat" in b_cfg:
|
| 133 |
+
bus_coords[b_cfg["id"]] = (b_cfg["lon"], b_cfg["lat"])
|
| 134 |
+
if len(bus_coords) == G.number_of_nodes():
|
| 135 |
+
pos = bus_coords
|
| 136 |
+
|
| 137 |
+
if pos is None and G.number_of_nodes() > 0:
|
| 138 |
+
pos = nx.spring_layout(G, seed=42)
|
| 139 |
+
|
| 140 |
+
if G.number_of_nodes() > 0 and pos:
|
| 141 |
+
# Draw solid edges
|
| 142 |
+
solid_edges = [
|
| 143 |
+
(u, v) for (u, v, _), s in zip(G.edges(data=True), edge_styles)
|
| 144 |
+
if s == 'solid'
|
| 145 |
+
]
|
| 146 |
+
solid_colors = [
|
| 147 |
+
c for c, s in zip(edge_colors, edge_styles) if s == 'solid'
|
| 148 |
+
]
|
| 149 |
+
dashed_edges = [
|
| 150 |
+
(u, v) for (u, v, _), s in zip(G.edges(data=True), edge_styles)
|
| 151 |
+
if s == 'dashed'
|
| 152 |
+
]
|
| 153 |
+
dashed_colors = [
|
| 154 |
+
c for c, s in zip(edge_colors, edge_styles) if s == 'dashed'
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
nx.draw_networkx_nodes(
|
| 158 |
+
G, pos, ax=ax1, node_color=node_colors, node_size=300
|
| 159 |
+
)
|
| 160 |
+
nx.draw_networkx_labels(G, pos, ax=ax1, font_size=8)
|
| 161 |
+
|
| 162 |
+
if solid_edges:
|
| 163 |
+
nx.draw_networkx_edges(
|
| 164 |
+
G, pos, ax=ax1, edgelist=solid_edges,
|
| 165 |
+
edge_color=solid_colors, width=2, style='solid'
|
| 166 |
+
)
|
| 167 |
+
if dashed_edges:
|
| 168 |
+
nx.draw_networkx_edges(
|
| 169 |
+
G, pos, ax=ax1, edgelist=dashed_edges,
|
| 170 |
+
edge_color=dashed_colors, width=1, style='dashed'
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Legend
|
| 174 |
+
legend_elements = [
|
| 175 |
+
Line2D([0], [0], marker='o', color='w',
|
| 176 |
+
markerfacecolor='#2ecc71', markersize=10,
|
| 177 |
+
label='Generator/Slack'),
|
| 178 |
+
Line2D([0], [0], marker='o', color='w',
|
| 179 |
+
markerfacecolor='#e74c3c', markersize=10,
|
| 180 |
+
label='Load'),
|
| 181 |
+
Line2D([0], [0], marker='o', color='w',
|
| 182 |
+
markerfacecolor='#3498db', markersize=10,
|
| 183 |
+
label='Battery'),
|
| 184 |
+
Line2D([0], [0], marker='o', color='w',
|
| 185 |
+
markerfacecolor='#f1c40f', markersize=10,
|
| 186 |
+
label='Renewable'),
|
| 187 |
+
]
|
| 188 |
+
ax1.legend(handles=legend_elements, loc='upper left', fontsize=7)
|
| 189 |
+
else:
|
| 190 |
+
ax1.text(0.5, 0.5, "No buses in observation",
|
| 191 |
+
ha='center', va='center', transform=ax1.transAxes)
|
| 192 |
+
|
| 193 |
+
ax1.set_title("Grid Topology & Loading")
|
| 194 |
+
|
| 195 |
+
# --- Plot 2: Frequency Trace ---
|
| 196 |
+
if history:
|
| 197 |
+
history_sorted = sorted(history, key=lambda h: h.timestep)
|
| 198 |
+
timesteps = [h.timestep for h in history_sorted]
|
| 199 |
+
freqs = [h.grid_frequency for h in history_sorted]
|
| 200 |
+
|
| 201 |
+
ax2.plot(timesteps, freqs, label='Frequency (Hz)',
|
| 202 |
+
color='#2980b9', linewidth=1.5)
|
| 203 |
+
ax2.axhline(y=50.0, color='k', linestyle='--', linewidth=0.8)
|
| 204 |
+
ax2.fill_between(timesteps, 49.5, 50.5,
|
| 205 |
+
color='green', alpha=0.1, label='Normal band')
|
| 206 |
+
ax2.legend(fontsize=8)
|
| 207 |
+
else:
|
| 208 |
+
ax2.text(0.5, 0.5, "No frequency history",
|
| 209 |
+
ha='center', va='center', transform=ax2.transAxes)
|
| 210 |
+
|
| 211 |
+
ax2.set_title("Frequency Stability")
|
| 212 |
+
ax2.set_xlabel("Timestep")
|
| 213 |
+
ax2.set_ylabel("Hz")
|
| 214 |
+
ax2.set_ylim(48.5, 51.5)
|
| 215 |
+
|
| 216 |
+
fig.tight_layout()
|
| 217 |
+
|
| 218 |
+
buf = io.BytesIO()
|
| 219 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
| 220 |
+
buf.seek(0)
|
| 221 |
+
return base64.b64encode(buf.read()).decode('utf-8')
|
| 222 |
+
|
| 223 |
+
finally:
|
| 224 |
+
plt.close(fig)
|
static/app.js
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// OpenGrid Control Room
|
| 2 |
+
const API = window.location.origin;
|
| 3 |
+
const AGENT_COLORS = ['#00bfff','#ff69b4','#ff6347','#32cd32','#9370db','#ffa500'];
|
| 4 |
+
const AGENT_NAMES = ['Bengaluru','Mysuru','Kalburagi','Hassan','Tumakuru','Bagalkot'];
|
| 5 |
+
|
| 6 |
+
// Real Karnataka state boundary path (source: @svg-maps/india)
|
| 7 |
+
const KARNATAKA_PATH = "m 124.338,505.46021 -0.617,-0.44733 0.776,-0.16422 -0.063,-0.8604 1.544,-0.77275 0.48,-0.70223 0.476,0.96821 0.881,0.0413 1.521,-0.74857 0.512,-1.53442 -0.938,-0.17228 0.62,-0.86141 0.404,0.86745 0.379,-0.0181 -0.412,-1.05888 1.641,-3.03861 -0.711,-0.35364 -0.968,0.47151 -0.458,-0.38889 1.391,-1.25837 1.141,0.50879 -0.068,-1.30269 0.567,-0.8997 -0.205,-0.93495 -1.688,-0.57629 -0.027,-0.50476 -1.422,-0.24583 -0.407,0.51987 0.312,-0.51181 -0.538,-0.73446 0.051,-1.1828 0.369,-0.24886 0.389,0.56622 0.156,-0.64581 -0.554,-0.135 -0.079,-1.12941 -0.891,-0.14911 0.075,-0.95309 -0.652,0.58133 -0.327,-0.41207 0.683,-0.18639 -0.196,-0.9007 0.79,0.92891 0.32,-1.12336 0.758,-0.0786 -0.063,0.39998 0.572,0.23676 0.284,-1.11026 1.444,-0.57126 0.104,-1.2241 0.432,0.74655 1.118,-0.14407 0.474,1.77622 1.304,-0.51987 0.135,-0.67805 0.996,0.0504 -0.625,-0.72439 0.746,-0.8191 0.043,-0.88055 3.282,-1.21706 1.441,0.0192 -0.248,-1.88302 1.091,-0.48057 0.066,-0.60249 -0.842,-0.44329 0.238,-0.33752 1.924,-0.0121 0.034,0.3486 1.225,-0.50375 1.062,1.64625 1.016,0 -0.135,0.69014 0.684,0.0373 1.401,-0.74252 0.119,-1.76514 1.19,0.0494 1.035,-0.52289 0.759,0.28311 0.772,-0.47957 0.515,0.92992 1.629,-0.45438 0.114,-0.9672 0.706,0.10276 0.024,0.73447 0.719,0.40703 0.619,-0.20251 -0.049,-1.65431 -0.596,-0.0151 0.725,-0.57931 0.002,-0.68712 -1.057,-1.6664 0.714,-0.83722 -0.047,-1.16568 -1.129,-0.91884 0.15,-0.85738 -0.592,-0.16422 -0.131,-0.72741 0.78,-0.19646 0.414,-1.88201 0.878,0.4302 0.285,0.99642 0.96,-0.20352 1.367,1.13646 0.469,-1.15761 0.779,0.81405 0.529,-0.69215 0.134,1.39841 0.785,0.64883 2.583,-0.66294 0.506,0.53196 0.889,-0.79693 0.877,0.55916 0.264,0.96015 -0.072,-1.13243 1.508,-0.56823 0.659,0.96922 1.418,-0.42618 0.181,0.86343 0.616,-0.0262 0.552,-1.2634 -0.964,-0.12593 0.234,-1.3037 -0.827,0.0463 -0.06,-0.80197 0.926,-0.54304 -0.661,-0.0191 0.474,-0.61155 -0.546,-0.44733 -0.175,-1.14955 1.758,-0.20553 0.273,-0.88459 1.268,-0.35766 -0.062,-1.16265 0.781,-0.0373 0.001,-0.96115 1.038,-0.0242 -0.001,1.27348 0.863,-1.45483 1.02,1.77017 0.573,0.1743 0.159,-1.01455 0.617,-0.24079 -0.249,-0.98735 0.985,-0.11384 0.532,-0.86746 -1.061,-0.67301 0.067,-0.90271 1.3,0.65386 2.379,-1.03067 0.026,-2.60337 0.773,-0.14206 -0.16,-0.75159 0.445,-0.51584 -0.957,-0.41912 0.661,-1.51628 0.707,-0.0796 0.755,0.56923 0.186,-0.46546 0.52,0.69316 1.072,-0.008 -0.279,-0.93496 1.14,-0.47453 0.43,-1.41956 0.746,0.0645 -0.226,-0.76772 1.039,-1.27851 -0.101,-0.84126 1.616,-0.99742 0.517,0.51987 0.577,-0.38386 0.002,1.03772 0.845,0.269 -1.074,1.7198 1.624,0.0917 0.607,1.02866 0.938,-0.40804 -0.015,-0.62465 0.847,0.33953 0,0 1.11,0.2952 -0.81,1.70972 0.701,1.07298 0.059,1.15661 -1.148,1.00649 0.974,0.96115 1.129,0.37378 0.151,0.52592 -0.197,0.50576 -0.424,-0.25087 -0.15,1.209 -0.657,-0.11788 0.241,0.83219 -0.501,-0.0524 -0.482,1.20598 -0.497,-0.19243 -0.316,0.55916 -0.134,0.41509 1.287,0.40501 -0.083,0.37479 -2.338,1.22814 -0.218,2.41597 1.049,0.33349 0.243,0.55815 0.54,-0.71029 0.439,0.5229 0.867,-0.29319 0.04,0.66193 1.965,0.59442 -0.034,0.72036 -0.752,-0.18336 0.098,-0.48461 -0.258,0.59946 -0.617,0.134 -0.007,0.56521 -0.783,-0.21964 -0.013,0.54203 -1.307,0.0504 0.531,0.50879 -0.157,0.70222 -0.605,0.39595 -0.995,-0.35968 -0.368,1.80544 0.429,0.27202 -1.552,1.23318 0.386,0.24079 -0.812,1.03369 2.148,1.26239 0.77,2.12078 -0.963,2.15403 0.372,2.84517 -0.704,-0.0887 -0.296,1.50218 0.909,0.0564 -0.037,0.73648 -1.015,0.33852 0.343,0.5511 0.763,-0.2025 -0.109,1.3712 -1.522,0.41509 0.5,1.0357 -0.758,0.15516 -0.268,0.58132 -1.458,-0.16019 -0.097,0.3899 -1.189,0.12191 1.036,1.42158 1.22,0.50879 1.44,0.37176 1.732,-0.28613 2.033,0.83622 -0.027,1.10724 -0.53,-0.23676 -0.653,0.7657 -0.682,-0.11284 -0.286,0.39393 0.025,0.55614 0.46,0.0212 -0.568,1.41352 0.064,0.93395 0.476,0.26698 -0.391,1.59084 0.405,0.6186 -0.014,1.7742 0,0 -5.454,-0.80499 -2.208,0.37379 -1.622,0.9007 -0.915,1.47195 0.871,0.1884 -0.433,1.93137 1.711,1.64222 -0.184,0.7385 -0.728,-0.6045 -1.092,0.41408 -0.056,2.91167 -1.145,-0.13803 -0.032,0.3355 1.193,1.21202 0.715,2.46333 1.007,-0.11788 0.12,0.81406 0.68,0.0907 0.34,0.5773 -0.906,0.82212 0.78,0.59543 -0.01,0.97425 -0.536,0.13601 0.459,0.48158 -1.574,2.61647 -0.792,-0.11788 0.123,-0.51583 -0.967,-0.008 -0.395,0.4171 -2.2,-0.39796 -1.67,-1.34803 -0.475,0.42113 0.216,1.18784 -0.435,1.01455 1.342,0.40904 0.765,-0.2821 -0.329,3.46982 -1.432,0.53599 0.371,0.96821 -0.793,2.87338 0.828,1.60897 0.583,0.10075 0.893,1.1828 2.16,-0.21258 -0.62,0.71835 -0.046,0.98633 -0.596,-0.30325 -0.627,0.50375 -0.084,0.94805 1.486,1.13344 -0.528,0.47251 0.271,0.69316 2.34,0.0121 0.538,0.65286 0.623,-0.0846 0.143,-1.60595 0.842,-1.08709 1.67,0.57729 1.03,-0.43423 -0.033,1.19086 1.667,0.1471 -0.081,0.84932 0.594,0.82615 0.668,-0.43524 -0.852,-1.8004 0.223,-0.64278 1.187,0.32743 0.259,0.81305 0.87,0.009 0.087,2.61143 -2.317,-0.3093 -0.272,0.77174 -0.606,0.11284 -0.067,0.61155 0.946,-0.48662 0.12,0.32643 -1.45,1.73995 1.197,0.3365 -0.162,0.82514 1.151,0.008 -0.48,0.70525 0.413,0.58032 -0.744,0.63372 -0.03,-0.38487 -0.881,0.0242 0.114,-1.85279 -0.843,-0.91682 -3.478,0.71734 -0.549,-1.09213 -2.039,-0.23978 -0.322,-0.96921 0.241,-1.61301 -1.637,-0.35766 -0.098,0.50577 -0.954,0.14911 -0.162,0.60449 0.907,0.0826 1.005,1.33896 -0.919,0.91884 2.193,2.09761 0.114,0.60853 -0.502,0.0876 -1.023,1.70468 0.506,1.54953 0.395,0.21158 0.313,-0.83522 0.706,0.70324 0.737,-0.47554 1.493,0.134 -0.091,-1.42461 -0.803,-0.6186 0.809,-0.16422 -0.137,-0.92488 0.441,-0.37983 0.037,1.24325 0.547,-0.46042 0.138,0.42617 0.467,-0.72339 0.348,1.19691 1.182,-0.38386 0.274,0.68006 0.826,-0.4302 1.362,0.2277 -0.332,0.77476 1.021,0.0474 -0.161,2.61646 0.695,-0.0846 0.092,-0.58435 0.522,0.45539 0.154,-1.25535 0.762,0.59141 0.828,-0.58536 0.537,0.3496 0.324,-0.16926 -0.55,-0.43624 0.809,-0.44028 0.442,-0.0363 -0.136,0.54505 0.666,-0.19746 0.276,0.6186 0.086,-1.24527 1.374,-0.48259 -0.051,-0.49669 1.082,-0.53297 -0.447,-1.03671 0.25,-0.56723 1.438,0.0796 0.515,0.73345 1.148,-1.15156 0.243,1.23519 -0.745,0.2831 0.044,1.30169 0.444,-0.005 0.406,-0.89566 1.102,0.18941 0.07,-0.73145 1.516,0.98937 0.098,1.37926 -0.697,0.93798 0.512,0.50173 -0.084,0.55715 -0.865,-0.0816 -0.12,0.4574 0.469,0.68309 1.57,-0.28815 0.1,0.54506 0.7,0.15616 -0.224,1.28859 0.93,-0.66394 3.414,0.19545 -0.746,4.80576 0.884,0.48965 -0.636,0.26497 0.508,0.35262 0.695,-0.33146 0.241,0.44632 0.749,-0.20553 1.027,1.12638 0.729,-0.94402 0.457,0.80499 -0.184,1.24425 -0.581,0.36573 0.589,0.66091 -1.263,0.79391 0.402,0.47957 -0.545,0.33751 0.056,0.62163 -1.11,0.45639 0.133,1.46793 -0.738,-0.11486 0.275,1.46087 -1.203,0.11788 -0.689,-0.70726 -0.886,1.73994 -1.298,-0.005 -0.428,2.03715 0,0 -2.093,-0.37478 -1.548,-1.55457 -0.666,-0.0756 -0.281,1.08406 -0.42,-0.004 -0.75,-1.15459 -0.435,0.7657 -0.326,-0.17833 0.528,-0.73245 -0.35,-0.48057 -2.781,0.95812 0.306,1.00952 -1.425,2.63964 -0.578,-0.20956 -0.533,0.52994 -0.504,-0.54002 -1.339,0.35666 0.157,0.78484 -0.582,1.35407 0.177,1.09314 0.583,0.15515 -0.649,0.67402 1.043,-0.19042 -0.107,1.47699 -0.34,1.17273 -1.279,1.50318 -1.518,0.46849 -0.095,1.27851 5.457,0.74958 0.881,1.32285 -1.654,2.04924 -0.607,1.53744 -3.686,0.12292 -0.157,1.20799 -0.505,-0.269 0.073,1.05888 -0.775,1.89308 -1.251,-0.60147 -0.699,0.42415 -0.864,-0.84327 -0.902,-0.0877 -0.308,0.39997 -2.601,0.44129 0.136,1.076 -0.789,-0.26195 -0.316,-1.11429 -0.716,0.26598 0.195,-0.41106 -0.57,-0.40803 -0.663,0.0413 -0.276,0.76872 -1.254,-0.38788 -1.49,2.97816 0.469,0.90473 -0.285,0.58435 -0.435,-0.48562 -1.471,-0.28512 -3.897,0.009 -0.412,-1.29161 -0.758,-0.58939 -1.106,0.91783 -0.584,-0.0897 0,0 -0.566,-0.89365 0.471,-0.34255 -0.235,-0.77678 -1.521,0.48561 -1.318,-1.56061 -1.12,0.0746 -0.722,-1.36415 -1.59,0.27001 -0.003,-2.55803 -2.375,1.01354 -2.464,-0.37278 -1.096,-0.93294 -0.517,-1.86185 -1.73,0.19444 -0.323,-0.81909 -0.82,0.19545 -0.572,-1.09817 -1.219,-0.17933 -1.97,-2.90361 -1.331,-0.005 0.047,-1.72382 -1.168,-0.86343 0.021,-0.95712 1.17,-0.18034 -0.168,-0.78686 -1.542,0.8725 -0.125,-0.73447 -1.125,-0.59946 -0.09,-0.98634 1.068,-0.45136 -1.071,-0.56823 -1.126,1.05183 -0.449,-1.34098 -0.885,0.17329 -0.339,-0.3496 0.161,-0.92388 -1.351,-0.46042 -0.063,0.67401 -0.739,0.13602 0.039,-1.17374 -0.891,0.11788 0.106,-0.29318 -0.574,-0.15012 0.499,-0.65689 -0.342,-0.6448 -2.621,0.77376 0,0 -0.965,-2.10365 -2.634,-10.6573 -0.512,-6.16488 -1.337,-5.02237 -0.768,-1.72786 -0.809,-0.39594 -0.627,-1.24728 -0.64,-3.47486 -0.611,-0.87048 -1.843,-6.03994 0.826,-0.61357 -0.599,-0.48662 -0.181,0.68611 -0.971,0.0302 -0.313,-1.75002 -0.524,-0.54808 0.32,-0.28814 -0.384,-1.61905 -0.669,-0.71633 -0.622,0.65084 -2.291,-1.75103 0.587,-0.13299 0.157,-0.89768 -0.396,-0.0121 -0.308,-1.05989 0,0 0.879,-0.538 0.754,0.24986 -0.068,-0.91279 0.831,0.64278 1.22,-0.98231 -0.176,-0.52289 0.851,-1.06593 -0.502,-1.21605 0.235,-1.02664 0.676,-0.8876 -0.318,-0.85033 -1.029,-0.74857 1.761,-0.76671 -0.278,-1.61602 -0.957,-0.46446 -0.003,-1.18481 -0.548,-1.00952 0.647,-0.7939 -0.69,-0.86645 0.298,-0.95108 -0.278,-0.97123 -0.496,-0.26799 -0.384,0.38083 -0.483,-0.69517 -0.35,0.62969 -0.986,0.0383 z";
|
| 8 |
+
|
| 9 |
+
let state = {
|
| 10 |
+
sessionId: null, task: 'task_karnataka', step: 0, done: false,
|
| 11 |
+
numAgents: 0, zoneInfo: {}, observations: {}, taskConfigs: {},
|
| 12 |
+
rewardHistory: [], freqHistory: [], perAgentRewards: {},
|
| 13 |
+
totalReward: 0, autoRunning: false, autoTimer: null,
|
| 14 |
+
safetyTotal: 0, lastOversight: null, mapScale: 1, alarms: []
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
// --- Init ---
|
| 18 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 19 |
+
document.querySelectorAll('.task-btn').forEach(btn => {
|
| 20 |
+
btn.addEventListener('click', () => {
|
| 21 |
+
document.querySelectorAll('.task-btn').forEach(b => b.classList.remove('active'));
|
| 22 |
+
btn.classList.add('active');
|
| 23 |
+
state.task = btn.dataset.task;
|
| 24 |
+
});
|
| 25 |
+
});
|
| 26 |
+
fetch(`${API}/tasks`).then(r=>r.json()).then(d=>{
|
| 27 |
+
d.forEach(t => state.taskConfigs[t.id] = t);
|
| 28 |
+
resetEpisode(); // reset only after configs are loaded
|
| 29 |
+
setTimeout(() => document.getElementById('loading').classList.add('hidden'), 800);
|
| 30 |
+
});
|
| 31 |
+
});
|
| 32 |
+
|
| 33 |
+
// --- API Calls ---
|
| 34 |
+
async function resetEpisode() {
|
| 35 |
+
stopAutoRun();
|
| 36 |
+
state.step = 0; state.done = false; state.totalReward = 0;
|
| 37 |
+
state.rewardHistory = []; state.freqHistory = []; state.safetyTotal = 0;
|
| 38 |
+
state.alarms = [];
|
| 39 |
+
mapFitted = false;
|
| 40 |
+
document.getElementById('alarmLog').innerHTML = '';
|
| 41 |
+
document.getElementById('simStatus').textContent = 'RUNNING';
|
| 42 |
+
try {
|
| 43 |
+
const r = await fetch(`${API}/reset_multi?task_id=${state.task}`, {method:'POST'});
|
| 44 |
+
const d = await r.json();
|
| 45 |
+
state.sessionId = d.session_id;
|
| 46 |
+
state.numAgents = d.num_agents;
|
| 47 |
+
state.zoneInfo = d.zone_info;
|
| 48 |
+
state.observations = d.observations;
|
| 49 |
+
state.perAgentRewards = {};
|
| 50 |
+
for (let i = 0; i < d.num_agents; i++) state.perAgentRewards[i] = [];
|
| 51 |
+
updateAll();
|
| 52 |
+
} catch(e) { showAlert('critical', 'Reset failed: ' + e.message); }
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
async function stepEpisode() {
|
| 56 |
+
if (!state.sessionId || state.done) return;
|
| 57 |
+
const actions = {};
|
| 58 |
+
for (let i = 0; i < state.numAgents; i++) {
|
| 59 |
+
const obs = state.observations[String(i)];
|
| 60 |
+
actions[String(i)] = generateHeuristicAction(i, obs);
|
| 61 |
+
}
|
| 62 |
+
try {
|
| 63 |
+
const r = await fetch(`${API}/step_multi?session_id=${state.sessionId}`, {
|
| 64 |
+
method: 'POST', headers: {'Content-Type':'application/json'},
|
| 65 |
+
body: JSON.stringify({agent_actions: actions})
|
| 66 |
+
});
|
| 67 |
+
const d = await r.json();
|
| 68 |
+
state.step++;
|
| 69 |
+
state.observations = d.observations;
|
| 70 |
+
state.totalReward += d.team_reward;
|
| 71 |
+
state.rewardHistory.push(d.team_reward);
|
| 72 |
+
state.lastOversight = d.oversight_report;
|
| 73 |
+
state.done = d.done;
|
| 74 |
+
const freq = getAvgFreq(d.observations);
|
| 75 |
+
state.freqHistory.push(freq);
|
| 76 |
+
// safety_reports is a string-keyed dict {"0": {...}, "1": {...}}, not an array
|
| 77 |
+
Object.values(d.safety_reports || {}).forEach(sr => { if (sr.was_corrected) state.safetyTotal++; });
|
| 78 |
+
for (const [aid, rew] of Object.entries(d.rewards)) {
|
| 79 |
+
if (!state.perAgentRewards[aid]) state.perAgentRewards[aid] = [];
|
| 80 |
+
state.perAgentRewards[aid].push(rew.value);
|
| 81 |
+
}
|
| 82 |
+
if (d.done) {
|
| 83 |
+
document.getElementById('simStatus').textContent = d.info.is_blackout ? 'BLACKOUT' : 'COMPLETE';
|
| 84 |
+
stopAutoRun();
|
| 85 |
+
}
|
| 86 |
+
updateAll(d);
|
| 87 |
+
} catch(e) { showAlert('critical', 'Step failed: ' + e.message); stopAutoRun(); }
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
async function getGrade() {
|
| 91 |
+
if (!state.sessionId) return;
|
| 92 |
+
try {
|
| 93 |
+
const r = await fetch(`${API}/grader?session_id=${state.sessionId}`);
|
| 94 |
+
const d = await r.json();
|
| 95 |
+
document.getElementById('episodeScore').textContent = d.score.toFixed(4);
|
| 96 |
+
document.getElementById('episodeScore').style.color =
|
| 97 |
+
d.score > 0.7 ? 'var(--status-normal)' : d.score > 0.4 ? 'var(--status-warning)' : 'var(--status-critical)';
|
| 98 |
+
} catch(e) { showAlert('warning', 'Grade failed: ' + e.message); }
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// --- Heuristic Agent ---
|
| 102 |
+
function generateHeuristicAction(agentId, obs) {
|
| 103 |
+
if (!obs) return {bus_adjustments: [], topology_actions: []};
|
| 104 |
+
const freq = obs.grid_frequency || 50;
|
| 105 |
+
const error = 50.0 - freq;
|
| 106 |
+
const buses = obs.local_buses || [];
|
| 107 |
+
const adjs = [];
|
| 108 |
+
buses.forEach(b => {
|
| 109 |
+
// Exclude slack — physics solver overwrites its injection; adjusting it wastes the action
|
| 110 |
+
if (b.type === 'battery' || b.type === 'generator') {
|
| 111 |
+
let delta = error * 8;
|
| 112 |
+
delta = Math.max(-15, Math.min(15, delta));
|
| 113 |
+
if (Math.abs(delta) > 0.5) adjs.push({bus_id: b.id, delta: Math.round(delta*10)/10});
|
| 114 |
+
}
|
| 115 |
+
});
|
| 116 |
+
return {bus_adjustments: adjs, topology_actions: []};
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// --- Auto Run ---
|
| 120 |
+
function toggleAutoRun() {
|
| 121 |
+
if (state.autoRunning) { stopAutoRun(); }
|
| 122 |
+
else { state.autoRunning = true; document.getElementById('btnAutoRun').classList.add('active'); autoStep(); }
|
| 123 |
+
}
|
| 124 |
+
function stopAutoRun() {
|
| 125 |
+
state.autoRunning = false;
|
| 126 |
+
if (state.autoTimer) clearTimeout(state.autoTimer);
|
| 127 |
+
document.getElementById('btnAutoRun').classList.remove('active');
|
| 128 |
+
}
|
| 129 |
+
async function autoStep() {
|
| 130 |
+
if (!state.autoRunning || state.done) { stopAutoRun(); return; }
|
| 131 |
+
await stepEpisode();
|
| 132 |
+
if (state.autoRunning && !state.done) state.autoTimer = setTimeout(autoStep, 200);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// --- UI Updates ---
|
| 136 |
+
function updateAll(stepData) {
|
| 137 |
+
updateHeader();
|
| 138 |
+
updateFrequency();
|
| 139 |
+
updateSystemSummary();
|
| 140 |
+
updateOversight();
|
| 141 |
+
updateAgentCards(stepData);
|
| 142 |
+
updateLeaderboard();
|
| 143 |
+
updateGridMap();
|
| 144 |
+
updateCharts();
|
| 145 |
+
updateAlarmLog(stepData);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
function getAvgFreq(obs) {
|
| 149 |
+
let sum=0, n=0;
|
| 150 |
+
for (const o of Object.values(obs||state.observations)) { sum += (o.grid_frequency||50); n++; }
|
| 151 |
+
return n ? sum/n : 50;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
function updateHeader() {
|
| 155 |
+
const maxSteps = state.taskConfigs[state.task]?.max_steps || 50;
|
| 156 |
+
document.getElementById('headerStep').textContent = `${state.step} / ${maxSteps}`;
|
| 157 |
+
document.getElementById('headerAgents').textContent = `${state.numAgents} Active`;
|
| 158 |
+
document.getElementById('headerReward').textContent = state.totalReward.toFixed(2);
|
| 159 |
+
document.getElementById('headerEpisode').textContent = state.task.replace('task_','').toUpperCase();
|
| 160 |
+
const freq = getAvgFreq();
|
| 161 |
+
const el = document.getElementById('headerFreq');
|
| 162 |
+
el.textContent = freq.toFixed(2) + ' Hz';
|
| 163 |
+
el.className = 'value ' + freqClass(freq);
|
| 164 |
+
document.getElementById('totalSteps').textContent = state.step;
|
| 165 |
+
document.getElementById('blackoutStatus').textContent = state.done && document.getElementById('simStatus').textContent==='BLACKOUT' ? 'Yes' : 'No';
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
function updateFrequency() {
|
| 169 |
+
const freq = getAvgFreq();
|
| 170 |
+
const cls = freqClass(freq);
|
| 171 |
+
const colors = {normal:'#00e5a0',warning:'#ffd700',critical:'#ff3d3d'};
|
| 172 |
+
const col = colors[cls];
|
| 173 |
+
// Arc gauge
|
| 174 |
+
const container = document.getElementById('freqArc');
|
| 175 |
+
const W=200, H=110, cx=100, cy=100, r=80;
|
| 176 |
+
const minF=49, maxF=51;
|
| 177 |
+
const pct = Math.max(0,Math.min(1,(freq-minF)/(maxF-minF)));
|
| 178 |
+
const startA=Math.PI, endA=0;
|
| 179 |
+
const needleA = startA - pct*(startA-endA);
|
| 180 |
+
let svg = `<svg viewBox="0 0 ${W} ${H}" xmlns="http://www.w3.org/2000/svg">`;
|
| 181 |
+
// Background arc
|
| 182 |
+
svg += `<path d="M${cx-r},${cy} A${r},${r} 0 0,1 ${cx+r},${cy}" fill="none" stroke="rgba(255,255,255,0.06)" stroke-width="10" stroke-linecap="round"/>`;
|
| 183 |
+
// Colored segments
|
| 184 |
+
const segs = [{f:49,t:49.5,c:'#ff3d3d'},{f:49.5,t:49.85,c:'#ffd700'},{f:49.85,t:50.15,c:'#00e5a0'},{f:50.15,t:50.5,c:'#ffd700'},{f:50.5,t:51,c:'#ff3d3d'}];
|
| 185 |
+
segs.forEach(s => {
|
| 186 |
+
const a1=Math.PI-((s.f-minF)/(maxF-minF))*Math.PI;
|
| 187 |
+
const a2=Math.PI-((s.t-minF)/(maxF-minF))*Math.PI;
|
| 188 |
+
const x1=cx+r*Math.cos(a1),y1=cy-r*Math.sin(a1);
|
| 189 |
+
const x2=cx+r*Math.cos(a2),y2=cy-r*Math.sin(a2);
|
| 190 |
+
svg += `<path d="M${x1},${y1} A${r},${r} 0 0,0 ${x2},${y2}" fill="none" stroke="${s.c}" stroke-width="6" opacity="0.25" stroke-linecap="round"/>`;
|
| 191 |
+
});
|
| 192 |
+
// Needle
|
| 193 |
+
const nx=cx+(r-12)*Math.cos(needleA), ny=cy-(r-12)*Math.sin(needleA);
|
| 194 |
+
svg += `<line x1="${cx}" y1="${cy}" x2="${nx}" y2="${ny}" stroke="${col}" stroke-width="2.5" stroke-linecap="round"/>`;
|
| 195 |
+
svg += `<circle cx="${cx}" cy="${cy}" r="4" fill="${col}"/>`;
|
| 196 |
+
// Value text
|
| 197 |
+
svg += `<text x="${cx}" y="${cy-20}" text-anchor="middle" fill="${col}" font-family="JetBrains Mono" font-size="28" font-weight="700" style="text-shadow:0 0 15px ${col}40">${freq.toFixed(2)}</text>`;
|
| 198 |
+
svg += `<text x="${cx}" y="${cy-6}" text-anchor="middle" fill="#90a4ae" font-family="Inter" font-size="11">Hz</text>`;
|
| 199 |
+
// Scale labels
|
| 200 |
+
svg += `<text x="18" y="${cy+14}" fill="#546e7a" font-size="8" font-family="JetBrains Mono">49.0</text>`;
|
| 201 |
+
svg += `<text x="${W-30}" y="${cy+14}" fill="#546e7a" font-size="8" font-family="JetBrains Mono">51.0</text>`;
|
| 202 |
+
svg += `<text x="${cx}" y="12" text-anchor="middle" fill="#546e7a" font-size="8" font-family="JetBrains Mono">50.0</text>`;
|
| 203 |
+
svg += '</svg>';
|
| 204 |
+
container.innerHTML = svg;
|
| 205 |
+
document.getElementById('freqDev').textContent = `Deviation: ${(freq-50).toFixed(3)} Hz | Nominal: 50.00 Hz`;
|
| 206 |
+
// Grid condition
|
| 207 |
+
const gc = document.getElementById('gridCondition');
|
| 208 |
+
const dev = Math.abs(freq-50);
|
| 209 |
+
if(dev<0.15){gc.textContent='NORMAL';gc.className='grid-condition normal';}
|
| 210 |
+
else if(dev<0.3){gc.textContent='CONSERVATIVE OPS';gc.className='grid-condition conservative';}
|
| 211 |
+
else if(dev<0.5){gc.textContent='CONSERVATION ALERT';gc.className='grid-condition alert';}
|
| 212 |
+
else{gc.textContent='EMERGENCY';gc.className='grid-condition emergency';}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
function freqClass(f) { return Math.abs(f-50)<0.5?'normal':Math.abs(f-50)<1?'warning':'critical'; }
|
| 216 |
+
|
| 217 |
+
function updateSystemSummary() {
|
| 218 |
+
let gen=0, load=0, lines=0, overloaded=0, totalLines=0;
|
| 219 |
+
for (const obs of Object.values(state.observations)) {
|
| 220 |
+
gen += obs.zone_gen_mw || 0;
|
| 221 |
+
load += obs.zone_load_mw || 0;
|
| 222 |
+
(obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
|
| 223 |
+
totalLines++; if(l.connected) lines++;
|
| 224 |
+
if(l.connected && l.rho > 1) overloaded++;
|
| 225 |
+
});
|
| 226 |
+
}
|
| 227 |
+
document.getElementById('totalGen').textContent = gen.toFixed(1) + ' MW';
|
| 228 |
+
document.getElementById('totalLoad').textContent = load.toFixed(1) + ' MW';
|
| 229 |
+
document.getElementById('netBalance').textContent = (gen-load).toFixed(1) + ' MW';
|
| 230 |
+
document.getElementById('linesConnected').textContent = `${lines} / ${totalLines}`;
|
| 231 |
+
const olEl = document.getElementById('linesOverloaded');
|
| 232 |
+
olEl.textContent = overloaded;
|
| 233 |
+
olEl.style.color = overloaded > 0 ? 'var(--status-critical)' : 'var(--status-normal)';
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
function updateOversight() {
|
| 237 |
+
const o = state.lastOversight;
|
| 238 |
+
if (!o) return;
|
| 239 |
+
const cs = document.getElementById('coordScore');
|
| 240 |
+
cs.textContent = o.coordination_score.toFixed(2);
|
| 241 |
+
cs.style.color = o.coordination_score > 0.7 ? 'var(--status-normal)' : o.coordination_score > 0.4 ? 'var(--status-warning)' : 'var(--status-critical)';
|
| 242 |
+
document.getElementById('conflicts').textContent = o.conflicting_actions_detected;
|
| 243 |
+
document.getElementById('safetyCorrTotal').textContent = state.safetyTotal;
|
| 244 |
+
document.getElementById('selfishActions').textContent = o.selfish_actions_detected;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
function updateAlarmLog(stepData) {
|
| 248 |
+
if (!stepData) return;
|
| 249 |
+
const logEl = document.getElementById('alarmLog');
|
| 250 |
+
let newAlarms = [];
|
| 251 |
+
const timeStr = `T+${String(state.step).padStart(2,'0')}s`;
|
| 252 |
+
|
| 253 |
+
// Check frequency
|
| 254 |
+
const freq = getAvgFreq();
|
| 255 |
+
if (Math.abs(freq - 50) > 0.5) {
|
| 256 |
+
newAlarms.push({t: timeStr, msg: `FREQ DEVIATION: ${freq.toFixed(2)} Hz`, type: Math.abs(freq-50)>1?'crit':'warn'});
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Check lines and safety
|
| 260 |
+
for (const [aid, obs] of Object.entries(state.observations)) {
|
| 261 |
+
(obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
|
| 262 |
+
if (l.rho > 1.0) newAlarms.push({t: timeStr, msg: `OVERLOAD: Line ${l.id} at ${(l.rho*100).toFixed(0)}%`, type: 'crit'});
|
| 263 |
+
else if (l.rho > 0.9) newAlarms.push({t: timeStr, msg: `CONGESTION: Line ${l.id} at ${(l.rho*100).toFixed(0)}%`, type: 'warn'});
|
| 264 |
+
});
|
| 265 |
+
const sr = stepData.safety_reports?.[aid];
|
| 266 |
+
if (sr && sr.was_corrected) {
|
| 267 |
+
newAlarms.push({t: timeStr, msg: `AGENT ${aid} SAFETY CORRECTED`, type: 'warn'});
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
if (state.done && document.getElementById('simStatus').textContent==='BLACKOUT') {
|
| 272 |
+
newAlarms.push({t: timeStr, msg: `SYSTEM COLLAPSE - BLACKOUT`, type: 'crit'});
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if (newAlarms.length > 0) {
|
| 276 |
+
state.alarms = [...newAlarms, ...state.alarms].slice(0, 50); // Keep last 50
|
| 277 |
+
logEl.innerHTML = state.alarms.map(a => `<div class="alarm-entry ${a.type}"><span class="alarm-time">[${a.t}]</span>${a.msg}</div>`).join('');
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
function updateAgentCards(stepData) {
|
| 282 |
+
const container = document.getElementById('agentCards');
|
| 283 |
+
container.innerHTML = '';
|
| 284 |
+
for (let i = 0; i < state.numAgents; i++) {
|
| 285 |
+
const obs = state.observations[String(i)];
|
| 286 |
+
const zi = state.zoneInfo[String(i)] || {};
|
| 287 |
+
const sr = stepData?.safety_reports?.[String(i)];
|
| 288 |
+
const rew = stepData?.rewards?.[String(i)];
|
| 289 |
+
const cumReward = (state.perAgentRewards[i]||[]).reduce((a,b)=>a+b,0);
|
| 290 |
+
const wasCorrected = sr?.was_corrected || false;
|
| 291 |
+
const cardClass = wasCorrected ? 'warning' : 'active';
|
| 292 |
+
const html = `
|
| 293 |
+
<div class="agent-card ${cardClass}">
|
| 294 |
+
<div class="agent-header">
|
| 295 |
+
<div class="agent-name">
|
| 296 |
+
<span class="agent-dot" style="background:${AGENT_COLORS[i]}"></span>
|
| 297 |
+
Agent ${i} - ${zi.zone_name||AGENT_NAMES[i]}
|
| 298 |
+
</div>
|
| 299 |
+
<span class="agent-status-badge ${wasCorrected?'corrected':'active'}">${wasCorrected?'Corrected':'Safe'}</span>
|
| 300 |
+
</div>
|
| 301 |
+
<div class="agent-metrics">
|
| 302 |
+
<div class="agent-metric">
|
| 303 |
+
<div class="label">Step Reward</div>
|
| 304 |
+
<div class="value" style="color:${(rew?.value||0)>=0?'var(--status-normal)':'var(--status-critical)'}">${(rew?.value||0).toFixed(2)}</div>
|
| 305 |
+
</div>
|
| 306 |
+
<div class="agent-metric">
|
| 307 |
+
<div class="label">Cumulative</div>
|
| 308 |
+
<div class="value">${cumReward.toFixed(1)}</div>
|
| 309 |
+
</div>
|
| 310 |
+
<div class="agent-metric">
|
| 311 |
+
<div class="label">Zone Load</div>
|
| 312 |
+
<div class="value">${(obs?.zone_load_mw||0).toFixed(0)} MW</div>
|
| 313 |
+
</div>
|
| 314 |
+
<div class="agent-metric">
|
| 315 |
+
<div class="label">Zone Gen</div>
|
| 316 |
+
<div class="value">${(obs?.zone_gen_mw||0).toFixed(0)} MW</div>
|
| 317 |
+
</div>
|
| 318 |
+
</div>
|
| 319 |
+
<div class="safety-shield ${wasCorrected?'corrected':'safe'}">
|
| 320 |
+
${wasCorrected?'⚠ Safety Corrected':'▣ Safety OK'}
|
| 321 |
+
${sr?.blocked_topology_actions ? ` | ${sr.blocked_topology_actions} blocked` : ''}
|
| 322 |
+
</div>
|
| 323 |
+
<div class="sparkline-container"><svg id="spark${i}"></svg></div>
|
| 324 |
+
</div>`;
|
| 325 |
+
container.innerHTML += html;
|
| 326 |
+
}
|
| 327 |
+
// Draw sparklines
|
| 328 |
+
for (let i = 0; i < state.numAgents; i++) {
|
| 329 |
+
drawSparkline(`spark${i}`, state.perAgentRewards[i]||[], AGENT_COLORS[i]);
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
function updateLeaderboard() {
|
| 334 |
+
const lb = document.getElementById('leaderboard');
|
| 335 |
+
const agents = [];
|
| 336 |
+
for (let i = 0; i < state.numAgents; i++) {
|
| 337 |
+
const cum = (state.perAgentRewards[i]||[]).reduce((a,b)=>a+b,0);
|
| 338 |
+
const zi = state.zoneInfo[String(i)] || {};
|
| 339 |
+
agents.push({id:i, name: zi.zone_name||AGENT_NAMES[i], score: cum});
|
| 340 |
+
}
|
| 341 |
+
agents.sort((a,b) => b.score - a.score);
|
| 342 |
+
lb.innerHTML = agents.map((a,idx) => `
|
| 343 |
+
<li>
|
| 344 |
+
<span class="agent-label">
|
| 345 |
+
<span class="agent-dot" style="background:${AGENT_COLORS[a.id]};width:6px;height:6px;border-radius:50%;display:inline-block;"></span>
|
| 346 |
+
${['#1','#2','#3'][idx]||' '} ${a.name}
|
| 347 |
+
</span>
|
| 348 |
+
<span class="score" style="color:${AGENT_COLORS[a.id]}">${a.score.toFixed(1)}</span>
|
| 349 |
+
</li>`).join('');
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
// --- Grid Map (Leaflet) ---
|
| 353 |
+
let leafletMap = null;
|
| 354 |
+
let mapLayers = { lines: null, nodes: null, badges: null };
|
| 355 |
+
let mapFitted = false;
|
| 356 |
+
|
| 357 |
+
function initLeafletMap() {
|
| 358 |
+
const container = document.getElementById('gridMap');
|
| 359 |
+
if (leafletMap) return;
|
| 360 |
+
|
| 361 |
+
// Karnataka bounds
|
| 362 |
+
const kaBounds = [[11.5, 73.5], [18.5, 79.0]];
|
| 363 |
+
|
| 364 |
+
leafletMap = L.map(container, {
|
| 365 |
+
center: [14.5, 76.5],
|
| 366 |
+
zoom: 7,
|
| 367 |
+
zoomControl: true,
|
| 368 |
+
attributionControl: false,
|
| 369 |
+
minZoom: 5,
|
| 370 |
+
maxZoom: 15,
|
| 371 |
+
preferCanvas: true,
|
| 372 |
+
});
|
| 373 |
+
|
| 374 |
+
// Dark tile layer for SCADA aesthetic
|
| 375 |
+
L.tileLayer('https://{s}.basemaps.cartocdn.com/dark_all/{z}/{x}/{y}{r}.png', {
|
| 376 |
+
subdomains: 'abcd',
|
| 377 |
+
maxZoom: 19,
|
| 378 |
+
}).addTo(leafletMap);
|
| 379 |
+
|
| 380 |
+
// Attribution (small, bottom-right)
|
| 381 |
+
L.control.attribution({position: 'bottomright', prefix: false})
|
| 382 |
+
.addAttribution('© <a href="https://carto.com/">CARTO</a>')
|
| 383 |
+
.addTo(leafletMap);
|
| 384 |
+
|
| 385 |
+
// Layer groups for easy clearing
|
| 386 |
+
mapLayers.lines = L.layerGroup().addTo(leafletMap);
|
| 387 |
+
mapLayers.nodes = L.layerGroup().addTo(leafletMap);
|
| 388 |
+
mapLayers.badges = L.layerGroup().addTo(leafletMap);
|
| 389 |
+
|
| 390 |
+
// Fix Leaflet size after container is fully rendered
|
| 391 |
+
setTimeout(() => {
|
| 392 |
+
leafletMap.invalidateSize();
|
| 393 |
+
leafletMap.fitBounds(kaBounds, { padding: [20, 20] });
|
| 394 |
+
}, 200);
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
function updateGridMap() {
|
| 398 |
+
if (!leafletMap) initLeafletMap();
|
| 399 |
+
|
| 400 |
+
// Clear previous layers
|
| 401 |
+
mapLayers.lines.clearLayers();
|
| 402 |
+
mapLayers.nodes.clearLayers();
|
| 403 |
+
mapLayers.badges.clearLayers();
|
| 404 |
+
|
| 405 |
+
const typeIcons = {slack:'S',generator:'G',load:'L',battery:'B',solar:'PV',wind:'W'};
|
| 406 |
+
const typeColors = {slack:'#00e5a0',generator:'#f5a623',load:'#e94560',battery:'#4a90d9',solar:'#ffeb3b',wind:'#64ffda'};
|
| 407 |
+
|
| 408 |
+
// Collect buses — merge static config with runtime state
|
| 409 |
+
let allBuses = [];
|
| 410 |
+
const taskCfg = state.taskConfigs[state.task];
|
| 411 |
+
const runtimeState = {};
|
| 412 |
+
for (const obs of Object.values(state.observations)) {
|
| 413 |
+
(obs.local_buses||[]).forEach(b => { runtimeState[b.id] = b; });
|
| 414 |
+
}
|
| 415 |
+
if (taskCfg && taskCfg.buses) {
|
| 416 |
+
allBuses = taskCfg.buses.map(b => {
|
| 417 |
+
const rt = runtimeState[b.id];
|
| 418 |
+
return {...b, p_injection: rt ? rt.p_injection : (b.base_p || 0)};
|
| 419 |
+
});
|
| 420 |
+
} else {
|
| 421 |
+
allBuses = Object.values(runtimeState);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
const hasGPS = allBuses.some(b => b.lat !== undefined && b.lon !== undefined);
|
| 425 |
+
|
| 426 |
+
// For non-GPS tasks, generate fake positions around Karnataka center
|
| 427 |
+
const busPositions = {};
|
| 428 |
+
const zones = [
|
| 429 |
+
{id:0, lat:16.8, lon:76.8, color:AGENT_COLORS[0], label:'Kalaburagi'},
|
| 430 |
+
{id:1, lat:15.2, lon:75.2, color:AGENT_COLORS[1], label:'Hubballi'},
|
| 431 |
+
{id:2, lat:12.8, lon:75.5, color:AGENT_COLORS[2], label:'Mysuru'},
|
| 432 |
+
{id:3, lat:13.2, lon:77.5, color:AGENT_COLORS[3], label:'Bengaluru'},
|
| 433 |
+
];
|
| 434 |
+
|
| 435 |
+
allBuses.forEach((b, idx) => {
|
| 436 |
+
const aid = findAgent(b.id);
|
| 437 |
+
let lat, lon;
|
| 438 |
+
if (hasGPS && b.lat !== undefined && b.lon !== undefined) {
|
| 439 |
+
lat = b.lat;
|
| 440 |
+
lon = b.lon;
|
| 441 |
+
} else {
|
| 442 |
+
// Fallback: spread around zone center
|
| 443 |
+
const zd = zones[aid >= 0 && aid < zones.length ? aid : 0];
|
| 444 |
+
const zBuses = allBuses.filter(bb => findAgent(bb.id) === aid);
|
| 445 |
+
const zi = zBuses.indexOf(b);
|
| 446 |
+
const a = (zi / Math.max(zBuses.length, 1)) * Math.PI * 2;
|
| 447 |
+
lat = zd.lat + Math.cos(a) * 0.3;
|
| 448 |
+
lon = zd.lon + Math.sin(a) * 0.3;
|
| 449 |
+
}
|
| 450 |
+
busPositions[b.id] = {lat, lon, bus: b, agent: aid};
|
| 451 |
+
});
|
| 452 |
+
|
| 453 |
+
// Draw transmission lines
|
| 454 |
+
const drawnLines = new Set();
|
| 455 |
+
for (const obs of Object.values(state.observations)) {
|
| 456 |
+
(obs.internal_lines||[]).concat(obs.boundary_lines||[]).forEach(l => {
|
| 457 |
+
if (drawnLines.has(l.id)) return;
|
| 458 |
+
drawnLines.add(l.id);
|
| 459 |
+
const parts = l.id.replace('L_','').split('_');
|
| 460 |
+
const fromId = parseInt(parts[0]);
|
| 461 |
+
const toId = parseInt(parts[1]);
|
| 462 |
+
const from = busPositions[fromId];
|
| 463 |
+
const to = busPositions[toId];
|
| 464 |
+
if (!from || !to) return;
|
| 465 |
+
|
| 466 |
+
const lc = !l.connected ? '#4a5568' : l.rho > 1 ? '#ff1744' : l.rho > 0.8 ? '#ff9100' : '#e91e63';
|
| 467 |
+
const w = !l.connected ? 1.5 : l.rho > 0.8 ? 5 : 3;
|
| 468 |
+
|
| 469 |
+
const polyline = L.polyline(
|
| 470 |
+
[[from.lat, from.lon], [to.lat, to.lon]],
|
| 471 |
+
{ color: lc, weight: w, dashArray: l.connected ? '10 5' : '4 4', opacity: 0.9 }
|
| 472 |
+
);
|
| 473 |
+
if (l.connected && Math.abs(l.flow) > 0.5) {
|
| 474 |
+
polyline.bindTooltip(`${l.id}: ${l.flow.toFixed(0)} MW (${(l.rho*100).toFixed(0)}%)`, {
|
| 475 |
+
permanent: false, className: 'leaflet-tooltip-dark'
|
| 476 |
+
});
|
| 477 |
+
}
|
| 478 |
+
mapLayers.lines.addLayer(polyline);
|
| 479 |
+
});
|
| 480 |
+
}
|
| 481 |
+
// Ensure lines are visible above tiles
|
| 482 |
+
if (drawnLines.size > 0) {
|
| 483 |
+
mapLayers.lines.eachLayer(l => { if (l.bringToFront) l.bringToFront(); });
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
// Draw bus markers
|
| 487 |
+
for (const [bid, pos] of Object.entries(busPositions)) {
|
| 488 |
+
const b = pos.bus;
|
| 489 |
+
const col = AGENT_COLORS[pos.agent] || '#4a5568';
|
| 490 |
+
const fill = typeColors[b.type] || '#666';
|
| 491 |
+
const r = b.type === 'slack' ? 12 : b.type === 'load' ? 7 : 9;
|
| 492 |
+
const inj = (b.p_injection !== undefined ? b.p_injection : 0);
|
| 493 |
+
const busLabel = b.name || `${b.type} ${b.id}`;
|
| 494 |
+
const icon = typeIcons[b.type] || '?';
|
| 495 |
+
|
| 496 |
+
// Outer ring (zone color)
|
| 497 |
+
const outerRing = L.circleMarker([pos.lat, pos.lon], {
|
| 498 |
+
radius: r + 4, fillColor: 'transparent', fillOpacity: 0,
|
| 499 |
+
color: col, weight: 1.5, opacity: 0.4
|
| 500 |
+
});
|
| 501 |
+
mapLayers.nodes.addLayer(outerRing);
|
| 502 |
+
|
| 503 |
+
// Inner node
|
| 504 |
+
const marker = L.circleMarker([pos.lat, pos.lon], {
|
| 505 |
+
radius: r, fillColor: fill, fillOpacity: 0.9,
|
| 506 |
+
color: col, weight: 1, opacity: 0.6
|
| 507 |
+
});
|
| 508 |
+
|
| 509 |
+
// Rich tooltip
|
| 510 |
+
const tooltipHtml = `
|
| 511 |
+
<div style="font-family:'JetBrains Mono',monospace;font-size:11px;min-width:120px;">
|
| 512 |
+
<b style="color:${fill}">${icon}</b> <b>${busLabel}</b><br>
|
| 513 |
+
<span style="color:#888">Type:</span> ${b.type}<br>
|
| 514 |
+
<span style="color:#888">Injection:</span> <b>${inj.toFixed(1)} MW</b><br>
|
| 515 |
+
<span style="color:#888">Zone:</span> ${state.zoneInfo[String(pos.agent)]?.zone_name || 'Agent ' + pos.agent}
|
| 516 |
+
</div>`;
|
| 517 |
+
marker.bindTooltip(tooltipHtml, { className: 'leaflet-tooltip-dark', direction: 'top', offset: [0, -r] });
|
| 518 |
+
mapLayers.nodes.addLayer(marker);
|
| 519 |
+
|
| 520 |
+
// Label under node
|
| 521 |
+
const labelIcon = L.divIcon({
|
| 522 |
+
className: 'bus-label-icon',
|
| 523 |
+
html: `<span style="color:${fill};text-shadow:0 0 4px #000;font-size:9px;font-family:'JetBrains Mono',monospace;white-space:nowrap;">${busLabel}</span>`,
|
| 524 |
+
iconSize: [80, 14],
|
| 525 |
+
iconAnchor: [40, -r - 2],
|
| 526 |
+
});
|
| 527 |
+
L.marker([pos.lat, pos.lon], { icon: labelIcon, interactive: false }).addTo(mapLayers.nodes);
|
| 528 |
+
|
| 529 |
+
// MW label above node
|
| 530 |
+
const mwIcon = L.divIcon({
|
| 531 |
+
className: 'bus-mw-icon',
|
| 532 |
+
html: `<span style="color:#e0e0e0;text-shadow:0 0 4px #000;font-size:10px;font-weight:700;font-family:'JetBrains Mono',monospace;">${inj.toFixed(0)}</span>`,
|
| 533 |
+
iconSize: [40, 14],
|
| 534 |
+
iconAnchor: [20, r + 16],
|
| 535 |
+
});
|
| 536 |
+
L.marker([pos.lat, pos.lon], { icon: mwIcon, interactive: false }).addTo(mapLayers.nodes);
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
// Zone badge overlays
|
| 540 |
+
zones.slice(0, state.numAgents).forEach(z => {
|
| 541 |
+
const zi = state.zoneInfo[String(z.id)] || {};
|
| 542 |
+
const name = zi.zone_name || z.label || AGENT_NAMES[z.id];
|
| 543 |
+
const cum = (state.perAgentRewards[z.id] || []).reduce((a, b) => a + b, 0);
|
| 544 |
+
|
| 545 |
+
const badgeIcon = L.divIcon({
|
| 546 |
+
className: 'zone-badge-leaflet',
|
| 547 |
+
html: `<div style="background:rgba(10,14,26,0.85);border:1px solid ${z.color};border-radius:6px;padding:4px 10px;text-align:center;white-space:nowrap;">
|
| 548 |
+
<div style="color:${z.color};font-size:11px;font-weight:700;font-family:'JetBrains Mono',monospace;">${name}</div>
|
| 549 |
+
<div style="color:${z.color};font-size:10px;font-family:'JetBrains Mono',monospace;opacity:0.8">${cum.toFixed(1)} pts</div>
|
| 550 |
+
</div>`,
|
| 551 |
+
iconSize: [120, 36],
|
| 552 |
+
iconAnchor: [60, 50],
|
| 553 |
+
});
|
| 554 |
+
L.marker([z.lat, z.lon], { icon: badgeIcon, interactive: false }).addTo(mapLayers.badges);
|
| 555 |
+
});
|
| 556 |
+
|
| 557 |
+
// Fit map to bus extent on first data load
|
| 558 |
+
if (!mapFitted && allBuses.length > 0) {
|
| 559 |
+
const lats = allBuses.filter(b => b.lat).map(b => b.lat);
|
| 560 |
+
const lons = allBuses.filter(b => b.lon).map(b => b.lon);
|
| 561 |
+
if (lats.length > 0) {
|
| 562 |
+
leafletMap.fitBounds([
|
| 563 |
+
[Math.min(...lats) - 0.5, Math.min(...lons) - 0.5],
|
| 564 |
+
[Math.max(...lats) + 0.5, Math.max(...lons) + 0.5]
|
| 565 |
+
]);
|
| 566 |
+
mapFitted = true;
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
function showBusTooltip(e, node) {
|
| 572 |
+
const tt = document.getElementById('busTooltip');
|
| 573 |
+
const zi = state.zoneInfo[node.dataset.agent]||{};
|
| 574 |
+
document.getElementById('ttTitle').textContent = `Bus ${node.dataset.bus} (${node.dataset.type})`;
|
| 575 |
+
document.getElementById('ttType').textContent = node.dataset.type;
|
| 576 |
+
document.getElementById('ttInj').textContent = node.dataset.inj + ' MW';
|
| 577 |
+
document.getElementById('ttZone').textContent = zi.zone_name || 'Zone ' + node.dataset.agent;
|
| 578 |
+
tt.style.left = (e.clientX + 12) + 'px';
|
| 579 |
+
tt.style.top = (e.clientY - 20) + 'px';
|
| 580 |
+
tt.classList.add('visible');
|
| 581 |
+
}
|
| 582 |
+
function hideBusTooltip() { document.getElementById('busTooltip').classList.remove('visible'); }
|
| 583 |
+
|
| 584 |
+
function findAgent(busId) {
|
| 585 |
+
for (const [aid, zi] of Object.entries(state.zoneInfo)) {
|
| 586 |
+
if ((zi.bus_ids||[]).includes(busId)) return parseInt(aid);
|
| 587 |
+
}
|
| 588 |
+
return -1;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
// --- Charts ---
|
| 592 |
+
function drawSparkline(id, data, color) {
|
| 593 |
+
const el = document.getElementById(id);
|
| 594 |
+
if (!el || !data.length) return;
|
| 595 |
+
const w = el.clientWidth||120, h = el.clientHeight||22;
|
| 596 |
+
const min = Math.min(...data), max = Math.max(...data);
|
| 597 |
+
const range = max-min || 1;
|
| 598 |
+
const pts = data.slice(-30).map((v,i,a) => `${(i/(a.length-1||1))*w},${h-(((v-min)/range)*h*0.8+h*0.1)}`).join(' ');
|
| 599 |
+
el.innerHTML = `<polyline points="${pts}" fill="none" stroke="${color}" stroke-width="1.5" opacity="0.8"/>`;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
function updateCharts() {
|
| 603 |
+
// Reward chart
|
| 604 |
+
drawChart('rewardChart', state.rewardHistory, 'var(--chart-reward)', 'Reward');
|
| 605 |
+
// Frequency chart
|
| 606 |
+
drawChart('freqChart', state.freqHistory, 'var(--chart-supply)', 'Hz', 49, 51);
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
function drawChart(containerId, data, color, label, fixedMin, fixedMax) {
|
| 610 |
+
const el = document.getElementById(containerId);
|
| 611 |
+
if (!el) return;
|
| 612 |
+
const W = el.clientWidth||300, H = el.clientHeight||140;
|
| 613 |
+
if (!data.length) { el.innerHTML = `<svg viewBox="0 0 ${W} ${H}"><text x="${W/2}" y="${H/2}" text-anchor="middle" fill="var(--text-muted)" font-size="11">Waiting for data...</text></svg>`; return; }
|
| 614 |
+
const pad = {t:10,r:10,b:20,l:40};
|
| 615 |
+
const cw = W-pad.l-pad.r, ch = H-pad.t-pad.b;
|
| 616 |
+
const min = fixedMin !== undefined ? fixedMin : Math.min(...data);
|
| 617 |
+
const max = fixedMax !== undefined ? fixedMax : Math.max(...data);
|
| 618 |
+
const range = max-min||1;
|
| 619 |
+
const pts = data.map((v,i) => `${pad.l+(i/(data.length-1||1))*cw},${pad.t+ch-(((v-min)/range)*ch)}`).join(' ');
|
| 620 |
+
let svg = `<svg viewBox="0 0 ${W} ${H}" xmlns="http://www.w3.org/2000/svg">`;
|
| 621 |
+
// Grid lines
|
| 622 |
+
for(let i=0;i<=4;i++){const y=pad.t+ch*i/4;const v=(max-((max-min)*i/4)).toFixed(1);svg+=`<line x1="${pad.l}" y1="${y}" x2="${W-pad.r}" y2="${y}" stroke="rgba(255,255,255,0.05)"/><text x="${pad.l-4}" y="${y+3}" text-anchor="end" fill="var(--text-muted)" font-size="8" font-family="JetBrains Mono">${v}</text>`;}
|
| 623 |
+
svg += `<polyline points="${pts}" fill="none" stroke="${color}" stroke-width="1.5"/>`;
|
| 624 |
+
// Fill area
|
| 625 |
+
const firstX = pad.l, lastX = pad.l+(data.length-1)/(data.length-1||1)*cw;
|
| 626 |
+
svg += `<polygon points="${pts} ${lastX},${pad.t+ch} ${firstX},${pad.t+ch}" fill="${color}" opacity="0.08"/>`;
|
| 627 |
+
svg += '</svg>';
|
| 628 |
+
el.innerHTML = svg;
|
| 629 |
+
// Gen mix chart
|
| 630 |
+
if (containerId === 'freqChart') updateGenMix();
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
function updateGenMix() {
|
| 634 |
+
const el = document.getElementById('genMixChart');
|
| 635 |
+
if (!el) return;
|
| 636 |
+
const W = el.clientWidth||200, H = el.clientHeight||140;
|
| 637 |
+
let types = {};
|
| 638 |
+
for (const obs of Object.values(state.observations)) {
|
| 639 |
+
(obs.local_buses||[]).forEach(b => {
|
| 640 |
+
if (b.p_injection > 0) types[b.type] = (types[b.type]||0) + b.p_injection;
|
| 641 |
+
});
|
| 642 |
+
}
|
| 643 |
+
const total = Object.values(types).reduce((a,b)=>a+b,0) || 1;
|
| 644 |
+
const colors = {slack:'#00e5a0',generator:'#f5a623',solar:'#ffeb3b',wind:'#64ffda',battery:'#4a90d9'};
|
| 645 |
+
let svg = `<svg viewBox="0 0 ${W} ${H}">`;
|
| 646 |
+
const cx=W/2, cy=H/2-5, r=Math.min(W,H)*0.3;
|
| 647 |
+
let startAngle = -Math.PI/2;
|
| 648 |
+
for (const [type, val] of Object.entries(types)) {
|
| 649 |
+
const pct = val/total;
|
| 650 |
+
const endAngle = startAngle + pct * Math.PI*2;
|
| 651 |
+
const x1=cx+r*Math.cos(startAngle), y1=cy+r*Math.sin(startAngle);
|
| 652 |
+
const x2=cx+r*Math.cos(endAngle), y2=cy+r*Math.sin(endAngle);
|
| 653 |
+
const large = pct > 0.5 ? 1 : 0;
|
| 654 |
+
svg += `<path d="M${cx},${cy} L${x1},${y1} A${r},${r} 0 ${large},1 ${x2},${y2} Z" fill="${colors[type]||'#666'}" opacity="0.8"/>`;
|
| 655 |
+
const mid = (startAngle+endAngle)/2;
|
| 656 |
+
if (pct > 0.08) {
|
| 657 |
+
const lx=cx+(r+14)*Math.cos(mid), ly=cy+(r+14)*Math.sin(mid);
|
| 658 |
+
svg += `<text x="${lx}" y="${ly}" text-anchor="middle" fill="var(--text-secondary)" font-size="8">${type} ${(pct*100).toFixed(0)}%</text>`;
|
| 659 |
+
}
|
| 660 |
+
startAngle = endAngle;
|
| 661 |
+
}
|
| 662 |
+
svg += `<circle cx="${cx}" cy="${cy}" r="${r*0.55}" fill="var(--bg-card)"/>`;
|
| 663 |
+
svg += `<text x="${cx}" y="${cy-2}" text-anchor="middle" fill="var(--text-primary)" font-family="JetBrains Mono" font-size="14" font-weight="700">${total.toFixed(0)}</text>`;
|
| 664 |
+
svg += `<text x="${cx}" y="${cy+10}" text-anchor="middle" fill="var(--text-muted)" font-size="8">MW</text>`;
|
| 665 |
+
svg += '</svg>';
|
| 666 |
+
el.innerHTML = svg;
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
// --- Alerts ---
|
| 670 |
+
function showAlert(type, msg) {
|
| 671 |
+
const el = document.getElementById('alertBanner');
|
| 672 |
+
el.className = `alert-banner ${type} visible`;
|
| 673 |
+
document.getElementById('alertText').textContent = msg;
|
| 674 |
+
setTimeout(() => el.classList.remove('visible'), 5000);
|
| 675 |
+
}
|
| 676 |
+
function dismissAlert() { document.getElementById('alertBanner').classList.remove('visible'); }
|
| 677 |
+
|
| 678 |
+
// --- Map Controls ---
|
| 679 |
+
function zoomMap(factor) { state.mapScale *= factor; updateGridMap(); }
|
| 680 |
+
function resetMapView() { state.mapScale = 1; updateGridMap(); }
|
static/index.html
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<meta name="description" content="OpenGrid — Multi-Agent POMDP Power Grid Control Room with Safe RL">
|
| 7 |
+
<title>OpenGrid | Control Room</title>
|
| 8 |
+
<link rel="stylesheet" href="/static/style.css">
|
| 9 |
+
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" />
|
| 10 |
+
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
|
| 11 |
+
<link rel="icon" href="/static/logo.png" type="image/png">
|
| 12 |
+
</head>
|
| 13 |
+
<body>
|
| 14 |
+
|
| 15 |
+
<!-- Loading Overlay -->
|
| 16 |
+
<div class="loading-overlay" id="loading">
|
| 17 |
+
<div class="loading-spinner"></div>
|
| 18 |
+
<div class="loading-text">OpenGrid — Initializing Control Room</div>
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
<!-- Alert Banner -->
|
| 22 |
+
<div class="alert-banner" id="alertBanner">
|
| 23 |
+
<span id="alertText"></span>
|
| 24 |
+
<button class="dismiss" onclick="dismissAlert()">Dismiss</button>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
<!-- Main Layout -->
|
| 28 |
+
<div class="control-room">
|
| 29 |
+
|
| 30 |
+
<!-- ===== HEADER ===== -->
|
| 31 |
+
<header class="header">
|
| 32 |
+
<div class="header-brand">
|
| 33 |
+
<img src="/static/logo.png" alt="OpenGrid" class="logo-img" style="width:32px;height:32px;border-radius:6px;">
|
| 34 |
+
<div>
|
| 35 |
+
<h1>OpenGrid</h1>
|
| 36 |
+
<div class="sub">Multi-Agent Power Grid Control Room</div>
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
<div class="sim-badge">
|
| 41 |
+
<span class="dot"></span>
|
| 42 |
+
<span id="simStatus">READY</span>
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
<div class="header-stats">
|
| 46 |
+
<div class="header-stat">
|
| 47 |
+
<span class="label">Episode</span>
|
| 48 |
+
<span class="value normal" id="headerEpisode">--</span>
|
| 49 |
+
</div>
|
| 50 |
+
<div class="header-stat">
|
| 51 |
+
<span class="label">Step</span>
|
| 52 |
+
<span class="value" id="headerStep">0 / 50</span>
|
| 53 |
+
</div>
|
| 54 |
+
<div class="header-stat">
|
| 55 |
+
<span class="label">Frequency</span>
|
| 56 |
+
<span class="value normal" id="headerFreq">50.00 Hz</span>
|
| 57 |
+
</div>
|
| 58 |
+
<div class="header-stat">
|
| 59 |
+
<span class="label">Agents</span>
|
| 60 |
+
<span class="value normal" id="headerAgents">--</span>
|
| 61 |
+
</div>
|
| 62 |
+
<div class="header-stat">
|
| 63 |
+
<span class="label">Team Reward</span>
|
| 64 |
+
<span class="value" id="headerReward">0.00</span>
|
| 65 |
+
</div>
|
| 66 |
+
</div>
|
| 67 |
+
</header>
|
| 68 |
+
|
| 69 |
+
<!-- ===== LEFT PANEL ===== -->
|
| 70 |
+
<aside class="left-panel">
|
| 71 |
+
|
| 72 |
+
<!-- Frequency Display -->
|
| 73 |
+
<div class="card">
|
| 74 |
+
<div class="card-title">Grid Frequency</div>
|
| 75 |
+
<div class="freq-display">
|
| 76 |
+
<div class="freq-arc-container" id="freqArc"></div>
|
| 77 |
+
<div class="freq-deviation" id="freqDev">Deviation: 0.00 Hz | Nominal: 50.00 Hz</div>
|
| 78 |
+
<div class="grid-condition normal" id="gridCondition">NORMAL</div>
|
| 79 |
+
</div>
|
| 80 |
+
</div>
|
| 81 |
+
|
| 82 |
+
<!-- System Summary -->
|
| 83 |
+
<div class="card">
|
| 84 |
+
<div class="card-title">System Summary</div>
|
| 85 |
+
<div class="stat-row highlight">
|
| 86 |
+
<span class="label">Total Generation</span>
|
| 87 |
+
<span class="value" id="totalGen">-- MW</span>
|
| 88 |
+
</div>
|
| 89 |
+
<div class="stat-row">
|
| 90 |
+
<span class="label">Total Load</span>
|
| 91 |
+
<span class="value" id="totalLoad">-- MW</span>
|
| 92 |
+
</div>
|
| 93 |
+
<div class="stat-row">
|
| 94 |
+
<span class="label">Net Balance</span>
|
| 95 |
+
<span class="value" id="netBalance">-- MW</span>
|
| 96 |
+
</div>
|
| 97 |
+
<div class="stat-row">
|
| 98 |
+
<span class="label">Lines Connected</span>
|
| 99 |
+
<span class="value" id="linesConnected">--</span>
|
| 100 |
+
</div>
|
| 101 |
+
<div class="stat-row">
|
| 102 |
+
<span class="label">Lines Overloaded</span>
|
| 103 |
+
<span class="value" id="linesOverloaded" style="color: var(--status-normal);">0</span>
|
| 104 |
+
</div>
|
| 105 |
+
</div>
|
| 106 |
+
|
| 107 |
+
<!-- Coordination -->
|
| 108 |
+
<div class="card">
|
| 109 |
+
<div class="card-title">Oversight Agent</div>
|
| 110 |
+
<div class="coord-score">
|
| 111 |
+
<div class="big-value" id="coordScore" style="color: var(--status-normal);">1.00</div>
|
| 112 |
+
<div style="font-size:10px; color: var(--text-secondary); margin-top:4px;">Coordination Score</div>
|
| 113 |
+
</div>
|
| 114 |
+
<div class="stat-row">
|
| 115 |
+
<span class="label">Conflicts</span>
|
| 116 |
+
<span class="value" id="conflicts">0</span>
|
| 117 |
+
</div>
|
| 118 |
+
<div class="stat-row">
|
| 119 |
+
<span class="label">Safety Corrections</span>
|
| 120 |
+
<span class="value" id="safetyCorrTotal">0</span>
|
| 121 |
+
</div>
|
| 122 |
+
<div class="stat-row">
|
| 123 |
+
<span class="label">Selfish Actions</span>
|
| 124 |
+
<span class="value" id="selfishActions">0</span>
|
| 125 |
+
</div>
|
| 126 |
+
</div>
|
| 127 |
+
|
| 128 |
+
<!-- Exception Log -->
|
| 129 |
+
<div class="card" style="flex:1; display:flex; flex-direction:column; overflow:hidden;">
|
| 130 |
+
<div class="card-title" style="color: var(--status-warning);">Exception Log</div>
|
| 131 |
+
<div class="alarm-log" id="alarmLog">
|
| 132 |
+
<!-- Populated by JS -->
|
| 133 |
+
</div>
|
| 134 |
+
</div>
|
| 135 |
+
|
| 136 |
+
<!-- Task Selector -->
|
| 137 |
+
<div class="card" style="flex-shrink:0;">
|
| 138 |
+
<div class="card-title">Task & Controls</div>
|
| 139 |
+
<div class="task-selector" id="taskSelector">
|
| 140 |
+
<button class="task-btn" data-task="task_easy">Easy</button>
|
| 141 |
+
<button class="task-btn" data-task="task_medium">Medium</button>
|
| 142 |
+
<button class="task-btn" data-task="task_hard">Hard</button>
|
| 143 |
+
<button class="task-btn active" data-task="task_karnataka" style="color: #ffeb3b; border-color: rgba(255,235,59,0.3);">Karnataka</button>
|
| 144 |
+
</div>
|
| 145 |
+
<div class="controls-row" style="margin-top: var(--gap-sm);">
|
| 146 |
+
<button class="ctrl-btn active" id="btnReset" onclick="resetEpisode()">Reset</button>
|
| 147 |
+
<button class="ctrl-btn" id="btnStep" onclick="stepEpisode()">Step</button>
|
| 148 |
+
<button class="ctrl-btn" id="btnAutoRun" onclick="toggleAutoRun()">Auto</button>
|
| 149 |
+
</div>
|
| 150 |
+
</div>
|
| 151 |
+
|
| 152 |
+
</aside>
|
| 153 |
+
|
| 154 |
+
<!-- ===== CENTER PANEL (Grid Map) ===== -->
|
| 155 |
+
<main class="center-panel" id="centerPanel">
|
| 156 |
+
<div class="grid-map" id="gridMap"></div>
|
| 157 |
+
<div class="bus-tooltip" id="busTooltip">
|
| 158 |
+
<div class="tt-title" id="ttTitle">Bus 0</div>
|
| 159 |
+
<div class="tt-row"><span>Type</span><span class="tt-val" id="ttType">--</span></div>
|
| 160 |
+
<div class="tt-row"><span>Injection</span><span class="tt-val" id="ttInj">-- MW</span></div>
|
| 161 |
+
<div class="tt-row"><span>Zone</span><span class="tt-val" id="ttZone">--</span></div>
|
| 162 |
+
</div>
|
| 163 |
+
</main>
|
| 164 |
+
|
| 165 |
+
<!-- ===== RIGHT PANEL (Agent Monitor) ===== -->
|
| 166 |
+
<aside class="right-panel">
|
| 167 |
+
<div class="card">
|
| 168 |
+
<div class="card-title">Agent Leaderboard</div>
|
| 169 |
+
<ul class="leaderboard" id="leaderboard">
|
| 170 |
+
<!-- Populated by JS -->
|
| 171 |
+
</ul>
|
| 172 |
+
</div>
|
| 173 |
+
|
| 174 |
+
<div id="agentCards">
|
| 175 |
+
<!-- Populated by JS -->
|
| 176 |
+
</div>
|
| 177 |
+
</aside>
|
| 178 |
+
|
| 179 |
+
<!-- ===== BOTTOM PANEL ===== -->
|
| 180 |
+
<footer class="bottom-panel">
|
| 181 |
+
|
| 182 |
+
<!-- Reward History Chart -->
|
| 183 |
+
<div class="bottom-card">
|
| 184 |
+
<div class="card-title">Reward History</div>
|
| 185 |
+
<div class="chart-area" id="rewardChart"></div>
|
| 186 |
+
</div>
|
| 187 |
+
|
| 188 |
+
<!-- Frequency Trend -->
|
| 189 |
+
<div class="bottom-card">
|
| 190 |
+
<div class="card-title">Frequency Trend</div>
|
| 191 |
+
<div class="chart-area" id="freqChart"></div>
|
| 192 |
+
</div>
|
| 193 |
+
|
| 194 |
+
<!-- Generation Mix -->
|
| 195 |
+
<div class="bottom-card">
|
| 196 |
+
<div class="card-title">Generation Mix</div>
|
| 197 |
+
<div class="chart-area" id="genMixChart"></div>
|
| 198 |
+
</div>
|
| 199 |
+
|
| 200 |
+
<!-- Episode Score -->
|
| 201 |
+
<div class="bottom-card">
|
| 202 |
+
<div class="card-title">Episode Score</div>
|
| 203 |
+
<div class="coord-score" style="flex:1; display:flex; flex-direction:column; justify-content:center;">
|
| 204 |
+
<div class="big-value" id="episodeScore" style="color: var(--chart-reward); font-size: 36px;">--</div>
|
| 205 |
+
<div style="font-size:10px; color: var(--text-secondary); margin-top:4px;">Grader Score</div>
|
| 206 |
+
<div style="font-size:11px; margin-top:8px;">
|
| 207 |
+
<span style="color: var(--text-secondary);">Steps:</span>
|
| 208 |
+
<span id="totalSteps" style="font-family: 'JetBrains Mono'; font-weight:600;">0</span>
|
| 209 |
+
<span style="color: var(--text-secondary); margin-left:8px;">Blackout:</span>
|
| 210 |
+
<span id="blackoutStatus" style="font-family: 'JetBrains Mono'; font-weight:600; color: var(--status-normal);">No</span>
|
| 211 |
+
</div>
|
| 212 |
+
</div>
|
| 213 |
+
<div class="controls-row">
|
| 214 |
+
<button class="ctrl-btn" onclick="getGrade()">Grade</button>
|
| 215 |
+
<button class="ctrl-btn danger" onclick="resetEpisode()">New Episode</button>
|
| 216 |
+
</div>
|
| 217 |
+
</div>
|
| 218 |
+
|
| 219 |
+
</footer>
|
| 220 |
+
|
| 221 |
+
</div>
|
| 222 |
+
|
| 223 |
+
<script src="/static/app.js"></script>
|
| 224 |
+
</body>
|
| 225 |
+
</html>
|
static/karnataka.svg
ADDED
|
|
static/logo.png
ADDED
|
Git LFS Details
|
static/style.css
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ============================================================================
|
| 2 |
+
OpenGrid KPTCL-SLDC Control Room — Design System
|
| 3 |
+
Inspired by ERCOT control room aesthetics, adapted for Karnataka grid
|
| 4 |
+
============================================================================ */
|
| 5 |
+
|
| 6 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500;600;700&display=swap');
|
| 7 |
+
|
| 8 |
+
/* ---------- CSS Custom Properties ---------- */
|
| 9 |
+
:root {
|
| 10 |
+
/* Background layers */
|
| 11 |
+
--bg-primary: #0a0e1a;
|
| 12 |
+
--bg-secondary: #0f1628;
|
| 13 |
+
--bg-tertiary: #141d35;
|
| 14 |
+
--bg-glass: rgba(15, 22, 40, 0.85);
|
| 15 |
+
--bg-card: rgba(15, 22, 40, 0.7);
|
| 16 |
+
|
| 17 |
+
/* Operational states */
|
| 18 |
+
--status-normal: #00e5a0;
|
| 19 |
+
--status-warning: #ffd700;
|
| 20 |
+
--status-critical:#ff3d3d;
|
| 21 |
+
--status-offline: #4a5568;
|
| 22 |
+
--status-overload:#ff6b35;
|
| 23 |
+
|
| 24 |
+
/* Voltage colors */
|
| 25 |
+
--voltage-400kv: #e94560;
|
| 26 |
+
--voltage-220kv: #f5a623;
|
| 27 |
+
--voltage-110kv: #7ed321;
|
| 28 |
+
--voltage-66kv: #4a90d9;
|
| 29 |
+
|
| 30 |
+
/* Agent identity colors */
|
| 31 |
+
--agent-0: #00bfff;
|
| 32 |
+
--agent-1: #ff69b4;
|
| 33 |
+
--agent-2: #ff6347;
|
| 34 |
+
|
| 35 |
+
/* Text */
|
| 36 |
+
--text-primary: #e8eaf6;
|
| 37 |
+
--text-secondary: #90a4ae;
|
| 38 |
+
--text-accent: #00e5a0;
|
| 39 |
+
--text-danger: #ff5252;
|
| 40 |
+
--text-muted: #546e7a;
|
| 41 |
+
|
| 42 |
+
/* Chart */
|
| 43 |
+
--chart-demand: #00bfff;
|
| 44 |
+
--chart-supply: #00e5a0;
|
| 45 |
+
--chart-reward: #ffd700;
|
| 46 |
+
|
| 47 |
+
/* Spacing */
|
| 48 |
+
--gap-sm: 8px;
|
| 49 |
+
--gap-md: 12px;
|
| 50 |
+
--gap-lg: 16px;
|
| 51 |
+
--gap-xl: 20px;
|
| 52 |
+
|
| 53 |
+
/* Radius */
|
| 54 |
+
--radius-sm: 6px;
|
| 55 |
+
--radius-md: 10px;
|
| 56 |
+
--radius-lg: 14px;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/* ---------- Reset & Base ---------- */
|
| 60 |
+
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 61 |
+
|
| 62 |
+
html, body {
|
| 63 |
+
height: 100%;
|
| 64 |
+
background: var(--bg-primary);
|
| 65 |
+
color: var(--text-primary);
|
| 66 |
+
font-family: 'Inter', 'Segoe UI', sans-serif;
|
| 67 |
+
font-size: 13px;
|
| 68 |
+
line-height: 1.5;
|
| 69 |
+
overflow: hidden;
|
| 70 |
+
-webkit-font-smoothing: antialiased;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/* Subtle scanline overlay */
|
| 74 |
+
body::before {
|
| 75 |
+
content: '';
|
| 76 |
+
position: fixed;
|
| 77 |
+
top: 0; left: 0; right: 0; bottom: 0;
|
| 78 |
+
pointer-events: none;
|
| 79 |
+
z-index: 9999;
|
| 80 |
+
background: repeating-linear-gradient(
|
| 81 |
+
0deg,
|
| 82 |
+
transparent,
|
| 83 |
+
transparent 2px,
|
| 84 |
+
rgba(0,0,0,0.03) 2px,
|
| 85 |
+
rgba(0,0,0,0.03) 4px
|
| 86 |
+
);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/* ---------- Layout ---------- */
|
| 90 |
+
.control-room {
|
| 91 |
+
display: grid;
|
| 92 |
+
grid-template-rows: 52px 1fr 180px;
|
| 93 |
+
grid-template-columns: 260px 1fr 300px;
|
| 94 |
+
grid-template-areas:
|
| 95 |
+
"header header header"
|
| 96 |
+
"left center right"
|
| 97 |
+
"bottom bottom bottom";
|
| 98 |
+
height: 100vh;
|
| 99 |
+
gap: 1px;
|
| 100 |
+
background: rgba(255,255,255,0.04);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
/* ---------- Header ---------- */
|
| 104 |
+
.header {
|
| 105 |
+
grid-area: header;
|
| 106 |
+
background: linear-gradient(90deg, #0a0e1a, #0f2040);
|
| 107 |
+
display: flex;
|
| 108 |
+
align-items: center;
|
| 109 |
+
padding: 0 var(--gap-lg);
|
| 110 |
+
gap: var(--gap-lg);
|
| 111 |
+
border-bottom: 1px solid rgba(0,229,160,0.15);
|
| 112 |
+
z-index: 10;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.header-brand {
|
| 116 |
+
display: flex;
|
| 117 |
+
align-items: center;
|
| 118 |
+
gap: var(--gap-sm);
|
| 119 |
+
flex-shrink: 0;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.header-brand .logo {
|
| 123 |
+
width: 28px;
|
| 124 |
+
height: 28px;
|
| 125 |
+
background: linear-gradient(135deg, #00e5a0, #00bfff);
|
| 126 |
+
border-radius: 6px;
|
| 127 |
+
display: flex;
|
| 128 |
+
align-items: center;
|
| 129 |
+
justify-content: center;
|
| 130 |
+
font-weight: 700;
|
| 131 |
+
font-size: 14px;
|
| 132 |
+
color: #0a0e1a;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.header-brand h1 {
|
| 136 |
+
font-size: 14px;
|
| 137 |
+
font-weight: 600;
|
| 138 |
+
letter-spacing: 0.5px;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.header-brand .sub {
|
| 142 |
+
font-size: 10px;
|
| 143 |
+
color: var(--text-secondary);
|
| 144 |
+
letter-spacing: 1px;
|
| 145 |
+
text-transform: uppercase;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.header-stats {
|
| 149 |
+
display: flex;
|
| 150 |
+
gap: var(--gap-lg);
|
| 151 |
+
margin-left: auto;
|
| 152 |
+
align-items: center;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.header-stat {
|
| 156 |
+
display: flex;
|
| 157 |
+
flex-direction: column;
|
| 158 |
+
align-items: center;
|
| 159 |
+
padding: 4px 12px;
|
| 160 |
+
border-radius: var(--radius-sm);
|
| 161 |
+
background: rgba(255,255,255,0.04);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.header-stat .label {
|
| 165 |
+
font-size: 9px;
|
| 166 |
+
text-transform: uppercase;
|
| 167 |
+
letter-spacing: 1px;
|
| 168 |
+
color: var(--text-secondary);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.header-stat .value {
|
| 172 |
+
font-family: 'JetBrains Mono', monospace;
|
| 173 |
+
font-size: 14px;
|
| 174 |
+
font-weight: 600;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.header-stat .value.normal { color: var(--status-normal); }
|
| 178 |
+
.header-stat .value.warning { color: var(--status-warning); }
|
| 179 |
+
.header-stat .value.critical { color: var(--status-critical); }
|
| 180 |
+
|
| 181 |
+
.sim-badge {
|
| 182 |
+
display: flex;
|
| 183 |
+
align-items: center;
|
| 184 |
+
gap: 6px;
|
| 185 |
+
padding: 4px 10px;
|
| 186 |
+
border-radius: 20px;
|
| 187 |
+
background: rgba(0,229,160,0.1);
|
| 188 |
+
border: 1px solid rgba(0,229,160,0.25);
|
| 189 |
+
font-size: 10px;
|
| 190 |
+
font-weight: 600;
|
| 191 |
+
color: var(--status-normal);
|
| 192 |
+
text-transform: uppercase;
|
| 193 |
+
letter-spacing: 1px;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
.sim-badge .dot {
|
| 197 |
+
width: 6px; height: 6px;
|
| 198 |
+
background: var(--status-normal);
|
| 199 |
+
border-radius: 50%;
|
| 200 |
+
animation: pulse-dot 2s infinite;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
@keyframes pulse-dot {
|
| 204 |
+
0%, 100% { opacity: 1; box-shadow: 0 0 0 0 rgba(0,229,160,0.4); }
|
| 205 |
+
50% { opacity: 0.7; box-shadow: 0 0 0 4px rgba(0,229,160,0); }
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/* ---------- Left Panel ---------- */
|
| 209 |
+
.left-panel {
|
| 210 |
+
grid-area: left;
|
| 211 |
+
background: var(--bg-secondary);
|
| 212 |
+
padding: var(--gap-md);
|
| 213 |
+
overflow-y: auto;
|
| 214 |
+
display: flex;
|
| 215 |
+
flex-direction: column;
|
| 216 |
+
gap: var(--gap-md);
|
| 217 |
+
border-right: 1px solid rgba(255,255,255,0.05);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/* ---------- Cards (shared) ---------- */
|
| 221 |
+
.card {
|
| 222 |
+
background: var(--bg-card);
|
| 223 |
+
border: 1px solid rgba(255,255,255,0.06);
|
| 224 |
+
border-radius: var(--radius-md);
|
| 225 |
+
padding: var(--gap-md);
|
| 226 |
+
backdrop-filter: blur(8px);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
.card-title {
|
| 230 |
+
font-size: 10px;
|
| 231 |
+
font-weight: 600;
|
| 232 |
+
text-transform: uppercase;
|
| 233 |
+
letter-spacing: 1.5px;
|
| 234 |
+
color: var(--text-secondary);
|
| 235 |
+
margin-bottom: var(--gap-sm);
|
| 236 |
+
padding-bottom: 6px;
|
| 237 |
+
border-bottom: 1px solid rgba(255,255,255,0.06);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/* ---------- Alarm Log ---------- */
|
| 241 |
+
.alarm-log {
|
| 242 |
+
flex: 1;
|
| 243 |
+
max-height: 90px;
|
| 244 |
+
overflow-y: auto;
|
| 245 |
+
font-family: 'JetBrains Mono', monospace;
|
| 246 |
+
font-size: 10px;
|
| 247 |
+
line-height: 1.4;
|
| 248 |
+
display: flex;
|
| 249 |
+
flex-direction: column;
|
| 250 |
+
gap: 4px;
|
| 251 |
+
}
|
| 252 |
+
.alarm-entry {
|
| 253 |
+
padding: 4px 6px;
|
| 254 |
+
background: rgba(255,255,255,0.03);
|
| 255 |
+
border-left: 2px solid transparent;
|
| 256 |
+
border-radius: 2px;
|
| 257 |
+
}
|
| 258 |
+
.alarm-time { color: var(--text-muted); margin-right: 6px; }
|
| 259 |
+
.alarm-entry.warn { border-left-color: var(--status-warning); background: rgba(255,152,0,0.05); color: #ffb74d; }
|
| 260 |
+
.alarm-entry.crit { border-left-color: var(--status-critical); background: rgba(244,67,54,0.05); color: #ef5350; }
|
| 261 |
+
.alarm-entry.info { border-left-color: var(--status-normal); }
|
| 262 |
+
|
| 263 |
+
/* ---------- Frequency Display ---------- */
|
| 264 |
+
.freq-display {
|
| 265 |
+
text-align: center;
|
| 266 |
+
padding: var(--gap-md) var(--gap-sm);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
.freq-arc-container {
|
| 270 |
+
position: relative;
|
| 271 |
+
width: 200px;
|
| 272 |
+
height: 110px;
|
| 273 |
+
margin: 0 auto;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
.freq-arc-container svg { overflow: visible; }
|
| 277 |
+
|
| 278 |
+
.freq-value {
|
| 279 |
+
font-family: 'JetBrains Mono', monospace;
|
| 280 |
+
font-size: 32px;
|
| 281 |
+
font-weight: 700;
|
| 282 |
+
letter-spacing: -1px;
|
| 283 |
+
transition: color 0.3s;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
.freq-value.normal { color: var(--status-normal); text-shadow: 0 0 20px rgba(0,229,160,0.3); }
|
| 287 |
+
.freq-value.warning { color: var(--status-warning); text-shadow: 0 0 20px rgba(255,215,0,0.3); }
|
| 288 |
+
.freq-value.critical { color: var(--status-critical); text-shadow: 0 0 20px rgba(255,61,61,0.3); animation: freq-blink 0.5s infinite; }
|
| 289 |
+
|
| 290 |
+
@keyframes freq-blink {
|
| 291 |
+
0%, 100% { opacity: 1; }
|
| 292 |
+
50% { opacity: 0.6; }
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.freq-deviation {
|
| 296 |
+
margin-top: 4px;
|
| 297 |
+
font-family: 'JetBrains Mono', monospace;
|
| 298 |
+
font-size: 10px;
|
| 299 |
+
color: var(--text-secondary);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/* Grid condition badge */
|
| 303 |
+
.grid-condition {
|
| 304 |
+
display: flex;
|
| 305 |
+
align-items: center;
|
| 306 |
+
justify-content: center;
|
| 307 |
+
gap: 6px;
|
| 308 |
+
margin-top: var(--gap-sm);
|
| 309 |
+
padding: 5px 10px;
|
| 310 |
+
border-radius: 20px;
|
| 311 |
+
font-size: 10px;
|
| 312 |
+
font-weight: 600;
|
| 313 |
+
text-transform: uppercase;
|
| 314 |
+
letter-spacing: 0.8px;
|
| 315 |
+
}
|
| 316 |
+
.grid-condition.normal { background: rgba(0,229,160,0.1); color: var(--status-normal); border: 1px solid rgba(0,229,160,0.2); }
|
| 317 |
+
.grid-condition.conservative { background: rgba(255,215,0,0.08); color: var(--status-warning); border: 1px solid rgba(255,215,0,0.15); }
|
| 318 |
+
.grid-condition.alert { background: rgba(255,107,53,0.1); color: var(--status-overload); border: 1px solid rgba(255,107,53,0.2); }
|
| 319 |
+
.grid-condition.emergency { background: rgba(255,61,61,0.1); color: var(--status-critical); border: 1px solid rgba(255,61,61,0.2); animation: cond-pulse 1s infinite; }
|
| 320 |
+
|
| 321 |
+
@keyframes cond-pulse {
|
| 322 |
+
0%,100% { box-shadow: 0 0 0 0 rgba(255,61,61,0.2); }
|
| 323 |
+
50% { box-shadow: 0 0 0 4px rgba(255,61,61,0); }
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/* ---------- System Summary ---------- */
|
| 327 |
+
.stat-row {
|
| 328 |
+
display: flex;
|
| 329 |
+
justify-content: space-between;
|
| 330 |
+
align-items: center;
|
| 331 |
+
padding: 4px 0;
|
| 332 |
+
font-size: 12px;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
.stat-row .label { color: var(--text-secondary); }
|
| 336 |
+
.stat-row .value {
|
| 337 |
+
font-family: 'JetBrains Mono', monospace;
|
| 338 |
+
font-weight: 500;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.stat-row.highlight .value {
|
| 342 |
+
color: var(--status-normal);
|
| 343 |
+
font-weight: 600;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/* Progress bars */
|
| 347 |
+
.progress-bar {
|
| 348 |
+
height: 4px;
|
| 349 |
+
background: rgba(255,255,255,0.06);
|
| 350 |
+
border-radius: 2px;
|
| 351 |
+
overflow: hidden;
|
| 352 |
+
margin-top: 4px;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
.progress-bar-fill {
|
| 356 |
+
height: 100%;
|
| 357 |
+
border-radius: 2px;
|
| 358 |
+
transition: width 0.5s;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
/* ---------- Center Panel (Grid Map) ---------- */
|
| 362 |
+
.center-panel {
|
| 363 |
+
grid-area: center;
|
| 364 |
+
background: var(--bg-tertiary);
|
| 365 |
+
position: relative;
|
| 366 |
+
overflow: hidden;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
.grid-map {
|
| 370 |
+
width: 100%;
|
| 371 |
+
height: 100%;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
.grid-map svg {
|
| 375 |
+
width: 100%;
|
| 376 |
+
height: 100%;
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
/* SVG map styles */
|
| 380 |
+
.zone-polygon {
|
| 381 |
+
opacity: 0.06;
|
| 382 |
+
transition: opacity 0.4s;
|
| 383 |
+
cursor: pointer;
|
| 384 |
+
filter: blur(0.5px);
|
| 385 |
+
}
|
| 386 |
+
.zone-polygon:hover { opacity: 0.18; }
|
| 387 |
+
|
| 388 |
+
.substation-node { cursor: pointer; }
|
| 389 |
+
.substation-node:hover .node-outer { stroke-width: 2.5; filter: url(#glow); }
|
| 390 |
+
.substation-node:hover .node-label { opacity: 1; }
|
| 391 |
+
|
| 392 |
+
.node-label {
|
| 393 |
+
font-family: 'Inter', sans-serif;
|
| 394 |
+
font-size: 8px;
|
| 395 |
+
fill: var(--text-secondary);
|
| 396 |
+
text-anchor: middle;
|
| 397 |
+
pointer-events: none;
|
| 398 |
+
opacity: 0.7;
|
| 399 |
+
transition: opacity 0.2s;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
.node-mw {
|
| 403 |
+
font-family: 'JetBrains Mono', monospace;
|
| 404 |
+
font-size: 9px;
|
| 405 |
+
fill: var(--text-primary);
|
| 406 |
+
text-anchor: middle;
|
| 407 |
+
pointer-events: none;
|
| 408 |
+
font-weight: 500;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
.line-flow {
|
| 412 |
+
fill: none;
|
| 413 |
+
stroke-linecap: round;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
/* Animated flow on lines */
|
| 417 |
+
@keyframes dash-flow {
|
| 418 |
+
to { stroke-dashoffset: -24; }
|
| 419 |
+
}
|
| 420 |
+
.line-animated {
|
| 421 |
+
animation: dash-flow 1.2s linear infinite;
|
| 422 |
+
}
|
| 423 |
+
.line-animated.reverse {
|
| 424 |
+
animation-direction: reverse;
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
.flow-label {
|
| 428 |
+
font-family: 'JetBrains Mono', monospace;
|
| 429 |
+
font-size: 8px;
|
| 430 |
+
fill: rgba(232,234,246,0.6);
|
| 431 |
+
text-anchor: middle;
|
| 432 |
+
pointer-events: none;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
.zone-badge { font-family: 'Inter', sans-serif; pointer-events: none; }
|
| 436 |
+
.zone-badge-bg {
|
| 437 |
+
rx: 8;
|
| 438 |
+
fill: rgba(10, 14, 26, 0.88);
|
| 439 |
+
stroke-width: 1;
|
| 440 |
+
backdrop-filter: blur(6px);
|
| 441 |
+
}
|
| 442 |
+
.zone-badge-name { font-size: 10px; font-weight: 600; text-anchor: middle; }
|
| 443 |
+
.zone-badge-status { font-size: 8px; text-anchor: middle; fill: var(--text-secondary); }
|
| 444 |
+
.zone-badge-reward { font-size: 9px; text-anchor: middle; font-weight: 600; font-family: 'JetBrains Mono', monospace; }
|
| 445 |
+
|
| 446 |
+
/* Bus tooltip */
|
| 447 |
+
.bus-tooltip {
|
| 448 |
+
position: absolute;
|
| 449 |
+
background: rgba(10, 14, 26, 0.95);
|
| 450 |
+
border: 1px solid rgba(0,229,160,0.2);
|
| 451 |
+
border-radius: var(--radius-sm);
|
| 452 |
+
padding: 8px 10px;
|
| 453 |
+
font-size: 11px;
|
| 454 |
+
pointer-events: none;
|
| 455 |
+
z-index: 20;
|
| 456 |
+
min-width: 140px;
|
| 457 |
+
backdrop-filter: blur(12px);
|
| 458 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.4);
|
| 459 |
+
display: none;
|
| 460 |
+
}
|
| 461 |
+
.bus-tooltip.visible { display: block; }
|
| 462 |
+
.bus-tooltip .tt-title {
|
| 463 |
+
font-weight: 600;
|
| 464 |
+
margin-bottom: 4px;
|
| 465 |
+
padding-bottom: 4px;
|
| 466 |
+
border-bottom: 1px solid rgba(255,255,255,0.08);
|
| 467 |
+
}
|
| 468 |
+
.bus-tooltip .tt-row {
|
| 469 |
+
display: flex;
|
| 470 |
+
justify-content: space-between;
|
| 471 |
+
padding: 1px 0;
|
| 472 |
+
}
|
| 473 |
+
.bus-tooltip .tt-row .tt-val {
|
| 474 |
+
font-family: 'JetBrains Mono', monospace;
|
| 475 |
+
font-weight: 500;
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
/* Map overlay controls */
|
| 479 |
+
.map-controls {
|
| 480 |
+
position: absolute;
|
| 481 |
+
top: var(--gap-md);
|
| 482 |
+
right: var(--gap-md);
|
| 483 |
+
display: flex;
|
| 484 |
+
flex-direction: column;
|
| 485 |
+
gap: 4px;
|
| 486 |
+
z-index: 5;
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
.map-btn {
|
| 490 |
+
width: 32px; height: 32px;
|
| 491 |
+
background: var(--bg-glass);
|
| 492 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 493 |
+
border-radius: var(--radius-sm);
|
| 494 |
+
color: var(--text-secondary);
|
| 495 |
+
font-size: 14px;
|
| 496 |
+
cursor: pointer;
|
| 497 |
+
display: flex;
|
| 498 |
+
align-items: center;
|
| 499 |
+
justify-content: center;
|
| 500 |
+
backdrop-filter: blur(8px);
|
| 501 |
+
transition: all 0.2s;
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
.map-btn:hover {
|
| 505 |
+
background: rgba(0,229,160,0.15);
|
| 506 |
+
color: var(--status-normal);
|
| 507 |
+
border-color: rgba(0,229,160,0.3);
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
/* ---------- Right Panel (Agent Monitor) ---------- */
|
| 511 |
+
.right-panel {
|
| 512 |
+
grid-area: right;
|
| 513 |
+
background: var(--bg-secondary);
|
| 514 |
+
padding: var(--gap-md);
|
| 515 |
+
overflow-y: auto;
|
| 516 |
+
display: flex;
|
| 517 |
+
flex-direction: column;
|
| 518 |
+
gap: var(--gap-md);
|
| 519 |
+
border-left: 1px solid rgba(255,255,255,0.05);
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
/* Agent cards */
|
| 523 |
+
.agent-card {
|
| 524 |
+
border-radius: var(--radius-md);
|
| 525 |
+
padding: var(--gap-md);
|
| 526 |
+
background: var(--bg-card);
|
| 527 |
+
border: 1px solid rgba(255,255,255,0.06);
|
| 528 |
+
backdrop-filter: blur(8px);
|
| 529 |
+
transition: border-color 0.3s, box-shadow 0.3s;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
.agent-card.active {
|
| 533 |
+
border-color: rgba(0,229,160,0.2);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
.agent-card.warning {
|
| 537 |
+
border-color: rgba(255,215,0,0.3);
|
| 538 |
+
box-shadow: 0 0 12px rgba(255,215,0,0.05);
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
.agent-card.critical {
|
| 542 |
+
border-color: rgba(255,61,61,0.3);
|
| 543 |
+
box-shadow: 0 0 12px rgba(255,61,61,0.08);
|
| 544 |
+
animation: card-pulse 1.5s infinite;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
@keyframes card-pulse {
|
| 548 |
+
0%, 100% { box-shadow: 0 0 12px rgba(255,61,61,0.08); }
|
| 549 |
+
50% { box-shadow: 0 0 20px rgba(255,61,61,0.15); }
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
.agent-header {
|
| 553 |
+
display: flex;
|
| 554 |
+
justify-content: space-between;
|
| 555 |
+
align-items: center;
|
| 556 |
+
margin-bottom: var(--gap-sm);
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
.agent-name {
|
| 560 |
+
font-size: 12px;
|
| 561 |
+
font-weight: 600;
|
| 562 |
+
display: flex;
|
| 563 |
+
align-items: center;
|
| 564 |
+
gap: 6px;
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
.agent-dot {
|
| 568 |
+
width: 8px; height: 8px;
|
| 569 |
+
border-radius: 50%;
|
| 570 |
+
flex-shrink: 0;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
.agent-status-badge {
|
| 574 |
+
font-size: 9px;
|
| 575 |
+
font-weight: 600;
|
| 576 |
+
padding: 2px 6px;
|
| 577 |
+
border-radius: 10px;
|
| 578 |
+
text-transform: uppercase;
|
| 579 |
+
letter-spacing: 0.5px;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
.agent-status-badge.active {
|
| 583 |
+
background: rgba(0,229,160,0.15);
|
| 584 |
+
color: var(--status-normal);
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
.agent-status-badge.corrected {
|
| 588 |
+
background: rgba(255,215,0,0.15);
|
| 589 |
+
color: var(--status-warning);
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
.agent-metrics {
|
| 593 |
+
display: grid;
|
| 594 |
+
grid-template-columns: 1fr 1fr;
|
| 595 |
+
gap: 6px;
|
| 596 |
+
margin-top: var(--gap-sm);
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
.agent-metric {
|
| 600 |
+
padding: 6px 8px;
|
| 601 |
+
background: rgba(255,255,255,0.02);
|
| 602 |
+
border-radius: var(--radius-sm);
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
.agent-metric .label {
|
| 606 |
+
font-size: 9px;
|
| 607 |
+
text-transform: uppercase;
|
| 608 |
+
letter-spacing: 0.5px;
|
| 609 |
+
color: var(--text-muted);
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
.agent-metric .value {
|
| 613 |
+
font-family: 'JetBrains Mono', monospace;
|
| 614 |
+
font-size: 14px;
|
| 615 |
+
font-weight: 600;
|
| 616 |
+
margin-top: 2px;
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
/* Safety shield */
|
| 620 |
+
.safety-shield {
|
| 621 |
+
margin-top: var(--gap-sm);
|
| 622 |
+
padding: 6px 8px;
|
| 623 |
+
border-radius: var(--radius-sm);
|
| 624 |
+
display: flex;
|
| 625 |
+
align-items: center;
|
| 626 |
+
gap: 6px;
|
| 627 |
+
font-size: 10px;
|
| 628 |
+
font-weight: 600;
|
| 629 |
+
text-transform: uppercase;
|
| 630 |
+
letter-spacing: 0.5px;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
.safety-shield.safe {
|
| 634 |
+
background: rgba(0,229,160,0.08);
|
| 635 |
+
border: 1px solid rgba(0,229,160,0.15);
|
| 636 |
+
color: var(--status-normal);
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
.safety-shield.corrected {
|
| 640 |
+
background: rgba(255,215,0,0.08);
|
| 641 |
+
border: 1px solid rgba(255,215,0,0.2);
|
| 642 |
+
color: var(--status-warning);
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
.safety-shield.violated {
|
| 646 |
+
background: rgba(255,61,61,0.08);
|
| 647 |
+
border: 1px solid rgba(255,61,61,0.2);
|
| 648 |
+
color: var(--status-critical);
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
/* Sparkline */
|
| 652 |
+
.sparkline-container {
|
| 653 |
+
margin-top: var(--gap-sm);
|
| 654 |
+
height: 30px;
|
| 655 |
+
background: rgba(255,255,255,0.02);
|
| 656 |
+
border-radius: var(--radius-sm);
|
| 657 |
+
padding: 4px;
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
.sparkline-container svg {
|
| 661 |
+
width: 100%;
|
| 662 |
+
height: 100%;
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
/* ---------- Bottom Panel ---------- */
|
| 666 |
+
.bottom-panel {
|
| 667 |
+
grid-area: bottom;
|
| 668 |
+
background: var(--bg-secondary);
|
| 669 |
+
display: grid;
|
| 670 |
+
grid-template-columns: 2fr 1fr 1fr 1fr;
|
| 671 |
+
gap: 1px;
|
| 672 |
+
border-top: 1px solid rgba(255,255,255,0.05);
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
.bottom-card {
|
| 676 |
+
background: var(--bg-card);
|
| 677 |
+
padding: var(--gap-md);
|
| 678 |
+
display: flex;
|
| 679 |
+
flex-direction: column;
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
.chart-area {
|
| 683 |
+
flex: 1;
|
| 684 |
+
position: relative;
|
| 685 |
+
min-height: 0;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
.chart-area canvas, .chart-area svg {
|
| 689 |
+
width: 100%;
|
| 690 |
+
height: 100%;
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
/* Reward chart */
|
| 694 |
+
.reward-history {
|
| 695 |
+
flex: 1;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
/* Controls */
|
| 699 |
+
.controls-row {
|
| 700 |
+
display: flex;
|
| 701 |
+
gap: var(--gap-sm);
|
| 702 |
+
margin-top: var(--gap-sm);
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
.ctrl-btn {
|
| 706 |
+
flex: 1;
|
| 707 |
+
padding: 6px 10px;
|
| 708 |
+
background: rgba(255,255,255,0.04);
|
| 709 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 710 |
+
border-radius: var(--radius-sm);
|
| 711 |
+
color: var(--text-primary);
|
| 712 |
+
font-family: 'Inter', sans-serif;
|
| 713 |
+
font-size: 11px;
|
| 714 |
+
font-weight: 500;
|
| 715 |
+
cursor: pointer;
|
| 716 |
+
transition: all 0.2s;
|
| 717 |
+
text-align: center;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
.ctrl-btn:hover {
|
| 721 |
+
background: rgba(0,229,160,0.1);
|
| 722 |
+
border-color: rgba(0,229,160,0.3);
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
.ctrl-btn.active {
|
| 726 |
+
background: rgba(0,229,160,0.15);
|
| 727 |
+
border-color: var(--status-normal);
|
| 728 |
+
color: var(--status-normal);
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
.ctrl-btn.danger {
|
| 732 |
+
border-color: rgba(255,61,61,0.3);
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
.ctrl-btn.danger:hover {
|
| 736 |
+
background: rgba(255,61,61,0.1);
|
| 737 |
+
border-color: rgba(255,61,61,0.5);
|
| 738 |
+
color: var(--status-critical);
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
/* Task selector */
|
| 742 |
+
.task-selector {
|
| 743 |
+
display: flex;
|
| 744 |
+
gap: 4px;
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
.task-btn {
|
| 748 |
+
flex: 1;
|
| 749 |
+
padding: 4px 8px;
|
| 750 |
+
background: rgba(255,255,255,0.03);
|
| 751 |
+
border: 1px solid rgba(255,255,255,0.08);
|
| 752 |
+
border-radius: var(--radius-sm);
|
| 753 |
+
color: var(--text-secondary);
|
| 754 |
+
font-size: 10px;
|
| 755 |
+
font-weight: 500;
|
| 756 |
+
cursor: pointer;
|
| 757 |
+
transition: all 0.2s;
|
| 758 |
+
text-transform: uppercase;
|
| 759 |
+
letter-spacing: 0.5px;
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
.task-btn:hover { border-color: rgba(0,229,160,0.3); color: var(--text-primary); }
|
| 763 |
+
.task-btn.active { background: rgba(0,229,160,0.1); border-color: var(--status-normal); color: var(--status-normal); }
|
| 764 |
+
|
| 765 |
+
/* Leaderboard */
|
| 766 |
+
.leaderboard {
|
| 767 |
+
list-style: none;
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
.leaderboard li {
|
| 771 |
+
display: flex;
|
| 772 |
+
justify-content: space-between;
|
| 773 |
+
align-items: center;
|
| 774 |
+
padding: 5px 0;
|
| 775 |
+
font-size: 11px;
|
| 776 |
+
border-bottom: 1px solid rgba(255,255,255,0.03);
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
.leaderboard li:last-child { border-bottom: none; }
|
| 780 |
+
|
| 781 |
+
.leaderboard .agent-label {
|
| 782 |
+
display: flex;
|
| 783 |
+
align-items: center;
|
| 784 |
+
gap: 6px;
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
.leaderboard .score {
|
| 788 |
+
font-family: 'JetBrains Mono', monospace;
|
| 789 |
+
font-weight: 600;
|
| 790 |
+
font-size: 12px;
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
/* Coordination score */
|
| 794 |
+
.coord-score {
|
| 795 |
+
text-align: center;
|
| 796 |
+
padding: var(--gap-sm);
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
.coord-score .big-value {
|
| 800 |
+
font-family: 'JetBrains Mono', monospace;
|
| 801 |
+
font-size: 28px;
|
| 802 |
+
font-weight: 700;
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
/* Alert banner */
|
| 806 |
+
.alert-banner {
|
| 807 |
+
position: fixed;
|
| 808 |
+
top: 52px;
|
| 809 |
+
left: 0; right: 0;
|
| 810 |
+
z-index: 100;
|
| 811 |
+
padding: 8px var(--gap-lg);
|
| 812 |
+
display: flex;
|
| 813 |
+
align-items: center;
|
| 814 |
+
gap: var(--gap-sm);
|
| 815 |
+
font-size: 12px;
|
| 816 |
+
font-weight: 500;
|
| 817 |
+
transform: translateY(-100%);
|
| 818 |
+
transition: transform 0.3s;
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
.alert-banner.visible { transform: translateY(0); }
|
| 822 |
+
|
| 823 |
+
.alert-banner.critical {
|
| 824 |
+
background: rgba(255,61,61,0.15);
|
| 825 |
+
border-bottom: 1px solid rgba(255,61,61,0.3);
|
| 826 |
+
color: var(--status-critical);
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
.alert-banner.warning {
|
| 830 |
+
background: rgba(255,215,0,0.1);
|
| 831 |
+
border-bottom: 1px solid rgba(255,215,0,0.2);
|
| 832 |
+
color: var(--status-warning);
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
.alert-banner .dismiss {
|
| 836 |
+
margin-left: auto;
|
| 837 |
+
background: none;
|
| 838 |
+
border: 1px solid currentColor;
|
| 839 |
+
border-radius: var(--radius-sm);
|
| 840 |
+
color: inherit;
|
| 841 |
+
padding: 2px 8px;
|
| 842 |
+
font-size: 10px;
|
| 843 |
+
cursor: pointer;
|
| 844 |
+
opacity: 0.7;
|
| 845 |
+
}
|
| 846 |
+
|
| 847 |
+
.alert-banner .dismiss:hover { opacity: 1; }
|
| 848 |
+
|
| 849 |
+
/* Scrollbar */
|
| 850 |
+
::-webkit-scrollbar { width: 4px; }
|
| 851 |
+
::-webkit-scrollbar-track { background: transparent; }
|
| 852 |
+
::-webkit-scrollbar-thumb { background: rgba(255,255,255,0.1); border-radius: 2px; }
|
| 853 |
+
::-webkit-scrollbar-thumb:hover { background: rgba(255,255,255,0.2); }
|
| 854 |
+
|
| 855 |
+
/* Loading state */
|
| 856 |
+
.loading-overlay {
|
| 857 |
+
position: fixed;
|
| 858 |
+
top: 0; left: 0; right: 0; bottom: 0;
|
| 859 |
+
background: var(--bg-primary);
|
| 860 |
+
display: flex;
|
| 861 |
+
flex-direction: column;
|
| 862 |
+
align-items: center;
|
| 863 |
+
justify-content: center;
|
| 864 |
+
z-index: 1000;
|
| 865 |
+
transition: opacity 0.5s;
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
.loading-overlay.hidden {
|
| 869 |
+
opacity: 0;
|
| 870 |
+
pointer-events: none;
|
| 871 |
+
}
|
| 872 |
+
|
| 873 |
+
.loading-spinner {
|
| 874 |
+
width: 40px; height: 40px;
|
| 875 |
+
border: 3px solid rgba(0,229,160,0.15);
|
| 876 |
+
border-top-color: var(--status-normal);
|
| 877 |
+
border-radius: 50%;
|
| 878 |
+
animation: spin 0.8s linear infinite;
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
@keyframes spin { to { transform: rotate(360deg); } }
|
| 882 |
+
|
| 883 |
+
.loading-text {
|
| 884 |
+
margin-top: var(--gap-md);
|
| 885 |
+
color: var(--text-secondary);
|
| 886 |
+
font-size: 12px;
|
| 887 |
+
letter-spacing: 2px;
|
| 888 |
+
text-transform: uppercase;
|
| 889 |
+
}
|
| 890 |
+
|
| 891 |
+
/* ── Leaflet Overrides ── */
|
| 892 |
+
.grid-map .leaflet-container {
|
| 893 |
+
background: var(--bg-primary) !important;
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
.leaflet-tooltip-dark {
|
| 897 |
+
background: rgba(10, 14, 26, 0.92) !important;
|
| 898 |
+
border: 1px solid rgba(0, 229, 160, 0.3) !important;
|
| 899 |
+
color: #e0e0e0 !important;
|
| 900 |
+
font-family: 'JetBrains Mono', monospace !important;
|
| 901 |
+
font-size: 11px !important;
|
| 902 |
+
border-radius: 6px !important;
|
| 903 |
+
padding: 6px 10px !important;
|
| 904 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.6) !important;
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
.leaflet-tooltip-dark::before {
|
| 908 |
+
border-top-color: rgba(10, 14, 26, 0.92) !important;
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
.bus-label-icon, .bus-mw-icon, .zone-badge-leaflet {
|
| 912 |
+
background: none !important;
|
| 913 |
+
border: none !important;
|
| 914 |
+
text-align: center;
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
/* Dark zoom controls */
|
| 918 |
+
.leaflet-control-zoom a {
|
| 919 |
+
background: rgba(15, 22, 40, 0.9) !important;
|
| 920 |
+
color: var(--status-normal) !important;
|
| 921 |
+
border-color: rgba(0, 229, 160, 0.2) !important;
|
| 922 |
+
font-family: 'JetBrains Mono', monospace !important;
|
| 923 |
+
}
|
| 924 |
+
.leaflet-control-zoom a:hover {
|
| 925 |
+
background: rgba(0, 229, 160, 0.15) !important;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
.leaflet-control-attribution {
|
| 929 |
+
background: rgba(10, 14, 26, 0.6) !important;
|
| 930 |
+
color: #555 !important;
|
| 931 |
+
font-size: 9px !important;
|
| 932 |
+
}
|
| 933 |
+
.leaflet-control-attribution a {
|
| 934 |
+
color: #666 !important;
|
| 935 |
+
}
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_multi_agent.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for multi-agent POMDP features:
|
| 3 |
+
- Zone assignment and partitioning
|
| 4 |
+
- Partial observability (ZoneObservation)
|
| 5 |
+
- Safety layer (action validation and correction)
|
| 6 |
+
- Oversight agent (coordination monitoring)
|
| 7 |
+
- Multi-agent step (combined pipeline)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import unittest
|
| 12 |
+
|
| 13 |
+
import networkx as nx
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from src.environment import OpenGridEnv
|
| 17 |
+
from src.tasks import TASKS
|
| 18 |
+
from src.models import GridAction, BusAdjustment, TopologyAction, ZoneObservation
|
| 19 |
+
from src.safety import SafetyLayer
|
| 20 |
+
from src.oversight import OversightAgent
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def task(task_id: str):
|
| 24 |
+
"""Get a deep-copied task config to prevent cross-test contamination."""
|
| 25 |
+
return copy.deepcopy(TASKS[task_id])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TestZoneAssignment(unittest.TestCase):
|
| 29 |
+
"""Tests for multi-agent zone partitioning."""
|
| 30 |
+
|
| 31 |
+
def test_all_buses_assigned(self):
|
| 32 |
+
"""Every bus should be assigned to exactly one zone."""
|
| 33 |
+
for task_id, config in TASKS.items():
|
| 34 |
+
zone_map = config['zone_assignments']
|
| 35 |
+
for i in range(config['num_buses']):
|
| 36 |
+
self.assertIn(i, zone_map, f"Bus {i} not assigned in {task_id}")
|
| 37 |
+
|
| 38 |
+
def test_zone_count_matches(self):
|
| 39 |
+
"""Number of zones should match num_agents."""
|
| 40 |
+
for task_id, config in TASKS.items():
|
| 41 |
+
agents = set(config['zone_assignments'].values())
|
| 42 |
+
self.assertEqual(len(agents), config['num_agents'],
|
| 43 |
+
f"Zone count mismatch in {task_id}")
|
| 44 |
+
|
| 45 |
+
def test_no_empty_zones(self):
|
| 46 |
+
"""Each zone should have at least 1 bus."""
|
| 47 |
+
for task_id, config in TASKS.items():
|
| 48 |
+
for agent_id in range(config['num_agents']):
|
| 49 |
+
bus_ids = config['zone_bus_ids'][agent_id]
|
| 50 |
+
self.assertGreater(len(bus_ids), 0,
|
| 51 |
+
f"Empty zone {agent_id} in {task_id}")
|
| 52 |
+
|
| 53 |
+
def test_lines_classified(self):
|
| 54 |
+
"""All lines should be classified as internal or boundary."""
|
| 55 |
+
for task_id, config in TASKS.items():
|
| 56 |
+
all_internal = set()
|
| 57 |
+
all_boundary = set()
|
| 58 |
+
for agent_id in range(config['num_agents']):
|
| 59 |
+
all_internal.update(config['internal_lines'].get(agent_id, []))
|
| 60 |
+
all_boundary.update(config['boundary_lines'].get(agent_id, []))
|
| 61 |
+
|
| 62 |
+
all_line_ids = {l['id'] for l in config['lines']}
|
| 63 |
+
classified = all_internal | all_boundary
|
| 64 |
+
self.assertEqual(all_line_ids, classified,
|
| 65 |
+
f"Unclassified lines in {task_id}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TestPartialObservability(unittest.TestCase):
|
| 69 |
+
"""Tests for POMDP zone observations."""
|
| 70 |
+
|
| 71 |
+
def test_partial_obs_returns_zone_obs(self):
|
| 72 |
+
"""reset_multi should return ZoneObservation for each agent."""
|
| 73 |
+
config = task("task_easy")
|
| 74 |
+
env = OpenGridEnv(config)
|
| 75 |
+
zone_obs = env.reset_multi()
|
| 76 |
+
|
| 77 |
+
self.assertEqual(len(zone_obs), config["num_agents"],
|
| 78 |
+
"Should have one observation per agent")
|
| 79 |
+
for agent_id, obs in zone_obs.items():
|
| 80 |
+
self.assertIsInstance(obs, ZoneObservation)
|
| 81 |
+
self.assertEqual(obs.agent_id, agent_id)
|
| 82 |
+
|
| 83 |
+
def test_partial_obs_only_shows_local_buses(self):
|
| 84 |
+
"""Each agent should only see buses in their zone."""
|
| 85 |
+
config = task("task_medium")
|
| 86 |
+
env = OpenGridEnv(config)
|
| 87 |
+
zone_obs = env.reset_multi()
|
| 88 |
+
|
| 89 |
+
for agent_id, obs in zone_obs.items():
|
| 90 |
+
expected_bus_ids = set(config['zone_bus_ids'][agent_id])
|
| 91 |
+
actual_bus_ids = {b.id for b in obs.local_buses}
|
| 92 |
+
self.assertEqual(actual_bus_ids, expected_bus_ids,
|
| 93 |
+
f"Agent {agent_id} sees wrong buses")
|
| 94 |
+
|
| 95 |
+
def test_frequency_has_noise(self):
|
| 96 |
+
"""POMDP observations should have noisy frequency readings."""
|
| 97 |
+
config = task("task_easy")
|
| 98 |
+
env = OpenGridEnv(config)
|
| 99 |
+
env.reset_multi()
|
| 100 |
+
|
| 101 |
+
# Compare zone obs against full obs from the same reset
|
| 102 |
+
full_obs = env.state()
|
| 103 |
+
differences = []
|
| 104 |
+
for agent_id in range(config['num_agents']):
|
| 105 |
+
z_obs = env._get_zone_obs(agent_id)
|
| 106 |
+
diff = abs(z_obs.grid_frequency - full_obs.grid_frequency)
|
| 107 |
+
differences.append(diff)
|
| 108 |
+
|
| 109 |
+
# At least one agent should see noisy frequency
|
| 110 |
+
self.assertTrue(any(d > 0.001 for d in differences),
|
| 111 |
+
"No frequency noise detected in POMDP observations")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TestSafetyLayer(unittest.TestCase):
|
| 115 |
+
"""Tests for the safety constraint filter."""
|
| 116 |
+
|
| 117 |
+
def setUp(self):
|
| 118 |
+
self.config = task("task_medium")
|
| 119 |
+
self.safety = SafetyLayer(self.config)
|
| 120 |
+
self.env = OpenGridEnv(self.config)
|
| 121 |
+
self.env.reset()
|
| 122 |
+
|
| 123 |
+
def test_zone_boundary_enforcement(self):
|
| 124 |
+
"""Agent should not be able to adjust buses in another zone."""
|
| 125 |
+
agent_0_buses = set(self.config['zone_bus_ids'][0])
|
| 126 |
+
other_bus = None
|
| 127 |
+
for bus_cfg in self.config['buses']:
|
| 128 |
+
if bus_cfg['id'] not in agent_0_buses:
|
| 129 |
+
other_bus = bus_cfg['id']
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
if other_bus is None:
|
| 133 |
+
self.skipTest("All buses in agent 0's zone (trivial grid)")
|
| 134 |
+
|
| 135 |
+
action = GridAction(bus_adjustments=[
|
| 136 |
+
BusAdjustment(bus_id=other_bus, delta=10.0)
|
| 137 |
+
])
|
| 138 |
+
|
| 139 |
+
corrected, report = self.safety.validate_and_correct(
|
| 140 |
+
agent_id=0,
|
| 141 |
+
proposed_action=action,
|
| 142 |
+
current_line_state=self.env.line_state,
|
| 143 |
+
current_bus_state=self.env.bus_state,
|
| 144 |
+
cooldowns=self.env.cooldowns,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.assertTrue(report.was_corrected, "Should have corrected cross-zone action")
|
| 148 |
+
self.assertEqual(len(corrected.bus_adjustments), 0,
|
| 149 |
+
"Cross-zone adjustment should be removed")
|
| 150 |
+
|
| 151 |
+
def test_safe_action_passes_through(self):
|
| 152 |
+
"""A valid action within the agent's zone should not be corrected."""
|
| 153 |
+
agent_0_buses = self.config['zone_bus_ids'][0]
|
| 154 |
+
controllable = None
|
| 155 |
+
for bus_cfg in self.config['buses']:
|
| 156 |
+
if bus_cfg['id'] in agent_0_buses and bus_cfg['type'] in ['generator', 'battery', 'slack']:
|
| 157 |
+
controllable = bus_cfg['id']
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
if controllable is None:
|
| 161 |
+
self.skipTest("No controllable bus in agent 0's zone")
|
| 162 |
+
|
| 163 |
+
action = GridAction(bus_adjustments=[
|
| 164 |
+
BusAdjustment(bus_id=controllable, delta=5.0)
|
| 165 |
+
])
|
| 166 |
+
|
| 167 |
+
corrected, report = self.safety.validate_and_correct(
|
| 168 |
+
agent_id=0,
|
| 169 |
+
proposed_action=action,
|
| 170 |
+
current_line_state=self.env.line_state,
|
| 171 |
+
current_bus_state=self.env.bus_state,
|
| 172 |
+
cooldowns=self.env.cooldowns,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Should pass through (may have minor clamping)
|
| 176 |
+
self.assertEqual(len(corrected.bus_adjustments), 1,
|
| 177 |
+
"Valid action should produce one adjustment")
|
| 178 |
+
|
| 179 |
+
def test_islanding_blocked(self):
|
| 180 |
+
"""Opening a bridge line should be blocked by safety layer."""
|
| 181 |
+
G = nx.Graph()
|
| 182 |
+
for line in self.config['lines']:
|
| 183 |
+
G.add_edge(line['from'], line['to'])
|
| 184 |
+
bridges = list(nx.bridges(G))
|
| 185 |
+
if not bridges:
|
| 186 |
+
self.skipTest("No bridges in grid topology")
|
| 187 |
+
|
| 188 |
+
bridge = bridges[0]
|
| 189 |
+
line_id = next(
|
| 190 |
+
l['id'] for l in self.config['lines']
|
| 191 |
+
if (l['from'], l['to']) == bridge or (l['to'], l['from']) == bridge
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
action = GridAction(topology_actions=[
|
| 195 |
+
TopologyAction(line_id=line_id, action="open")
|
| 196 |
+
])
|
| 197 |
+
|
| 198 |
+
corrected, report = self.safety.validate_and_correct(
|
| 199 |
+
agent_id=0,
|
| 200 |
+
proposed_action=action,
|
| 201 |
+
current_line_state=self.env.line_state,
|
| 202 |
+
current_bus_state=self.env.bus_state,
|
| 203 |
+
cooldowns=self.env.cooldowns,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
self.assertTrue(report.was_corrected, "Bridge opening should be blocked")
|
| 207 |
+
self.assertEqual(len(corrected.topology_actions), 0,
|
| 208 |
+
"Bridge opening should be removed")
|
| 209 |
+
|
| 210 |
+
def test_duplicate_battery_adjustments_aggregated(self):
|
| 211 |
+
"""Multiple adjustments to the same battery should be aggregated."""
|
| 212 |
+
battery = next(
|
| 213 |
+
(b for b in self.config['buses'] if b['type'] == 'battery'), None
|
| 214 |
+
)
|
| 215 |
+
if battery is None:
|
| 216 |
+
self.skipTest("No battery in task")
|
| 217 |
+
|
| 218 |
+
bus_id = battery['id']
|
| 219 |
+
agent_id = self.config['zone_assignments'].get(bus_id, 0)
|
| 220 |
+
|
| 221 |
+
# Set SOC to a known value
|
| 222 |
+
for b in self.env.bus_state:
|
| 223 |
+
if b['id'] == bus_id:
|
| 224 |
+
b['soc'] = 10.0
|
| 225 |
+
|
| 226 |
+
action = GridAction(bus_adjustments=[
|
| 227 |
+
BusAdjustment(bus_id=bus_id, delta=8.0),
|
| 228 |
+
BusAdjustment(bus_id=bus_id, delta=8.0),
|
| 229 |
+
])
|
| 230 |
+
|
| 231 |
+
corrected, report = self.safety.validate_and_correct(
|
| 232 |
+
agent_id=agent_id,
|
| 233 |
+
proposed_action=action,
|
| 234 |
+
current_line_state=self.env.line_state,
|
| 235 |
+
current_bus_state=self.env.bus_state,
|
| 236 |
+
cooldowns=self.env.cooldowns,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
total_delta = sum(a.delta for a in corrected.bus_adjustments)
|
| 240 |
+
self.assertLessEqual(total_delta, 10.0,
|
| 241 |
+
"Combined discharge should not exceed SOC")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class TestOversightAgent(unittest.TestCase):
|
| 245 |
+
"""Tests for the coordination oversight agent."""
|
| 246 |
+
|
| 247 |
+
def test_no_conflict_scores_high(self):
|
| 248 |
+
"""Cooperative actions should score high coordination."""
|
| 249 |
+
config = task("task_easy")
|
| 250 |
+
oversight = OversightAgent(config)
|
| 251 |
+
|
| 252 |
+
# Both agents inject (cooperative)
|
| 253 |
+
agent_actions = {
|
| 254 |
+
0: GridAction(bus_adjustments=[BusAdjustment(bus_id=0, delta=5.0)]),
|
| 255 |
+
1: GridAction(bus_adjustments=[BusAdjustment(bus_id=1, delta=3.0)]),
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
report = oversight.evaluate(
|
| 259 |
+
agent_actions=agent_actions,
|
| 260 |
+
safety_reports={},
|
| 261 |
+
pre_frequency=49.8,
|
| 262 |
+
post_frequency=49.9,
|
| 263 |
+
pre_bus_state=[],
|
| 264 |
+
post_bus_state=[],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.assertGreater(report.coordination_score, 0.5,
|
| 268 |
+
"Cooperative actions should score > 0.5")
|
| 269 |
+
|
| 270 |
+
def test_reset_clears_history(self):
|
| 271 |
+
"""Resetting oversight should clear intervention history."""
|
| 272 |
+
config = task("task_easy")
|
| 273 |
+
oversight = OversightAgent(config)
|
| 274 |
+
oversight.intervention_history[0] = 5
|
| 275 |
+
oversight.reset()
|
| 276 |
+
self.assertEqual(oversight.intervention_history[0], 0)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class TestMultiAgentStep(unittest.TestCase):
|
| 280 |
+
"""Integration tests for the full multi-agent pipeline."""
|
| 281 |
+
|
| 282 |
+
def test_multi_agent_step_returns_result(self):
|
| 283 |
+
"""step_multi should return a complete MultiAgentStepResult."""
|
| 284 |
+
config = task("task_easy")
|
| 285 |
+
env = OpenGridEnv(config)
|
| 286 |
+
env.reset_multi()
|
| 287 |
+
|
| 288 |
+
# No-op actions for all agents
|
| 289 |
+
actions = {i: GridAction() for i in range(config['num_agents'])}
|
| 290 |
+
result = env.step_multi(actions)
|
| 291 |
+
|
| 292 |
+
self.assertEqual(len(result.observations), config['num_agents'])
|
| 293 |
+
self.assertEqual(len(result.rewards), config['num_agents'])
|
| 294 |
+
self.assertIsInstance(result.team_reward, float)
|
| 295 |
+
self.assertIsInstance(result.done, bool)
|
| 296 |
+
self.assertEqual(len(result.safety_reports), config['num_agents'])
|
| 297 |
+
|
| 298 |
+
def test_safety_reports_match_agent_ids(self):
|
| 299 |
+
"""Safety reports should contain all expected agent IDs."""
|
| 300 |
+
config = task("task_easy")
|
| 301 |
+
env = OpenGridEnv(config)
|
| 302 |
+
env.reset_multi()
|
| 303 |
+
|
| 304 |
+
result = env.step_multi({
|
| 305 |
+
i: GridAction() for i in range(config['num_agents'])
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
report_ids = set(result.safety_reports.keys())
|
| 309 |
+
expected_ids = set(range(config['num_agents']))
|
| 310 |
+
self.assertEqual(report_ids, expected_ids,
|
| 311 |
+
"Safety report agent IDs should match expected agents")
|
| 312 |
+
|
| 313 |
+
def test_multi_agent_episode_completes(self):
|
| 314 |
+
"""A full multi-agent episode should complete without errors."""
|
| 315 |
+
config = task("task_easy")
|
| 316 |
+
env = OpenGridEnv(config)
|
| 317 |
+
env.reset_multi()
|
| 318 |
+
|
| 319 |
+
done = False
|
| 320 |
+
steps = 0
|
| 321 |
+
while not done and steps < config['max_steps'] + 5:
|
| 322 |
+
actions = {i: GridAction() for i in range(config['num_agents'])}
|
| 323 |
+
result = env.step_multi(actions)
|
| 324 |
+
done = result.done
|
| 325 |
+
steps += 1
|
| 326 |
+
|
| 327 |
+
self.assertTrue(done, "Episode should terminate")
|
| 328 |
+
self.assertLessEqual(steps, config['max_steps'] + 1)
|
| 329 |
+
|
| 330 |
+
def test_backward_compatibility(self):
|
| 331 |
+
"""Single-agent reset/step should still work after multi-agent changes."""
|
| 332 |
+
for task_id in TASKS:
|
| 333 |
+
config = task(task_id)
|
| 334 |
+
env = OpenGridEnv(config)
|
| 335 |
+
obs = env.reset()
|
| 336 |
+
self.assertGreater(len(obs.buses), 0,
|
| 337 |
+
f"No buses in {task_id}")
|
| 338 |
+
|
| 339 |
+
obs, reward, done, info = env.step(GridAction())
|
| 340 |
+
self.assertEqual(obs.timestep, 1)
|
| 341 |
+
self.assertIsInstance(reward.value, float)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
if __name__ == '__main__':
|
| 345 |
+
unittest.main()
|
tests/test_solver.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for core simulation components:
|
| 3 |
+
- DC power flow solver
|
| 4 |
+
- Environment lifecycle (reset, step, terminate)
|
| 5 |
+
- Grading system (scoring, bounds, reproducibility)
|
| 6 |
+
- Baseline heuristic policy
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
import unittest
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from src.physics import DCSolver, IslandedException
|
| 15 |
+
from src.environment import OpenGridEnv
|
| 16 |
+
from src.tasks import TASKS
|
| 17 |
+
from src.models import GridAction, BusAdjustment
|
| 18 |
+
from src.grader import RobustnessGrader, compute_analytical_ceiling
|
| 19 |
+
from src.baseline import heuristic_policy
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def task(task_id: str):
|
| 23 |
+
"""Get a deep-copied task config to prevent cross-test contamination."""
|
| 24 |
+
return copy.deepcopy(TASKS[task_id])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TestDCSolver(unittest.TestCase):
|
| 28 |
+
def setUp(self):
|
| 29 |
+
self.num_buses = 3
|
| 30 |
+
self.lines = [
|
| 31 |
+
{'id': 'L01', 'from': 0, 'to': 1, 'susceptance': 100, 'connected': True},
|
| 32 |
+
{'id': 'L12', 'from': 1, 'to': 2, 'susceptance': 50, 'connected': True},
|
| 33 |
+
{'id': 'L02', 'from': 0, 'to': 2, 'susceptance': 100, 'connected': True}
|
| 34 |
+
]
|
| 35 |
+
self.solver = DCSolver(self.num_buses)
|
| 36 |
+
self.solver.update_grid(self.lines)
|
| 37 |
+
|
| 38 |
+
def test_power_flow_balance(self):
|
| 39 |
+
"""Slack bus should absorb any generation/load imbalance."""
|
| 40 |
+
p_inj = np.array([0.0, 50.0, -100.0])
|
| 41 |
+
theta, flows, slack_inj = self.solver.solve(p_inj)
|
| 42 |
+
|
| 43 |
+
# Check that flows are computed
|
| 44 |
+
self.assertIn('L01', flows)
|
| 45 |
+
self.assertIn('L02', flows)
|
| 46 |
+
|
| 47 |
+
def test_islanding_detection(self):
|
| 48 |
+
"""Disconnecting lines to island bus 2 should raise IslandedException."""
|
| 49 |
+
with self.assertRaises(IslandedException):
|
| 50 |
+
broken_lines = [
|
| 51 |
+
{'id': 'L01', 'from': 0, 'to': 1, 'susceptance': 100, 'connected': True},
|
| 52 |
+
{'id': 'L12', 'from': 1, 'to': 2, 'susceptance': 50, 'connected': False},
|
| 53 |
+
{'id': 'L02', 'from': 0, 'to': 2, 'susceptance': 100, 'connected': False}
|
| 54 |
+
]
|
| 55 |
+
self.solver.update_grid(broken_lines)
|
| 56 |
+
|
| 57 |
+
def test_slack_injection_returned(self):
|
| 58 |
+
"""solve() should return slack bus injection as third element."""
|
| 59 |
+
p_inj = np.array([0.0, 50.0, -100.0])
|
| 60 |
+
result = self.solver.solve(p_inj)
|
| 61 |
+
self.assertEqual(len(result), 3)
|
| 62 |
+
theta, flows, slack_inj = result
|
| 63 |
+
# Slack should inject ~50 MW to cover the deficit
|
| 64 |
+
self.assertAlmostEqual(slack_inj, 50.0, places=0)
|
| 65 |
+
|
| 66 |
+
def test_solve_before_update_raises(self):
|
| 67 |
+
"""Calling solve() on a fresh solver should raise RuntimeError."""
|
| 68 |
+
fresh = DCSolver(3)
|
| 69 |
+
with self.assertRaises(RuntimeError):
|
| 70 |
+
fresh.solve(np.array([0.0, 10.0, -10.0]))
|
| 71 |
+
|
| 72 |
+
def test_invalid_bus_index_raises(self):
|
| 73 |
+
"""Lines referencing out-of-range bus IDs should raise ValueError."""
|
| 74 |
+
bad_lines = [
|
| 75 |
+
{'id': 'L_bad', 'from': 0, 'to': 99, 'susceptance': 50, 'connected': True},
|
| 76 |
+
]
|
| 77 |
+
solver = DCSolver(3)
|
| 78 |
+
with self.assertRaises(ValueError):
|
| 79 |
+
solver.update_grid(bad_lines)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestEnvironment(unittest.TestCase):
|
| 83 |
+
def test_reset_returns_observation(self):
|
| 84 |
+
"""reset() should return a valid GridObservation."""
|
| 85 |
+
env = OpenGridEnv(task("task_easy"))
|
| 86 |
+
obs = env.reset()
|
| 87 |
+
self.assertEqual(obs.timestep, 0)
|
| 88 |
+
self.assertGreater(len(obs.buses), 0, "Observation should have buses")
|
| 89 |
+
self.assertGreater(len(obs.lines), 0, "Observation should have lines")
|
| 90 |
+
|
| 91 |
+
def test_step_returns_tuple(self):
|
| 92 |
+
"""step() should return (obs, reward, done, info)."""
|
| 93 |
+
env = OpenGridEnv(task("task_easy"))
|
| 94 |
+
env.reset()
|
| 95 |
+
obs, reward, done, info = env.step(GridAction())
|
| 96 |
+
self.assertEqual(obs.timestep, 1)
|
| 97 |
+
self.assertIsInstance(reward.value, float)
|
| 98 |
+
self.assertIsInstance(done, bool)
|
| 99 |
+
|
| 100 |
+
def test_reproducibility(self):
|
| 101 |
+
"""Running the same task twice should produce identical initial observations."""
|
| 102 |
+
env1 = OpenGridEnv(task("task_easy"))
|
| 103 |
+
obs1 = env1.reset()
|
| 104 |
+
|
| 105 |
+
env2 = OpenGridEnv(task("task_easy"))
|
| 106 |
+
obs2 = env2.reset()
|
| 107 |
+
|
| 108 |
+
self.assertEqual(obs1.grid_frequency, obs2.grid_frequency)
|
| 109 |
+
self.assertEqual(len(obs1.buses), len(obs2.buses))
|
| 110 |
+
|
| 111 |
+
def test_episode_terminates(self):
|
| 112 |
+
"""Episode should end after max_steps."""
|
| 113 |
+
config = task("task_easy")
|
| 114 |
+
env = OpenGridEnv(config)
|
| 115 |
+
env.reset()
|
| 116 |
+
done = False
|
| 117 |
+
steps = 0
|
| 118 |
+
while not done and steps < 100:
|
| 119 |
+
_, _, done, _ = env.step(GridAction())
|
| 120 |
+
steps += 1
|
| 121 |
+
self.assertTrue(done, "Episode should terminate")
|
| 122 |
+
self.assertLessEqual(steps, config["max_steps"])
|
| 123 |
+
|
| 124 |
+
def test_frequency_reasonable(self):
|
| 125 |
+
"""Frequency should stay in a reasonable range for do-nothing agent."""
|
| 126 |
+
env = OpenGridEnv(task("task_easy"))
|
| 127 |
+
obs = env.reset()
|
| 128 |
+
for _ in range(10):
|
| 129 |
+
obs, _, done, _ = env.step(GridAction())
|
| 130 |
+
if done:
|
| 131 |
+
break
|
| 132 |
+
self.assertGreater(obs.grid_frequency, 40.0,
|
| 133 |
+
"Frequency below reasonable minimum")
|
| 134 |
+
self.assertLess(obs.grid_frequency, 60.0,
|
| 135 |
+
"Frequency above reasonable maximum")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TestGrader(unittest.TestCase):
|
| 139 |
+
def test_grader_score_range(self):
|
| 140 |
+
"""Grader should return score strictly in (0, 1) — never 0.0 or 1.0."""
|
| 141 |
+
grader = RobustnessGrader(task("task_easy"))
|
| 142 |
+
result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
|
| 143 |
+
self.assertGreater(result["score"], 0.0)
|
| 144 |
+
self.assertLess(result["score"], 1.0)
|
| 145 |
+
|
| 146 |
+
def test_grader_all_tasks(self):
|
| 147 |
+
"""Grader should work on all registered tasks."""
|
| 148 |
+
for task_id, config in TASKS.items():
|
| 149 |
+
grader = RobustnessGrader(copy.deepcopy(config))
|
| 150 |
+
result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
|
| 151 |
+
self.assertIn("score", result, f"Missing 'score' for {task_id}")
|
| 152 |
+
self.assertIn("avg_raw_reward", result,
|
| 153 |
+
f"Missing 'avg_raw_reward' for {task_id}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class TestBaseline(unittest.TestCase):
|
| 157 |
+
def test_heuristic_returns_valid_action(self):
|
| 158 |
+
"""Heuristic policy should return a valid GridAction."""
|
| 159 |
+
env = OpenGridEnv(task("task_easy"))
|
| 160 |
+
obs = env.reset()
|
| 161 |
+
action = heuristic_policy(obs)
|
| 162 |
+
self.assertIsInstance(action, GridAction)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class TestReproducibility(unittest.TestCase):
|
| 166 |
+
def test_floor_deterministic(self):
|
| 167 |
+
"""Two calls to _estimate_bounds should produce identical floors (seeded RNG)."""
|
| 168 |
+
grader1 = RobustnessGrader(task("task_easy"))
|
| 169 |
+
grader1._estimate_bounds(n_samples=3)
|
| 170 |
+
|
| 171 |
+
grader2 = RobustnessGrader(task("task_easy"))
|
| 172 |
+
grader2._estimate_bounds(n_samples=3)
|
| 173 |
+
|
| 174 |
+
self.assertEqual(grader1.reward_floor, grader2.reward_floor,
|
| 175 |
+
"Floor should be deterministic with same seed")
|
| 176 |
+
|
| 177 |
+
def test_ceiling_is_analytical(self):
|
| 178 |
+
"""Ceiling should be max_steps * 1.2, not an empirical estimate."""
|
| 179 |
+
config = task("task_easy")
|
| 180 |
+
grader = RobustnessGrader(config)
|
| 181 |
+
bounds = grader.get_bounds()
|
| 182 |
+
expected_ceiling = compute_analytical_ceiling(config["max_steps"])
|
| 183 |
+
self.assertEqual(bounds["reward_ceiling"], expected_ceiling,
|
| 184 |
+
"Ceiling should match analytical formula")
|
| 185 |
+
|
| 186 |
+
def test_heuristic_score_below_one(self):
|
| 187 |
+
"""With analytical ceiling, heuristic should score < 1.0 (not degenerate)."""
|
| 188 |
+
grader = RobustnessGrader(task("task_easy"))
|
| 189 |
+
result = grader.evaluate_policy(heuristic_policy, n_episodes=1)
|
| 190 |
+
self.assertLess(result["score"], 1.0)
|
| 191 |
+
self.assertGreater(result["score"], 0.0)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == '__main__':
|
| 195 |
+
unittest.main()
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Training module for OpenGrid GRPO pipeline
|
training/opengrid_grpo_colab.ipynb
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🔋 OpenGrid — GRPO Training Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**Multi-Agent RL for Power Grid Operations**\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"- **Environment**: OpenGrid — multi-agent POMDP with safety layer & oversight agent\n",
|
| 14 |
+
"- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
|
| 15 |
+
"- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"⚡ **Runtime**: Select `T4 GPU` from Runtime → Change runtime type"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"## 1. Install Dependencies"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"%%capture\n",
|
| 34 |
+
"!pip install unsloth\n",
|
| 35 |
+
"!pip install --no-deps trl peft accelerate bitsandbytes\n",
|
| 36 |
+
"!pip install fastapi uvicorn pydantic numpy networkx matplotlib openai httpx datasets"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"source": [
|
| 43 |
+
"## 2. Clone OpenGrid Repository"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"import os\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# ⚠️ UPDATE THIS with your actual repo URL\n",
|
| 55 |
+
"REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"if not os.path.exists(\"opengrid\"):\n",
|
| 58 |
+
" !git clone {REPO_URL} opengrid\n",
|
| 59 |
+
"else:\n",
|
| 60 |
+
" !cd opengrid && git pull\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"os.chdir(\"opengrid\")\n",
|
| 63 |
+
"print(f\"Working directory: {os.getcwd()}\")\n",
|
| 64 |
+
"!ls -la"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## 3. Verify GPU & Environment"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"import torch\n",
|
| 81 |
+
"print(f\"PyTorch: {torch.__version__}\")\n",
|
| 82 |
+
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 83 |
+
"if torch.cuda.is_available():\n",
|
| 84 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 85 |
+
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
|
| 86 |
+
"else:\n",
|
| 87 |
+
" print(\"⚠️ No GPU detected! Go to Runtime → Change runtime type → T4 GPU\")"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"# Verify OpenGrid imports work\n",
|
| 97 |
+
"import sys\n",
|
| 98 |
+
"sys.path.insert(0, '.')\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"from src.environment import OpenGridEnv\n",
|
| 101 |
+
"from src.tasks import TASKS\n",
|
| 102 |
+
"from src.models import GridAction, BusAdjustment\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"print(f\"Available tasks: {list(TASKS.keys())}\")\n",
|
| 105 |
+
"for tid, cfg in TASKS.items():\n",
|
| 106 |
+
" print(f\" {tid}: {cfg['num_buses']} buses, {cfg['num_agents']} agents, {cfg.get('difficulty','')}\")"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "markdown",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"source": [
|
| 113 |
+
"## 4. Run Test Mode (Pipeline Verification)"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": null,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"!python training/train_grpo.py --test-mode"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "markdown",
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"source": [
|
| 129 |
+
"## 5. Baseline Evaluation (Before Training)\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"Run the heuristic policy to get baseline scores. We'll compare against this after training."
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"execution_count": null,
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"outputs": [],
|
| 139 |
+
"source": [
|
| 140 |
+
"import json\n",
|
| 141 |
+
"import re\n",
|
| 142 |
+
"import numpy as np\n",
|
| 143 |
+
"from src.environment import OpenGridEnv\n",
|
| 144 |
+
"from src.tasks import TASKS\n",
|
| 145 |
+
"from src.models import GridAction, BusAdjustment\n",
|
| 146 |
+
"from training.train_grpo import (\n",
|
| 147 |
+
" rollout_multi_agent, format_observation_prompt, extract_action\n",
|
| 148 |
+
")\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"def heuristic_generate(prompt):\n",
|
| 151 |
+
" \"\"\"Simple proportional controller as baseline.\"\"\"\n",
|
| 152 |
+
" freq_match = re.search(r'Frequency: ([\\d.]+)', prompt)\n",
|
| 153 |
+
" freq = float(freq_match.group(1)) if freq_match else 50.0\n",
|
| 154 |
+
" error = 50.0 - freq\n",
|
| 155 |
+
" delta = max(-20, min(20, error * 10))\n",
|
| 156 |
+
" bus_match = re.search(r'Bus (\\d+) \\((generator|battery|slack)\\)', prompt)\n",
|
| 157 |
+
" if bus_match:\n",
|
| 158 |
+
" return json.dumps({\"bus_adjustments\": [{\"bus_id\": int(bus_match.group(1)), \"delta\": round(delta, 1)}], \"topology_actions\": []})\n",
|
| 159 |
+
" return json.dumps({\"bus_adjustments\": [], \"topology_actions\": []})\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"# Evaluate baseline on all tasks\n",
|
| 162 |
+
"baseline_results = {}\n",
|
| 163 |
+
"for task_id in [\"task_easy\", \"task_medium\", \"task_karnataka\"]:\n",
|
| 164 |
+
" if task_id not in TASKS:\n",
|
| 165 |
+
" continue\n",
|
| 166 |
+
" config = TASKS[task_id]\n",
|
| 167 |
+
" rewards = []\n",
|
| 168 |
+
" import copy\n",
|
| 169 |
+
" for ep in range(5):\n",
|
| 170 |
+
" ep_config = copy.deepcopy(config)\n",
|
| 171 |
+
" ep_config['seed'] = 42 + ep\n",
|
| 172 |
+
" env = OpenGridEnv(ep_config)\n",
|
| 173 |
+
" result = rollout_multi_agent(env, heuristic_generate, ep_config)\n",
|
| 174 |
+
" rewards.append(result['total_reward'])\n",
|
| 175 |
+
" baseline_results[task_id] = {\n",
|
| 176 |
+
" \"avg_reward\": np.mean(rewards),\n",
|
| 177 |
+
" \"std_reward\": np.std(rewards),\n",
|
| 178 |
+
" \"rewards\": rewards\n",
|
| 179 |
+
" }\n",
|
| 180 |
+
" print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\")\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"# Save baseline for later comparison\n",
|
| 183 |
+
"import pickle\n",
|
| 184 |
+
"os.makedirs(\"training/outputs\", exist_ok=True)\n",
|
| 185 |
+
"with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
|
| 186 |
+
" pickle.dump(baseline_results, f)\n",
|
| 187 |
+
"print(\"\\n✅ Baseline scores saved.\")"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "markdown",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"source": [
|
| 194 |
+
"## 6. Load Model with Unsloth (4-bit Quantized)"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": null,
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": [
|
| 203 |
+
"from unsloth import FastLanguageModel\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"MODEL_NAME = \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\"\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 208 |
+
" model_name=MODEL_NAME,\n",
|
| 209 |
+
" max_seq_length=2048,\n",
|
| 210 |
+
" load_in_4bit=True,\n",
|
| 211 |
+
")\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 214 |
+
" model,\n",
|
| 215 |
+
" r=16,\n",
|
| 216 |
+
" lora_alpha=16,\n",
|
| 217 |
+
" lora_dropout=0,\n",
|
| 218 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 219 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 220 |
+
")\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"if tokenizer.pad_token is None:\n",
|
| 223 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"print(f\"✅ Model loaded: {MODEL_NAME}\")\n",
|
| 226 |
+
"print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "markdown",
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"source": [
|
| 233 |
+
"## 7. Generate Training Prompts from Environment"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "code",
|
| 238 |
+
"execution_count": null,
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"outputs": [],
|
| 241 |
+
"source": [
|
| 242 |
+
"import copy\n",
|
| 243 |
+
"import json as _json\n",
|
| 244 |
+
"import numpy as np\n",
|
| 245 |
+
"from training.train_grpo import SYSTEM_PROMPT, format_observation_prompt\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"TRAIN_TASK = \"task_karnataka\" # Change to task_easy for faster first run\n",
|
| 248 |
+
"NUM_EPISODES = 30\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"task_config = TASKS[TRAIN_TASK]\n",
|
| 251 |
+
"base_seed = task_config.get('seed', 42)\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"prompts = []\n",
|
| 254 |
+
"obs_contexts = [] # stored as JSON strings to satisfy PyArrow schema inference\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"for episode in range(NUM_EPISODES):\n",
|
| 257 |
+
" ep_config = copy.deepcopy(task_config)\n",
|
| 258 |
+
" ep_config['seed'] = base_seed + episode\n",
|
| 259 |
+
" env = OpenGridEnv(ep_config)\n",
|
| 260 |
+
" zone_obs = env.reset_multi()\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" for t in range(min(10, task_config['max_steps'])):\n",
|
| 263 |
+
" for agent_id, obs in zone_obs.items():\n",
|
| 264 |
+
" # model_dump_json() → json.loads() ensures all keys are strings\n",
|
| 265 |
+
" obs_dict = _json.loads(obs.model_dump_json())\n",
|
| 266 |
+
" prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
|
| 267 |
+
" messages = [\n",
|
| 268 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 269 |
+
" {\"role\": \"user\", \"content\": prompt_text},\n",
|
| 270 |
+
" ]\n",
|
| 271 |
+
" formatted = tokenizer.apply_chat_template(\n",
|
| 272 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 273 |
+
" )\n",
|
| 274 |
+
" prompts.append(formatted)\n",
|
| 275 |
+
" # Store as JSON string — flat scalar, no schema-inference issues\n",
|
| 276 |
+
" obs_contexts.append(_json.dumps(obs_dict))\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" # Advance env with diverse random actions (no slack bus)\n",
|
| 279 |
+
" random_actions = {}\n",
|
| 280 |
+
" for aid in range(env.num_agents):\n",
|
| 281 |
+
" zone_buses = task_config['zone_bus_ids'].get(aid, [])\n",
|
| 282 |
+
" controllable = [bid for bid in zone_buses\n",
|
| 283 |
+
" if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')\n",
|
| 284 |
+
" in ['generator', 'battery']]\n",
|
| 285 |
+
" adj = []\n",
|
| 286 |
+
" if controllable:\n",
|
| 287 |
+
" bid = np.random.choice(controllable)\n",
|
| 288 |
+
" adj = [BusAdjustment(bus_id=int(bid), delta=float(np.random.uniform(-15, 15)))]\n",
|
| 289 |
+
" random_actions[aid] = GridAction(bus_adjustments=adj)\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" result = env.step_multi(random_actions)\n",
|
| 292 |
+
" if result.done:\n",
|
| 293 |
+
" break\n",
|
| 294 |
+
" zone_obs = result.observations\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"print(f\"✅ Generated {len(prompts)} training prompts\")\n",
|
| 297 |
+
"print(f\"\\nSample prompt (first 400 chars):\")\n",
|
| 298 |
+
"print(prompts[0][:400])"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "markdown",
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"source": [
|
| 305 |
+
"## 8. Define GRPO Reward Function"
|
| 306 |
+
]
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"cell_type": "code",
|
| 310 |
+
"execution_count": null,
|
| 311 |
+
"metadata": {},
|
| 312 |
+
"outputs": [],
|
| 313 |
+
"source": [
|
| 314 |
+
"import json as _json\n",
|
| 315 |
+
"from training.train_grpo import compute_grpo_reward, extract_action\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"def reward_fn(completions, obs_context=None, **kwargs):\n",
|
| 318 |
+
" \"\"\"GRPO-compatible reward function for OpenGrid.\n",
|
| 319 |
+
" obs_context arrives as JSON strings from the dataset column.\n",
|
| 320 |
+
" \"\"\"\n",
|
| 321 |
+
" texts = []\n",
|
| 322 |
+
" for c in completions:\n",
|
| 323 |
+
" if isinstance(c, list):\n",
|
| 324 |
+
" text = c[-1]['content'] if c else \"\"\n",
|
| 325 |
+
" else:\n",
|
| 326 |
+
" text = str(c)\n",
|
| 327 |
+
" texts.append(text)\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" # Deserialize JSON strings → dicts for the reward scorer\n",
|
| 330 |
+
" if obs_context is None:\n",
|
| 331 |
+
" batch_obs = [None] * len(texts)\n",
|
| 332 |
+
" else:\n",
|
| 333 |
+
" batch_obs = [\n",
|
| 334 |
+
" _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
|
| 335 |
+
" for ctx in obs_context\n",
|
| 336 |
+
" ]\n",
|
| 337 |
+
" return compute_grpo_reward(texts, batch_obs)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# Quick sanity test\n",
|
| 340 |
+
"test_rewards = reward_fn([\n",
|
| 341 |
+
" '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
|
| 342 |
+
" 'invalid json here',\n",
|
| 343 |
+
"])\n",
|
| 344 |
+
"print(f\"Test rewards: {test_rewards}\")\n",
|
| 345 |
+
"assert len(test_rewards) == 2, \"reward_fn must return one score per completion\"\n",
|
| 346 |
+
"print(\"✅ reward_fn OK\")"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"cell_type": "markdown",
|
| 351 |
+
"metadata": {},
|
| 352 |
+
"source": [
|
| 353 |
+
"## 9. Train with GRPO 🚀"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "code",
|
| 358 |
+
"execution_count": null,
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"outputs": [],
|
| 361 |
+
"source": [
|
| 362 |
+
"from trl import GRPOTrainer, GRPOConfig\n",
|
| 363 |
+
"from datasets import Dataset\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"_cuda_ok = torch.cuda.is_available()\n",
|
| 366 |
+
"_bf16 = _cuda_ok and torch.cuda.is_bf16_supported()\n",
|
| 367 |
+
"_fp16 = _cuda_ok and not _bf16\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"grpo_config = GRPOConfig(\n",
|
| 370 |
+
" output_dir=\"training/outputs/grpo_checkpoints\",\n",
|
| 371 |
+
" num_train_epochs=1,\n",
|
| 372 |
+
" per_device_train_batch_size=2,\n",
|
| 373 |
+
" gradient_accumulation_steps=4,\n",
|
| 374 |
+
" learning_rate=5e-6,\n",
|
| 375 |
+
" logging_steps=5,\n",
|
| 376 |
+
" save_steps=50,\n",
|
| 377 |
+
" max_completion_length=256,\n",
|
| 378 |
+
" num_generations=4,\n",
|
| 379 |
+
" report_to=\"none\",\n",
|
| 380 |
+
" remove_unused_columns=False,\n",
|
| 381 |
+
" bf16=_bf16,\n",
|
| 382 |
+
" fp16=_fp16,\n",
|
| 383 |
+
")\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"# obs_contexts are JSON strings — PyArrow handles flat strings with no issues\n",
|
| 386 |
+
"train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
|
| 387 |
+
"print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"trainer = GRPOTrainer(\n",
|
| 390 |
+
" model=model,\n",
|
| 391 |
+
" args=grpo_config,\n",
|
| 392 |
+
" train_dataset=train_dataset,\n",
|
| 393 |
+
" reward_funcs=reward_fn,\n",
|
| 394 |
+
" processing_class=tokenizer,\n",
|
| 395 |
+
")\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
|
| 398 |
+
"print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
|
| 399 |
+
"print(\"\\n🚀 Starting GRPO training...\")\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"train_result = trainer.train()\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"print(\"\\n✅ Training complete!\")\n",
|
| 404 |
+
"print(f\" Total steps: {trainer.state.global_step}\")"
|
| 405 |
+
]
|
| 406 |
+
},
|
| 407 |
+
{
|
| 408 |
+
"cell_type": "markdown",
|
| 409 |
+
"metadata": {},
|
| 410 |
+
"source": [
|
| 411 |
+
"## 10. Save Trained Model"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"cell_type": "code",
|
| 416 |
+
"execution_count": null,
|
| 417 |
+
"metadata": {},
|
| 418 |
+
"outputs": [],
|
| 419 |
+
"source": [
|
| 420 |
+
"OUTPUT_PATH = \"training/outputs/trained_model\"\n",
|
| 421 |
+
"trainer.save_model(OUTPUT_PATH)\n",
|
| 422 |
+
"tokenizer.save_pretrained(OUTPUT_PATH)\n",
|
| 423 |
+
"print(f\"✅ Model saved to {OUTPUT_PATH}\")"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "markdown",
|
| 428 |
+
"metadata": {},
|
| 429 |
+
"source": [
|
| 430 |
+
"## 11. Evaluate Trained Model (After Training)"
|
| 431 |
+
]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"cell_type": "code",
|
| 435 |
+
"execution_count": null,
|
| 436 |
+
"metadata": {},
|
| 437 |
+
"outputs": [],
|
| 438 |
+
"source": [
|
| 439 |
+
"from transformers import pipeline\n",
|
| 440 |
+
"\n",
|
| 441 |
+
"# Create generation function from trained model\n",
|
| 442 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"def trained_generate(prompt):\n",
|
| 445 |
+
" \"\"\"Generate action using the trained model.\"\"\"\n",
|
| 446 |
+
" messages = [\n",
|
| 447 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 448 |
+
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 449 |
+
" ]\n",
|
| 450 |
+
" formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 451 |
+
" inputs = tokenizer(formatted, return_tensors=\"pt\").to(model.device)\n",
|
| 452 |
+
" with torch.no_grad():\n",
|
| 453 |
+
" outputs = model.generate(\n",
|
| 454 |
+
" **inputs,\n",
|
| 455 |
+
" max_new_tokens=256,\n",
|
| 456 |
+
" temperature=0.3,\n",
|
| 457 |
+
" do_sample=True,\n",
|
| 458 |
+
" )\n",
|
| 459 |
+
" response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 460 |
+
" return response\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"# Evaluate on same tasks as baseline\n",
|
| 463 |
+
"trained_results = {}\n",
|
| 464 |
+
"for task_id in [\"task_easy\", \"task_medium\", \"task_karnataka\"]:\n",
|
| 465 |
+
" if task_id not in TASKS:\n",
|
| 466 |
+
" continue\n",
|
| 467 |
+
" config = TASKS[task_id]\n",
|
| 468 |
+
" rewards = []\n",
|
| 469 |
+
" import copy\n",
|
| 470 |
+
" for ep in range(5):\n",
|
| 471 |
+
" ep_config = copy.deepcopy(config)\n",
|
| 472 |
+
" ep_config['seed'] = 42 + ep\n",
|
| 473 |
+
" env = OpenGridEnv(ep_config)\n",
|
| 474 |
+
" result = rollout_multi_agent(env, trained_generate, ep_config)\n",
|
| 475 |
+
" rewards.append(result['total_reward'])\n",
|
| 476 |
+
" print(f\" {task_id} ep{ep}: reward={result['total_reward']:.2f}, blackout={result['is_blackout']}\")\n",
|
| 477 |
+
" trained_results[task_id] = {\n",
|
| 478 |
+
" \"avg_reward\": np.mean(rewards),\n",
|
| 479 |
+
" \"std_reward\": np.std(rewards),\n",
|
| 480 |
+
" \"rewards\": rewards\n",
|
| 481 |
+
" }\n",
|
| 482 |
+
" print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\\n\")"
|
| 483 |
+
]
|
| 484 |
+
},
|
| 485 |
+
{
|
| 486 |
+
"cell_type": "markdown",
|
| 487 |
+
"metadata": {},
|
| 488 |
+
"source": [
|
| 489 |
+
"## 12. Generate Before/After Plots 📊"
|
| 490 |
+
]
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"cell_type": "code",
|
| 494 |
+
"execution_count": null,
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"outputs": [],
|
| 497 |
+
"source": [
|
| 498 |
+
"import matplotlib.pyplot as plt\n",
|
| 499 |
+
"import pickle\n",
|
| 500 |
+
"\n",
|
| 501 |
+
"# Load baseline\n",
|
| 502 |
+
"with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
|
| 503 |
+
" baseline_results = pickle.load(f)\n",
|
| 504 |
+
"\n",
|
| 505 |
+
"# ── Plot 1: Before vs After Bar Chart ──\n",
|
| 506 |
+
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 507 |
+
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
| 508 |
+
"x = np.arange(len(common_tasks))\n",
|
| 509 |
+
"width = 0.35\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"before_vals = [baseline_results[t]['avg_reward'] for t in common_tasks]\n",
|
| 512 |
+
"after_vals = [trained_results[t]['avg_reward'] for t in common_tasks]\n",
|
| 513 |
+
"\n",
|
| 514 |
+
"bars1 = ax.bar(x - width/2, before_vals, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.8)\n",
|
| 515 |
+
"bars2 = ax.bar(x + width/2, after_vals, width, label='GRPO Trained', color='#00d4aa', alpha=0.8)\n",
|
| 516 |
+
"\n",
|
| 517 |
+
"ax.set_xlabel('Task', fontsize=12)\n",
|
| 518 |
+
"ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
|
| 519 |
+
"ax.set_title('OpenGrid — GRPO Training: Before vs After', fontsize=14, fontweight='bold')\n",
|
| 520 |
+
"ax.set_xticks(x)\n",
|
| 521 |
+
"ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
|
| 522 |
+
"ax.legend(fontsize=11)\n",
|
| 523 |
+
"ax.grid(True, alpha=0.3, axis='y')\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"# Fix label positioning for negative bar heights\n",
|
| 526 |
+
"for bars in (bars1, bars2):\n",
|
| 527 |
+
" for bar in bars:\n",
|
| 528 |
+
" h = bar.get_height()\n",
|
| 529 |
+
" ax.text(\n",
|
| 530 |
+
" bar.get_x() + bar.get_width() / 2.,\n",
|
| 531 |
+
" h + (2 if h >= 0 else -5),\n",
|
| 532 |
+
" f'{h:.1f}',\n",
|
| 533 |
+
" ha='center', va='bottom' if h >= 0 else 'top', fontsize=10\n",
|
| 534 |
+
" )\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"plt.tight_layout()\n",
|
| 537 |
+
"plt.savefig('training/outputs/before_after.png', dpi=150)\n",
|
| 538 |
+
"plt.show()\n",
|
| 539 |
+
"print(\"✅ Saved: training/outputs/before_after.png\")"
|
| 540 |
+
]
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"cell_type": "code",
|
| 544 |
+
"execution_count": null,
|
| 545 |
+
"metadata": {},
|
| 546 |
+
"outputs": [],
|
| 547 |
+
"source": [
|
| 548 |
+
"# ── Plot 2: Training Reward Curve ──\n",
|
| 549 |
+
"history = trainer.state.log_history\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"steps = [h['step'] for h in history if 'loss' in h]\n",
|
| 552 |
+
"losses = [h['loss'] for h in history if 'loss' in h]\n",
|
| 553 |
+
"\n",
|
| 554 |
+
"fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 555 |
+
"ax.plot(steps, losses, color='#ff6b6b', linewidth=1.5, alpha=0.6, label='Loss')\n",
|
| 556 |
+
"if len(losses) > 10:\n",
|
| 557 |
+
" window = min(20, len(losses) // 3)\n",
|
| 558 |
+
" smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')\n",
|
| 559 |
+
" ax.plot(steps[window-1:], smoothed, color='#ff6b6b', linewidth=2.5, label=f'Smoothed (w={window})')\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"ax.set_xlabel('Training Step', fontsize=12)\n",
|
| 562 |
+
"ax.set_ylabel('Loss', fontsize=12)\n",
|
| 563 |
+
"ax.set_title('OpenGrid GRPO — Training Loss', fontsize=14, fontweight='bold')\n",
|
| 564 |
+
"ax.legend()\n",
|
| 565 |
+
"ax.grid(True, alpha=0.3)\n",
|
| 566 |
+
"plt.tight_layout()\n",
|
| 567 |
+
"plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
|
| 568 |
+
"plt.show()\n",
|
| 569 |
+
"print(\"✅ Saved: training/outputs/training_loss.png\")"
|
| 570 |
+
]
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"cell_type": "markdown",
|
| 574 |
+
"metadata": {},
|
| 575 |
+
"source": [
|
| 576 |
+
"## 13. Summary & Next Steps\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"### Results Table"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
"cell_type": "code",
|
| 583 |
+
"execution_count": null,
|
| 584 |
+
"metadata": {},
|
| 585 |
+
"outputs": [],
|
| 586 |
+
"source": [
|
| 587 |
+
"print(\"=\"*60)\n",
|
| 588 |
+
"print(\" OpenGrid GRPO Training — Results Summary\")\n",
|
| 589 |
+
"print(\"=\"*60)\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"# Rebuild common_tasks in case Cell 12 was skipped\n",
|
| 592 |
+
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 593 |
+
"\n",
|
| 594 |
+
"print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'Δ':>10}\")\n",
|
| 595 |
+
"print(\"-\"*60)\n",
|
| 596 |
+
"for t in common_tasks:\n",
|
| 597 |
+
" b = baseline_results[t]['avg_reward']\n",
|
| 598 |
+
" a = trained_results[t]['avg_reward']\n",
|
| 599 |
+
" delta = a - b\n",
|
| 600 |
+
" arrow = '↑' if delta > 0 else '↓'\n",
|
| 601 |
+
" print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
|
| 602 |
+
"print(\"=\"*60)"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"cell_type": "code",
|
| 607 |
+
"execution_count": null,
|
| 608 |
+
"metadata": {},
|
| 609 |
+
"outputs": [],
|
| 610 |
+
"source": [
|
| 611 |
+
"# Download plots for your README\n",
|
| 612 |
+
"from google.colab import files\n",
|
| 613 |
+
"files.download('training/outputs/before_after.png')\n",
|
| 614 |
+
"files.download('training/outputs/training_loss.png')"
|
| 615 |
+
]
|
| 616 |
+
}
|
| 617 |
+
],
|
| 618 |
+
"metadata": {
|
| 619 |
+
"accelerator": "GPU",
|
| 620 |
+
"colab": {
|
| 621 |
+
"gpuType": "T4",
|
| 622 |
+
"provenance": []
|
| 623 |
+
},
|
| 624 |
+
"kernelspec": {
|
| 625 |
+
"display_name": "Python 3",
|
| 626 |
+
"name": "python3"
|
| 627 |
+
},
|
| 628 |
+
"language_info": {
|
| 629 |
+
"name": "python",
|
| 630 |
+
"version": "3.10.0"
|
| 631 |
+
}
|
| 632 |
+
},
|
| 633 |
+
"nbformat": 4,
|
| 634 |
+
"nbformat_minor": 0
|
| 635 |
+
}
|
training/train_grpo.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenGrid GRPO Training Script
|
| 3 |
+
==============================
|
| 4 |
+
Uses TRL's GRPOTrainer to train an LLM for multi-agent power grid control.
|
| 5 |
+
|
| 6 |
+
The LLM receives grid observations (partial, per-zone) as text prompts,
|
| 7 |
+
generates JSON actions, and is trained via GRPO to maximize grid stability rewards.
|
| 8 |
+
|
| 9 |
+
Compatible with:
|
| 10 |
+
- Unsloth for 4-bit quantized training (recommended)
|
| 11 |
+
- HuggingFace TRL GRPOTrainer
|
| 12 |
+
- Colab / HF Spaces with GPU
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Quick test (no GPU needed, just verifies the pipeline)
|
| 16 |
+
python training/train_grpo.py --test-mode
|
| 17 |
+
|
| 18 |
+
# Full training on GPU
|
| 19 |
+
python training/train_grpo.py --model Qwen/Qwen2.5-1.5B-Instruct --epochs 3
|
| 20 |
+
|
| 21 |
+
# With Unsloth quantization (faster, less memory)
|
| 22 |
+
python training/train_grpo.py --model unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit --use-unsloth
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import copy
|
| 27 |
+
import json
|
| 28 |
+
import random
|
| 29 |
+
import sys
|
| 30 |
+
import os
|
| 31 |
+
import re
|
| 32 |
+
import time
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
# Add project root to path
|
| 36 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import matplotlib
|
| 40 |
+
matplotlib.use('Agg')
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
|
| 43 |
+
from src.environment import OpenGridEnv
|
| 44 |
+
from src.tasks import TASKS
|
| 45 |
+
from src.models import GridAction, BusAdjustment, TopologyAction
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================================
|
| 49 |
+
# Prompt Engineering
|
| 50 |
+
# ============================================================================
|
| 51 |
+
|
| 52 |
+
SYSTEM_PROMPT = """You are an AI power grid operator for the Karnataka Power Transmission Corporation (KPTCL).
|
| 53 |
+
You manage one zone of a multi-agent grid. Your goal: keep frequency at 50.0 Hz, avoid line overloads, and prevent blackouts.
|
| 54 |
+
|
| 55 |
+
You receive partial observations of your zone and must output a JSON action.
|
| 56 |
+
Respond ONLY with valid JSON matching this schema:
|
| 57 |
+
{"bus_adjustments": [{"bus_id": <int>, "delta": <float>}], "topology_actions": []}
|
| 58 |
+
|
| 59 |
+
Rules:
|
| 60 |
+
- Positive delta = inject more power (discharge battery / increase generation)
|
| 61 |
+
- Negative delta = reduce injection (charge battery / decrease generation)
|
| 62 |
+
- Only adjust buses in YOUR zone
|
| 63 |
+
- Keep frequency close to 50.0 Hz
|
| 64 |
+
- Avoid overloading lines (rho > 1.0 is dangerous)"""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def format_observation_prompt(obs_dict: dict, zone_name: str = "") -> str:
|
| 68 |
+
"""Convert a zone observation to a text prompt for the LLM."""
|
| 69 |
+
freq = obs_dict.get('grid_frequency', 50.0)
|
| 70 |
+
timestep = obs_dict.get('timestep', 0)
|
| 71 |
+
|
| 72 |
+
prompt = f"[Zone: {zone_name}] Step {timestep} | Frequency: {freq:.3f} Hz"
|
| 73 |
+
|
| 74 |
+
freq_error = freq - 50.0
|
| 75 |
+
if abs(freq_error) > 0.3:
|
| 76 |
+
prompt += f" [!] CRITICAL: {freq_error:+.3f} Hz deviation!"
|
| 77 |
+
elif abs(freq_error) > 0.1:
|
| 78 |
+
prompt += f" WARNING: {freq_error:+.3f} Hz deviation"
|
| 79 |
+
|
| 80 |
+
# Local buses
|
| 81 |
+
buses = obs_dict.get('local_buses', [])
|
| 82 |
+
if buses:
|
| 83 |
+
prompt += "\n\nYour buses:"
|
| 84 |
+
for b in buses:
|
| 85 |
+
bus_info = f" Bus {b['id']} ({b['type']}): {b['p_injection']:.1f} MW"
|
| 86 |
+
if b['type'] == 'battery':
|
| 87 |
+
bus_info += f" | SoC: {b['soc']:.1f} MWh"
|
| 88 |
+
prompt += f"\n{bus_info}"
|
| 89 |
+
|
| 90 |
+
# Lines
|
| 91 |
+
all_lines = obs_dict.get('internal_lines', []) + obs_dict.get('boundary_lines', [])
|
| 92 |
+
overloaded = [l for l in all_lines if l.get('rho', 0) > 0.8 and l.get('connected', True)]
|
| 93 |
+
if overloaded:
|
| 94 |
+
prompt += "\n\n[!] Stressed lines:"
|
| 95 |
+
for l in overloaded:
|
| 96 |
+
prompt += f"\n {l['id']}: {l['rho']:.2f} loading ({l['flow']:.1f} MW)"
|
| 97 |
+
|
| 98 |
+
# Neighbor signals
|
| 99 |
+
neighbors = obs_dict.get('neighbor_signals', {})
|
| 100 |
+
if neighbors:
|
| 101 |
+
prompt += "\n\nNeighbor zones (avg injection):"
|
| 102 |
+
for nid, val in neighbors.items():
|
| 103 |
+
prompt += f"\n Zone {nid}: {val:.1f} MW"
|
| 104 |
+
|
| 105 |
+
# Zone summary
|
| 106 |
+
zone_load = obs_dict.get('zone_load_mw', 0)
|
| 107 |
+
zone_gen = obs_dict.get('zone_gen_mw', 0)
|
| 108 |
+
if zone_load or zone_gen:
|
| 109 |
+
prompt += f"\n\nZone balance: Gen={zone_gen:.1f} MW, Load={zone_load:.1f} MW, Net={zone_gen-zone_load:.1f} MW"
|
| 110 |
+
|
| 111 |
+
prompt += "\n\nWhat action do you take? Respond with JSON only."
|
| 112 |
+
return prompt
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def extract_action(text: str) -> GridAction:
|
| 116 |
+
"""Parse LLM output to a GridAction, with fallback for malformed JSON."""
|
| 117 |
+
text = text.strip()
|
| 118 |
+
|
| 119 |
+
# Try to find JSON in the response
|
| 120 |
+
json_match = re.search(r'\{[\s\S]*\}', text)
|
| 121 |
+
if json_match:
|
| 122 |
+
try:
|
| 123 |
+
data = json.loads(json_match.group())
|
| 124 |
+
return GridAction(
|
| 125 |
+
bus_adjustments=[
|
| 126 |
+
BusAdjustment(**a) for a in data.get('bus_adjustments', [])
|
| 127 |
+
],
|
| 128 |
+
topology_actions=[
|
| 129 |
+
TopologyAction(**t) for t in data.get('topology_actions', [])
|
| 130 |
+
],
|
| 131 |
+
)
|
| 132 |
+
except (json.JSONDecodeError, Exception):
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
# Fallback: no-op action
|
| 136 |
+
return GridAction()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ============================================================================
|
| 140 |
+
# Environment Rollout
|
| 141 |
+
# ============================================================================
|
| 142 |
+
|
| 143 |
+
def rollout_single_agent(env: OpenGridEnv, generate_fn, task_config: dict) -> dict:
|
| 144 |
+
"""Run one episode in single-agent mode. Returns episode data."""
|
| 145 |
+
obs = env.reset()
|
| 146 |
+
total_reward = 0.0
|
| 147 |
+
rewards = []
|
| 148 |
+
steps = 0
|
| 149 |
+
is_blackout = False
|
| 150 |
+
|
| 151 |
+
for t in range(task_config['max_steps']):
|
| 152 |
+
obs_dict = obs.model_dump()
|
| 153 |
+
prompt = format_observation_prompt(obs_dict, zone_name="Full_Grid")
|
| 154 |
+
|
| 155 |
+
response = generate_fn(prompt)
|
| 156 |
+
action = extract_action(response)
|
| 157 |
+
|
| 158 |
+
obs, reward, done, info = env.step(action)
|
| 159 |
+
total_reward += reward.value
|
| 160 |
+
rewards.append(reward.value)
|
| 161 |
+
steps += 1
|
| 162 |
+
|
| 163 |
+
if done:
|
| 164 |
+
is_blackout = info.is_blackout
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
"total_reward": total_reward,
|
| 169 |
+
"rewards": rewards,
|
| 170 |
+
"steps": steps,
|
| 171 |
+
"is_blackout": is_blackout,
|
| 172 |
+
"avg_reward": total_reward / max(steps, 1),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def rollout_multi_agent(env: OpenGridEnv, generate_fn, task_config: dict) -> dict:
|
| 177 |
+
"""Run one episode in multi-agent mode. Returns episode data."""
|
| 178 |
+
zone_obs = env.reset_multi()
|
| 179 |
+
total_reward = 0.0
|
| 180 |
+
rewards = []
|
| 181 |
+
per_agent_rewards = {i: [] for i in range(env.num_agents)}
|
| 182 |
+
steps = 0
|
| 183 |
+
safety_interventions = 0
|
| 184 |
+
is_blackout = False
|
| 185 |
+
|
| 186 |
+
for t in range(task_config['max_steps']):
|
| 187 |
+
agent_actions = {}
|
| 188 |
+
for agent_id, obs in zone_obs.items():
|
| 189 |
+
obs_dict = obs.model_dump()
|
| 190 |
+
prompt = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
|
| 191 |
+
|
| 192 |
+
response = generate_fn(prompt)
|
| 193 |
+
action = extract_action(response)
|
| 194 |
+
agent_actions[agent_id] = action
|
| 195 |
+
|
| 196 |
+
result = env.step_multi(agent_actions)
|
| 197 |
+
|
| 198 |
+
total_reward += result.team_reward
|
| 199 |
+
rewards.append(result.team_reward)
|
| 200 |
+
for aid, r in result.rewards.items():
|
| 201 |
+
per_agent_rewards[aid].append(r.value)
|
| 202 |
+
|
| 203 |
+
# safety_reports is Dict[int, SafetyReport] — iterate values
|
| 204 |
+
safety_interventions += sum(
|
| 205 |
+
1 for sr in result.safety_reports.values() if sr.was_corrected
|
| 206 |
+
)
|
| 207 |
+
steps += 1
|
| 208 |
+
|
| 209 |
+
if result.done:
|
| 210 |
+
is_blackout = result.info.is_blackout
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
zone_obs = result.observations
|
| 214 |
+
|
| 215 |
+
return {
|
| 216 |
+
"total_reward": total_reward,
|
| 217 |
+
"rewards": rewards,
|
| 218 |
+
"per_agent_rewards": per_agent_rewards,
|
| 219 |
+
"steps": steps,
|
| 220 |
+
"is_blackout": is_blackout,
|
| 221 |
+
"safety_interventions": safety_interventions,
|
| 222 |
+
"avg_reward": total_reward / max(steps, 1),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ============================================================================
|
| 227 |
+
# GRPO Reward Functions
|
| 228 |
+
# ============================================================================
|
| 229 |
+
|
| 230 |
+
# Cache task configs to avoid re-deepcopy on every reward call
|
| 231 |
+
_REWARD_ENV_CACHE = {}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _get_reward_env(task_config: dict) -> OpenGridEnv:
|
| 235 |
+
"""Get a fresh environment for reward computation."""
|
| 236 |
+
env = OpenGridEnv(copy.deepcopy(task_config))
|
| 237 |
+
env.reset()
|
| 238 |
+
return env
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def compute_grpo_reward_env(
|
| 242 |
+
completions: list,
|
| 243 |
+
observations: list,
|
| 244 |
+
task_config: dict,
|
| 245 |
+
horizon: int = 3,
|
| 246 |
+
) -> list:
|
| 247 |
+
"""Environment-grounded reward: step the actual physics simulation.
|
| 248 |
+
|
| 249 |
+
For each LLM-generated action:
|
| 250 |
+
1. Restore the env to the observation state
|
| 251 |
+
2. Step with the proposed action and get the real reward
|
| 252 |
+
3. Run a short rollout (horizon steps) with heuristic continuation
|
| 253 |
+
to capture trajectory-level impact
|
| 254 |
+
4. Add format/schema bonuses
|
| 255 |
+
|
| 256 |
+
This directly addresses the proxy-reward disconnect that caused
|
| 257 |
+
the original GRPO training to show zero improvement.
|
| 258 |
+
"""
|
| 259 |
+
from src.baseline import heuristic_policy
|
| 260 |
+
|
| 261 |
+
rewards = []
|
| 262 |
+
for completion, obs_dict in zip(completions, observations):
|
| 263 |
+
if obs_dict is None:
|
| 264 |
+
rewards.append(0.0)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# Deserialize if needed (TRL may pass strings)
|
| 268 |
+
if isinstance(obs_dict, str):
|
| 269 |
+
try:
|
| 270 |
+
obs_dict = json.loads(obs_dict)
|
| 271 |
+
except (json.JSONDecodeError, TypeError):
|
| 272 |
+
rewards.append(0.0)
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
action = extract_action(completion)
|
| 276 |
+
has_adjustments = bool(action.bus_adjustments)
|
| 277 |
+
|
| 278 |
+
# ── 1. Format reward (small but keeps gradient alive) ──
|
| 279 |
+
format_score = 0.0
|
| 280 |
+
if has_adjustments:
|
| 281 |
+
format_score += 0.05
|
| 282 |
+
else:
|
| 283 |
+
freq = obs_dict.get('grid_frequency', 50.0)
|
| 284 |
+
if abs(freq - 50.0) < 0.05:
|
| 285 |
+
format_score += 0.05 # No-op when stable is fine
|
| 286 |
+
else:
|
| 287 |
+
format_score -= 0.05 # No-op during deviation is bad
|
| 288 |
+
|
| 289 |
+
# ── 2. Environment-grounded reward ──
|
| 290 |
+
try:
|
| 291 |
+
env = _get_reward_env(task_config)
|
| 292 |
+
env._set_state(obs_dict)
|
| 293 |
+
|
| 294 |
+
# Step with the LLM's proposed action
|
| 295 |
+
obs_after, reward, done, info = env.step(action)
|
| 296 |
+
env_score = reward.value
|
| 297 |
+
|
| 298 |
+
# Blackout = catastrophic
|
| 299 |
+
if info.is_blackout:
|
| 300 |
+
rewards.append(-1.0)
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
# ── 3. Mini-rollout: what happens next? ──
|
| 304 |
+
# Run a few more steps with heuristic to measure trajectory impact
|
| 305 |
+
rollout_reward = 0.0
|
| 306 |
+
for _ in range(horizon - 1):
|
| 307 |
+
if done:
|
| 308 |
+
break
|
| 309 |
+
h_action = heuristic_policy(obs_after)
|
| 310 |
+
obs_after, r, done, info = env.step(h_action)
|
| 311 |
+
rollout_reward += r.value
|
| 312 |
+
if info.is_blackout:
|
| 313 |
+
rollout_reward -= 10.0
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
# Combine: immediate reward + discounted future
|
| 317 |
+
total_env_score = env_score + 0.5 * rollout_reward
|
| 318 |
+
|
| 319 |
+
# Normalize to [-1, 1] range
|
| 320 |
+
# Typical per-step reward is ~0.5 to 1.5, rollout adds ~1-4
|
| 321 |
+
# So total_env_score is roughly in [-10, 4] range
|
| 322 |
+
normalized = total_env_score / 5.0
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
# Fallback: use lightweight heuristic scoring
|
| 326 |
+
normalized = _compute_heuristic_score(action, obs_dict)
|
| 327 |
+
|
| 328 |
+
total = format_score + normalized
|
| 329 |
+
rewards.append(max(-1.0, min(1.0, total)))
|
| 330 |
+
|
| 331 |
+
return rewards
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _compute_heuristic_score(action: GridAction, obs_dict: dict) -> float:
|
| 335 |
+
"""Lightweight fallback scorer when env rollout fails."""
|
| 336 |
+
score = 0.0
|
| 337 |
+
freq = obs_dict.get('grid_frequency', 50.0)
|
| 338 |
+
freq_error = freq - 50.0
|
| 339 |
+
abs_error = abs(freq_error)
|
| 340 |
+
|
| 341 |
+
if not action.bus_adjustments:
|
| 342 |
+
return 0.0
|
| 343 |
+
|
| 344 |
+
total_delta = sum(a.delta for a in action.bus_adjustments)
|
| 345 |
+
|
| 346 |
+
# Direction
|
| 347 |
+
if abs_error > 0.05:
|
| 348 |
+
correct = (freq_error < 0 and total_delta > 0) or \
|
| 349 |
+
(freq_error > 0 and total_delta < 0)
|
| 350 |
+
score += 0.3 if correct else -0.3
|
| 351 |
+
|
| 352 |
+
# Proportionality
|
| 353 |
+
if abs_error > 0.05:
|
| 354 |
+
ideal = abs(freq_error) * 15.0
|
| 355 |
+
actual = abs(total_delta)
|
| 356 |
+
if actual > 0.1:
|
| 357 |
+
ratio = min(actual, ideal) / max(actual, ideal, 0.1)
|
| 358 |
+
score += 0.2 * ratio
|
| 359 |
+
|
| 360 |
+
# Stability
|
| 361 |
+
if abs_error < 0.05 and abs(total_delta) < 2.0:
|
| 362 |
+
score += 0.1
|
| 363 |
+
|
| 364 |
+
return max(-0.5, min(0.5, score))
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# Keep old function for backward compat / test mode
|
| 368 |
+
def compute_grpo_reward(completions: list, observations: list, env_url: str = None) -> list:
|
| 369 |
+
"""Legacy heuristic reward (used in test mode only)."""
|
| 370 |
+
return [_compute_heuristic_score(extract_action(c), o or {})
|
| 371 |
+
for c, o in zip(completions, observations)]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# ============================================================================
|
| 375 |
+
# Training Loop
|
| 376 |
+
# ============================================================================
|
| 377 |
+
|
| 378 |
+
def train_grpo(args):
|
| 379 |
+
"""Main GRPO training loop using TRL."""
|
| 380 |
+
try:
|
| 381 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 382 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 383 |
+
except ImportError:
|
| 384 |
+
print("ERROR: TRL not installed. Run: pip install trl transformers")
|
| 385 |
+
print("For quantized training: pip install unsloth")
|
| 386 |
+
sys.exit(1)
|
| 387 |
+
|
| 388 |
+
print(f"[TRAIN] Model: {args.model}")
|
| 389 |
+
print(f"[TRAIN] Task: {args.task}")
|
| 390 |
+
print(f"[TRAIN] Epochs: {args.epochs}")
|
| 391 |
+
print(f"[TRAIN] Batch size: {args.batch_size}")
|
| 392 |
+
|
| 393 |
+
# Load model
|
| 394 |
+
if args.use_unsloth:
|
| 395 |
+
try:
|
| 396 |
+
from unsloth import FastLanguageModel
|
| 397 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 398 |
+
model_name=args.model,
|
| 399 |
+
max_seq_length=2048,
|
| 400 |
+
load_in_4bit=True,
|
| 401 |
+
)
|
| 402 |
+
model = FastLanguageModel.get_peft_model(
|
| 403 |
+
model,
|
| 404 |
+
r=16, lora_alpha=16, lora_dropout=0,
|
| 405 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 406 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 407 |
+
)
|
| 408 |
+
print("[TRAIN] Loaded with Unsloth 4-bit quantization")
|
| 409 |
+
except ImportError:
|
| 410 |
+
print("WARNING: Unsloth not available, falling back to standard loading")
|
| 411 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 412 |
+
model = AutoModelForCausalLM.from_pretrained(args.model)
|
| 413 |
+
else:
|
| 414 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
| 415 |
+
model = AutoModelForCausalLM.from_pretrained(args.model)
|
| 416 |
+
|
| 417 |
+
if tokenizer.pad_token is None:
|
| 418 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 419 |
+
|
| 420 |
+
# Prepare training data: observation prompts from the environment
|
| 421 |
+
task_config = copy.deepcopy(TASKS[args.task])
|
| 422 |
+
base_seed = task_config.get('seed', 42)
|
| 423 |
+
|
| 424 |
+
# Generate prompts with diverse grid states:
|
| 425 |
+
# - Larger random perturbations (-30 to +30 MW)
|
| 426 |
+
# - Adversarial states (drained batteries, high frequency deviation)
|
| 427 |
+
# - More steps per episode for temporal diversity
|
| 428 |
+
print("[TRAIN] Generating training prompts from environment...")
|
| 429 |
+
prompts = []
|
| 430 |
+
obs_contexts = []
|
| 431 |
+
rng = np.random.RandomState(base_seed)
|
| 432 |
+
|
| 433 |
+
steps_per_episode = min(15, task_config['max_steps'])
|
| 434 |
+
|
| 435 |
+
for episode in range(args.num_prompts):
|
| 436 |
+
ep_config = copy.deepcopy(task_config)
|
| 437 |
+
ep_config['seed'] = base_seed + episode
|
| 438 |
+
env = OpenGridEnv(ep_config)
|
| 439 |
+
zone_obs = env.reset_multi()
|
| 440 |
+
|
| 441 |
+
# Adversarial injection: every 5th episode, drain batteries
|
| 442 |
+
if episode % 5 == 0:
|
| 443 |
+
for b in env.bus_state:
|
| 444 |
+
b_cfg = env._find_bus_config(b['id'])
|
| 445 |
+
if b_cfg and b_cfg['type'] == 'battery':
|
| 446 |
+
b['soc'] = max(1.0, b['soc'] * 0.1) # Near-empty
|
| 447 |
+
|
| 448 |
+
for t in range(steps_per_episode):
|
| 449 |
+
for agent_id, obs in zone_obs.items():
|
| 450 |
+
obs_dict = json.loads(obs.model_dump_json())
|
| 451 |
+
prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)
|
| 452 |
+
|
| 453 |
+
messages = [
|
| 454 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 455 |
+
{"role": "user", "content": prompt_text},
|
| 456 |
+
]
|
| 457 |
+
|
| 458 |
+
formatted = tokenizer.apply_chat_template(
|
| 459 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 460 |
+
)
|
| 461 |
+
prompts.append(formatted)
|
| 462 |
+
obs_contexts.append(json.dumps(obs_dict)) # Store as string for Arrow compat
|
| 463 |
+
|
| 464 |
+
# Larger random perturbations for state diversity
|
| 465 |
+
random_actions = {}
|
| 466 |
+
for agent_id in range(env.num_agents):
|
| 467 |
+
zone_buses = task_config['zone_bus_ids'].get(agent_id, [])
|
| 468 |
+
controllable = [
|
| 469 |
+
bid for bid in zone_buses
|
| 470 |
+
if next((b for b in task_config['buses'] if b['id'] == bid), {}).get('type')
|
| 471 |
+
in ['generator', 'battery']
|
| 472 |
+
]
|
| 473 |
+
adj = []
|
| 474 |
+
if controllable:
|
| 475 |
+
# Pick 1-2 buses with larger perturbations
|
| 476 |
+
n_adj = min(len(controllable), rng.randint(1, 3))
|
| 477 |
+
chosen = rng.choice(controllable, size=n_adj, replace=False)
|
| 478 |
+
for bid in chosen:
|
| 479 |
+
adj.append(BusAdjustment(
|
| 480 |
+
bus_id=int(bid),
|
| 481 |
+
delta=float(rng.uniform(-30, 30)) # Was ±15
|
| 482 |
+
))
|
| 483 |
+
random_actions[agent_id] = GridAction(bus_adjustments=adj)
|
| 484 |
+
|
| 485 |
+
result = env.step_multi(random_actions)
|
| 486 |
+
if result.done:
|
| 487 |
+
break
|
| 488 |
+
zone_obs = result.observations
|
| 489 |
+
|
| 490 |
+
print(f"[TRAIN] Generated {len(prompts)} training prompts")
|
| 491 |
+
|
| 492 |
+
# GRPO reward function: environment-grounded
|
| 493 |
+
def reward_fn(completions, obs_context=None, **kwargs):
|
| 494 |
+
"""Environment-grounded GRPO reward.
|
| 495 |
+
|
| 496 |
+
Steps the actual physics simulation to score each action,
|
| 497 |
+
rather than using a disconnected heuristic proxy.
|
| 498 |
+
"""
|
| 499 |
+
texts = []
|
| 500 |
+
for c in completions:
|
| 501 |
+
if isinstance(c, list):
|
| 502 |
+
text = c[-1]['content'] if c else ""
|
| 503 |
+
else:
|
| 504 |
+
text = str(c)
|
| 505 |
+
texts.append(text)
|
| 506 |
+
|
| 507 |
+
if obs_context is None:
|
| 508 |
+
obs_context = [None] * len(texts)
|
| 509 |
+
|
| 510 |
+
# Deserialize obs_context strings
|
| 511 |
+
obs_dicts = []
|
| 512 |
+
for ctx in obs_context:
|
| 513 |
+
if isinstance(ctx, str):
|
| 514 |
+
try:
|
| 515 |
+
obs_dicts.append(json.loads(ctx))
|
| 516 |
+
except (json.JSONDecodeError, TypeError):
|
| 517 |
+
obs_dicts.append(None)
|
| 518 |
+
else:
|
| 519 |
+
obs_dicts.append(ctx)
|
| 520 |
+
|
| 521 |
+
return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=3)
|
| 522 |
+
|
| 523 |
+
# GRPO Config — tuned for sustained learning signal
|
| 524 |
+
grpo_config = GRPOConfig(
|
| 525 |
+
output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
|
| 526 |
+
num_train_epochs=args.epochs,
|
| 527 |
+
per_device_train_batch_size=args.batch_size,
|
| 528 |
+
gradient_accumulation_steps=max(1, 16 // args.batch_size),
|
| 529 |
+
learning_rate=1e-5, # Was 5e-6 — slightly more aggressive
|
| 530 |
+
logging_steps=5,
|
| 531 |
+
save_steps=50,
|
| 532 |
+
max_completion_length=256,
|
| 533 |
+
num_generations=8, # Was 4 — wider group for better ranking signal
|
| 534 |
+
report_to="none",
|
| 535 |
+
remove_unused_columns=False,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Create dataset — include obs_context so TRL passes it to reward_fn
|
| 539 |
+
from datasets import Dataset
|
| 540 |
+
train_dataset = Dataset.from_dict({
|
| 541 |
+
"prompt": prompts,
|
| 542 |
+
"obs_context": obs_contexts,
|
| 543 |
+
})
|
| 544 |
+
|
| 545 |
+
# Initialize trainer
|
| 546 |
+
trainer = GRPOTrainer(
|
| 547 |
+
model=model,
|
| 548 |
+
args=grpo_config,
|
| 549 |
+
train_dataset=train_dataset,
|
| 550 |
+
reward_funcs=reward_fn,
|
| 551 |
+
processing_class=tokenizer,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# Train
|
| 555 |
+
print("[TRAIN] Starting GRPO training...")
|
| 556 |
+
train_result = trainer.train()
|
| 557 |
+
|
| 558 |
+
# Save model
|
| 559 |
+
output_path = Path(args.output_dir) / "trained_model"
|
| 560 |
+
trainer.save_model(str(output_path))
|
| 561 |
+
tokenizer.save_pretrained(str(output_path))
|
| 562 |
+
print(f"[TRAIN] Model saved to {output_path}")
|
| 563 |
+
|
| 564 |
+
return train_result
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# ============================================================================
|
| 568 |
+
# Evaluation & Plotting
|
| 569 |
+
# ============================================================================
|
| 570 |
+
|
| 571 |
+
def evaluate_model(generate_fn, task_ids=None, n_episodes=3, multi_agent=True):
|
| 572 |
+
"""Evaluate a model across tasks. Returns per-task results.
|
| 573 |
+
|
| 574 |
+
Each episode uses a distinct seed to produce meaningful variance.
|
| 575 |
+
"""
|
| 576 |
+
if task_ids is None:
|
| 577 |
+
task_ids = list(TASKS.keys())
|
| 578 |
+
|
| 579 |
+
results = {}
|
| 580 |
+
for task_id in task_ids:
|
| 581 |
+
base_config = TASKS[task_id]
|
| 582 |
+
base_seed = base_config.get('seed', 42)
|
| 583 |
+
episode_rewards = []
|
| 584 |
+
|
| 585 |
+
for ep in range(n_episodes):
|
| 586 |
+
# Vary seed per episode to get independent rollouts
|
| 587 |
+
ep_config = copy.deepcopy(base_config)
|
| 588 |
+
ep_config['seed'] = base_seed + ep
|
| 589 |
+
env = OpenGridEnv(ep_config)
|
| 590 |
+
|
| 591 |
+
if multi_agent:
|
| 592 |
+
data = rollout_multi_agent(env, generate_fn, ep_config)
|
| 593 |
+
else:
|
| 594 |
+
data = rollout_single_agent(env, generate_fn, ep_config)
|
| 595 |
+
episode_rewards.append(data['total_reward'])
|
| 596 |
+
|
| 597 |
+
results[task_id] = {
|
| 598 |
+
"avg_reward": np.mean(episode_rewards),
|
| 599 |
+
"std_reward": np.std(episode_rewards),
|
| 600 |
+
"rewards": episode_rewards,
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
return results
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def plot_training_curves(training_log: list, output_path: str):
|
| 607 |
+
"""Generate reward curves from training log."""
|
| 608 |
+
if not training_log:
|
| 609 |
+
print("[PLOT] No training data to plot.")
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 613 |
+
|
| 614 |
+
# Reward curve
|
| 615 |
+
steps = range(len(training_log))
|
| 616 |
+
rewards = [entry.get('reward', 0) for entry in training_log]
|
| 617 |
+
|
| 618 |
+
axes[0].plot(steps, rewards, color='#00d4aa', linewidth=1.5, alpha=0.6, label='Step Reward')
|
| 619 |
+
|
| 620 |
+
# Smoothed reward
|
| 621 |
+
if len(rewards) > 10:
|
| 622 |
+
window = min(20, len(rewards) // 5)
|
| 623 |
+
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
| 624 |
+
axes[0].plot(range(window-1, len(rewards)), smoothed, color='#00d4aa',
|
| 625 |
+
linewidth=2.5, label=f'Smoothed (window={window})')
|
| 626 |
+
|
| 627 |
+
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
|
| 628 |
+
axes[0].set_xlabel('Training Step')
|
| 629 |
+
axes[0].set_ylabel('Reward')
|
| 630 |
+
axes[0].set_title('GRPO Training — Reward Curve')
|
| 631 |
+
axes[0].legend()
|
| 632 |
+
axes[0].grid(True, alpha=0.3)
|
| 633 |
+
|
| 634 |
+
# Loss curve (if available)
|
| 635 |
+
losses = [entry.get('loss', 0) for entry in training_log if 'loss' in entry]
|
| 636 |
+
if losses:
|
| 637 |
+
axes[1].plot(range(len(losses)), losses, color='#ff6b6b', linewidth=1.5)
|
| 638 |
+
axes[1].set_xlabel('Training Step')
|
| 639 |
+
axes[1].set_ylabel('Loss')
|
| 640 |
+
axes[1].set_title('Training Loss')
|
| 641 |
+
axes[1].grid(True, alpha=0.3)
|
| 642 |
+
else:
|
| 643 |
+
axes[1].text(0.5, 0.5, 'Loss data not available', ha='center', va='center',
|
| 644 |
+
transform=axes[1].transAxes, fontsize=14, color='gray')
|
| 645 |
+
axes[1].set_title('Training Loss')
|
| 646 |
+
|
| 647 |
+
plt.tight_layout()
|
| 648 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 649 |
+
plt.close()
|
| 650 |
+
print(f"[PLOT] Saved training curves to {output_path}")
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def plot_before_after(before_results: dict, after_results: dict, output_path: str):
|
| 654 |
+
"""Generate before/after comparison chart."""
|
| 655 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 656 |
+
|
| 657 |
+
tasks = list(before_results.keys())
|
| 658 |
+
x = np.arange(len(tasks))
|
| 659 |
+
width = 0.35
|
| 660 |
+
|
| 661 |
+
before_vals = [before_results[t]['avg_reward'] for t in tasks]
|
| 662 |
+
after_vals = [after_results[t]['avg_reward'] for t in tasks]
|
| 663 |
+
|
| 664 |
+
bars1 = ax.bar(x - width/2, before_vals, width, label='Before Training',
|
| 665 |
+
color='#ff6b6b', alpha=0.8)
|
| 666 |
+
bars2 = ax.bar(x + width/2, after_vals, width, label='After Training',
|
| 667 |
+
color='#00d4aa', alpha=0.8)
|
| 668 |
+
|
| 669 |
+
ax.set_xlabel('Task')
|
| 670 |
+
ax.set_ylabel('Average Episode Reward')
|
| 671 |
+
ax.set_title('OpenGrid — GRPO Training: Before vs After')
|
| 672 |
+
ax.set_xticks(x)
|
| 673 |
+
ax.set_xticklabels([t.replace('task_', '').title() for t in tasks])
|
| 674 |
+
ax.legend()
|
| 675 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 676 |
+
|
| 677 |
+
# Add value labels on bars (handle negative heights)
|
| 678 |
+
for bar in list(bars1) + list(bars2):
|
| 679 |
+
h = bar.get_height()
|
| 680 |
+
va = 'bottom' if h >= 0 else 'top'
|
| 681 |
+
offset = 1 if h >= 0 else -1
|
| 682 |
+
ax.text(bar.get_x() + bar.get_width()/2., h + offset,
|
| 683 |
+
f'{h:.1f}', ha='center', va=va, fontsize=9)
|
| 684 |
+
|
| 685 |
+
plt.tight_layout()
|
| 686 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 687 |
+
plt.close()
|
| 688 |
+
print(f"[PLOT] Saved before/after comparison to {output_path}")
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
# ============================================================================
|
| 692 |
+
# Test Mode
|
| 693 |
+
# ============================================================================
|
| 694 |
+
|
| 695 |
+
def run_test_mode():
|
| 696 |
+
"""Quick pipeline verification without GPU. Runs a few episodes with heuristic."""
|
| 697 |
+
print("\n" + "="*60)
|
| 698 |
+
print(" OpenGrid GRPO Training — TEST MODE")
|
| 699 |
+
print(" (Verifies the pipeline without training)")
|
| 700 |
+
print("="*60 + "\n")
|
| 701 |
+
|
| 702 |
+
# Test 1: Prompt generation
|
| 703 |
+
print("[TEST] Generating prompts...")
|
| 704 |
+
env = OpenGridEnv(TASKS["task_easy"])
|
| 705 |
+
zone_obs = env.reset_multi()
|
| 706 |
+
for agent_id, obs in zone_obs.items():
|
| 707 |
+
prompt = format_observation_prompt(obs.model_dump(), zone_name=obs.zone_name)
|
| 708 |
+
print(f"\n--- Agent {agent_id} ({obs.zone_name}) ---")
|
| 709 |
+
print(prompt[:500])
|
| 710 |
+
|
| 711 |
+
# Test 2: Action extraction
|
| 712 |
+
print("\n[TEST] Testing action extraction...")
|
| 713 |
+
test_cases = [
|
| 714 |
+
'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}',
|
| 715 |
+
'Here is my action: {"bus_adjustments": [], "topology_actions": []}',
|
| 716 |
+
'invalid garbage',
|
| 717 |
+
]
|
| 718 |
+
for tc in test_cases:
|
| 719 |
+
action = extract_action(tc)
|
| 720 |
+
print(f" Input: {tc[:60]}... -> {len(action.bus_adjustments)} adjustments")
|
| 721 |
+
|
| 722 |
+
# Test 3: Multi-agent rollout with heuristic
|
| 723 |
+
print("\n[TEST] Running multi-agent rollout...")
|
| 724 |
+
from src.baseline import heuristic_policy
|
| 725 |
+
|
| 726 |
+
def heuristic_generate(prompt):
|
| 727 |
+
"""Pseudo-LLM: use heuristic policy and format as JSON."""
|
| 728 |
+
# Extract frequency from prompt (handles negative/signed values)
|
| 729 |
+
freq_match = re.search(r'Frequency:\s*([-+]?\d+(?:\.\d+)?)', prompt)
|
| 730 |
+
freq = float(freq_match.group(1)) if freq_match else 50.0
|
| 731 |
+
|
| 732 |
+
# Simple proportional control
|
| 733 |
+
error = 50.0 - freq
|
| 734 |
+
delta = error * 10 # proportional gain
|
| 735 |
+
delta = max(-20, min(20, delta))
|
| 736 |
+
|
| 737 |
+
# Find controllable buses (generator/battery, NOT slack — physics overwrites it)
|
| 738 |
+
bus_matches = re.findall(r'Bus (\d+) \((generator|battery)\)', prompt)
|
| 739 |
+
if bus_matches:
|
| 740 |
+
# Distribute across all controllable buses
|
| 741 |
+
per_bus = delta / len(bus_matches)
|
| 742 |
+
adjustments = [
|
| 743 |
+
{"bus_id": int(m[0]), "delta": round(per_bus, 1)}
|
| 744 |
+
for m in bus_matches
|
| 745 |
+
]
|
| 746 |
+
return json.dumps({
|
| 747 |
+
"bus_adjustments": adjustments,
|
| 748 |
+
"topology_actions": []
|
| 749 |
+
})
|
| 750 |
+
return json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 751 |
+
|
| 752 |
+
for task_id in ["task_easy", "task_medium"]:
|
| 753 |
+
config = copy.deepcopy(TASKS[task_id])
|
| 754 |
+
env = OpenGridEnv(config)
|
| 755 |
+
result = rollout_multi_agent(env, heuristic_generate, config)
|
| 756 |
+
print(f" {task_id}: reward={result['total_reward']:.2f}, "
|
| 757 |
+
f"steps={result['steps']}, blackout={result['is_blackout']}, "
|
| 758 |
+
f"safety_interventions={result['safety_interventions']}")
|
| 759 |
+
|
| 760 |
+
# Test 4: Reward function
|
| 761 |
+
print("\n[TEST] Testing GRPO reward function...")
|
| 762 |
+
test_completions = [
|
| 763 |
+
'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}',
|
| 764 |
+
'{"bus_adjustments": [], "topology_actions": []}',
|
| 765 |
+
'not valid json at all',
|
| 766 |
+
]
|
| 767 |
+
test_obs = [{"grid_frequency": 49.5}, {"grid_frequency": 50.0}, {"grid_frequency": 50.3}]
|
| 768 |
+
grpo_rewards = compute_grpo_reward(test_completions, test_obs)
|
| 769 |
+
for tc, r in zip(test_completions, grpo_rewards):
|
| 770 |
+
print(f" Reward: {r:.2f} for: {tc[:50]}...")
|
| 771 |
+
|
| 772 |
+
# Test 5: Generate plots
|
| 773 |
+
output_dir = Path("training/outputs")
|
| 774 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 775 |
+
|
| 776 |
+
fake_log = [{"reward": np.random.normal(0.5, 0.3) + i * 0.01, "loss": 2.0 - i * 0.02}
|
| 777 |
+
for i in range(100)]
|
| 778 |
+
plot_training_curves(fake_log, str(output_dir / "test_training_curves.png"))
|
| 779 |
+
|
| 780 |
+
fake_before = {t: {"avg_reward": np.random.uniform(20, 35)} for t in TASKS}
|
| 781 |
+
fake_after = {t: {"avg_reward": np.random.uniform(40, 55)} for t in TASKS}
|
| 782 |
+
plot_before_after(fake_before, fake_after, str(output_dir / "test_before_after.png"))
|
| 783 |
+
|
| 784 |
+
print("\n" + "="*60)
|
| 785 |
+
print(" [OK] ALL TESTS PASSED - Pipeline is ready for GPU training")
|
| 786 |
+
print("="*60)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
# ============================================================================
|
| 790 |
+
# Main
|
| 791 |
+
# ============================================================================
|
| 792 |
+
|
| 793 |
+
def main():
|
| 794 |
+
parser = argparse.ArgumentParser(description="OpenGrid GRPO Training")
|
| 795 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
|
| 796 |
+
help="HuggingFace model name or path")
|
| 797 |
+
parser.add_argument("--task", default="task_easy", choices=list(TASKS.keys()),
|
| 798 |
+
help="Which task to train on")
|
| 799 |
+
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
| 800 |
+
parser.add_argument("--batch-size", type=int, default=2, help="Batch size per device")
|
| 801 |
+
parser.add_argument("--num-prompts", type=int, default=50,
|
| 802 |
+
help="Number of episodes to generate prompts from")
|
| 803 |
+
parser.add_argument("--output-dir", default="training/outputs",
|
| 804 |
+
help="Directory for checkpoints and plots")
|
| 805 |
+
parser.add_argument("--use-unsloth", action="store_true",
|
| 806 |
+
help="Use Unsloth for 4-bit quantized training")
|
| 807 |
+
parser.add_argument("--test-mode", action="store_true",
|
| 808 |
+
help="Run pipeline verification without GPU")
|
| 809 |
+
|
| 810 |
+
args = parser.parse_args()
|
| 811 |
+
|
| 812 |
+
if args.test_mode:
|
| 813 |
+
run_test_mode()
|
| 814 |
+
return
|
| 815 |
+
|
| 816 |
+
# Create output directory
|
| 817 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 818 |
+
|
| 819 |
+
# Run training
|
| 820 |
+
train_result = train_grpo(args)
|
| 821 |
+
|
| 822 |
+
print("\n[DONE] Training complete!")
|
| 823 |
+
print(f" Output: {args.output_dir}")
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
if __name__ == "__main__":
|
| 827 |
+
main()
|
validate-submission.sh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -uo pipefail
|
| 3 |
+
|
| 4 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 5 |
+
RED='\033[0;31m'
|
| 6 |
+
GREEN='\033[0;32m'
|
| 7 |
+
YELLOW='\033[1;33m'
|
| 8 |
+
BOLD='\033[1m'
|
| 9 |
+
NC='\033[0m'
|
| 10 |
+
|
| 11 |
+
PING_URL="${1:-}"
|
| 12 |
+
REPO_DIR="${2:-.}"
|
| 13 |
+
|
| 14 |
+
if [ -z "$PING_URL" ]; then
|
| 15 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"
|
| 20 |
+
PING_URL="${PING_URL%/}"
|
| 21 |
+
PASS=0
|
| 22 |
+
|
| 23 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 24 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 25 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 26 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 27 |
+
stop_at() {
|
| 28 |
+
printf "\n${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 29 |
+
exit 1
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
printf "\n${BOLD}========================================${NC}\n"
|
| 33 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 34 |
+
printf "${BOLD}========================================${NC}\n"
|
| 35 |
+
log "Repo: $REPO_DIR"
|
| 36 |
+
log "Ping URL: $PING_URL"
|
| 37 |
+
printf "\n"
|
| 38 |
+
|
| 39 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 40 |
+
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST \
|
| 41 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 42 |
+
"$PING_URL/reset" --max-time 30 2>/dev/null || printf "000")
|
| 43 |
+
|
| 44 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 45 |
+
pass "HF Space is live and responds to /reset"
|
| 46 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 47 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 48 |
+
stop_at "Step 1"
|
| 49 |
+
else
|
| 50 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 51 |
+
stop_at "Step 1"
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 55 |
+
if ! command -v docker &>/dev/null; then
|
| 56 |
+
fail "docker command not found"
|
| 57 |
+
stop_at "Step 2"
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 61 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 62 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 63 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 64 |
+
else
|
| 65 |
+
fail "No Dockerfile found"
|
| 66 |
+
stop_at "Step 2"
|
| 67 |
+
fi
|
| 68 |
+
|
| 69 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 70 |
+
BUILD_OK=false
|
| 71 |
+
BUILD_OUTPUT=$(timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 72 |
+
|
| 73 |
+
if [ "$BUILD_OK" = true ]; then
|
| 74 |
+
pass "Docker build succeeded"
|
| 75 |
+
else
|
| 76 |
+
fail "Docker build failed"
|
| 77 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 78 |
+
stop_at "Step 2"
|
| 79 |
+
fi
|
| 80 |
+
|
| 81 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 82 |
+
if ! command -v openenv &>/dev/null; then
|
| 83 |
+
fail "openenv command not found"
|
| 84 |
+
hint "Install it: pip install openenv-core"
|
| 85 |
+
stop_at "Step 3"
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
VALIDATE_OK=false
|
| 89 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 90 |
+
|
| 91 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 92 |
+
pass "openenv validate passed"
|
| 93 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 94 |
+
else
|
| 95 |
+
fail "openenv validate failed"
|
| 96 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 97 |
+
stop_at "Step 3"
|
| 98 |
+
fi
|
| 99 |
+
|
| 100 |
+
printf "\n${BOLD}========================================${NC}\n"
|
| 101 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 102 |
+
printf "${BOLD}========================================${NC}\n\n"
|
| 103 |
+
exit 0
|