feat: curriculum training + Karnataka scenarios + repo cleanup
Browse files- Add Karnataka scenario variants (easy/medium/hard) with same 15-bus topology
but different operating conditions (renewables, load, line capacity)
- Add curriculum training mode (--curriculum flag) that chains through
karnataka_easy -> medium -> hard -> full with checkpoint transfer
- Restructure repo: docs/ for documentation, scripts/ for utilities
- Clean up generated blobs (codebase_summary.md, inference_output.txt)
- Frontend: dynamic task groups (Procedural green, Karnataka gold)
- Map: no tiles for procedural grids (dark canvas), locked bounds for Karnataka
- Fix zone names: generic for procedural, KPTCL-specific for Karnataka
- Fix Raichur TPS GPS coordinates (16.36, 77.34)
- All 7 training pipeline checks pass
- .gitignore +4 -9
- changes.md → docs/changes.md +0 -0
- inference.py +3 -1
- open-grid logo.png +3 -0
- run_training.py +2 -2
- scripts/generate_code_md.py +92 -0
- scripts/get_scores.py +37 -0
- scripts/verify_training.py +161 -0
- src/scenarios.py +113 -0
- src/tasks.py +23 -14
- static/app.js +96 -26
- static/index.html +57 -53
- static/style.css +199 -70
- training/train_grpo.py +99 -6
.gitignore
CHANGED
|
@@ -13,14 +13,12 @@ build/
|
|
| 13 |
# Generated / temporary files
|
| 14 |
inference_output.txt
|
| 15 |
codebase_summary.md
|
| 16 |
-
generate_code_md.py
|
| 17 |
uv.lock
|
| 18 |
|
| 19 |
-
# Reference docs (not part of submission)
|
| 20 |
-
guide.md
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
project-spec.md
|
| 24 |
pyrightconfig.json
|
| 25 |
|
| 26 |
# Training outputs (large files — push separately or add to HF)
|
|
@@ -31,6 +29,3 @@ training/outputs/
|
|
| 31 |
# OS files
|
| 32 |
Thumbs.db
|
| 33 |
.DS_Store
|
| 34 |
-
|
| 35 |
-
# Duplicate test file (tests/ directory has the real one)
|
| 36 |
-
test_multiagent.py
|
|
|
|
| 13 |
# Generated / temporary files
|
| 14 |
inference_output.txt
|
| 15 |
codebase_summary.md
|
|
|
|
| 16 |
uv.lock
|
| 17 |
|
| 18 |
+
# Reference docs (moved to docs/ — not part of submission)
|
| 19 |
+
docs/guide.md
|
| 20 |
+
docs/detailed_judging_criteria.md
|
| 21 |
+
docs/project-spec.md
|
|
|
|
| 22 |
pyrightconfig.json
|
| 23 |
|
| 24 |
# Training outputs (large files — push separately or add to HF)
|
|
|
|
| 29 |
# OS files
|
| 30 |
Thumbs.db
|
| 31 |
.DS_Store
|
|
|
|
|
|
|
|
|
changes.md → docs/changes.md
RENAMED
|
File without changes
|
inference.py
CHANGED
|
@@ -48,7 +48,9 @@ BENCHMARK = "OpenGrid"
|
|
| 48 |
MAX_STEPS = 100
|
| 49 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 50 |
|
| 51 |
-
TASKS = ["task_easy", "task_medium", "task_hard",
|
|
|
|
|
|
|
| 52 |
|
| 53 |
SYSTEM_PROMPT_SINGLE = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
|
| 54 |
|
|
|
|
| 48 |
MAX_STEPS = 100
|
| 49 |
SUCCESS_SCORE_THRESHOLD = 0.5
|
| 50 |
|
| 51 |
+
TASKS = ["task_easy", "task_medium", "task_hard",
|
| 52 |
+
"karnataka_easy", "karnataka_medium", "karnataka_hard",
|
| 53 |
+
"task_karnataka"]
|
| 54 |
|
| 55 |
SYSTEM_PROMPT_SINGLE = """You are a Power Grid Controller AI. Your goal is to maintain grid stability.
|
| 56 |
|
open-grid logo.png
ADDED
|
Git LFS Details
|
run_training.py
CHANGED
|
@@ -79,7 +79,7 @@ def run_grpo_training():
|
|
| 79 |
return json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 80 |
|
| 81 |
baseline_results = {}
|
| 82 |
-
for task_id in ["task_easy", "task_medium", "task_karnataka"]:
|
| 83 |
if task_id not in TASKS:
|
| 84 |
continue
|
| 85 |
config = TASKS[task_id]
|
|
@@ -238,7 +238,7 @@ def run_grpo_training():
|
|
| 238 |
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
| 239 |
|
| 240 |
trained_results = {}
|
| 241 |
-
for task_id in ["task_easy", "task_medium", "task_karnataka"]:
|
| 242 |
if task_id not in TASKS:
|
| 243 |
continue
|
| 244 |
config = TASKS[task_id]
|
|
|
|
| 79 |
return json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 80 |
|
| 81 |
baseline_results = {}
|
| 82 |
+
for task_id in ["task_easy", "task_medium", "karnataka_easy", "karnataka_medium", "karnataka_hard", "task_karnataka"]:
|
| 83 |
if task_id not in TASKS:
|
| 84 |
continue
|
| 85 |
config = TASKS[task_id]
|
|
|
|
| 238 |
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
| 239 |
|
| 240 |
trained_results = {}
|
| 241 |
+
for task_id in ["task_easy", "task_medium", "karnataka_easy", "karnataka_medium", "karnataka_hard", "task_karnataka"]:
|
| 242 |
if task_id not in TASKS:
|
| 243 |
continue
|
| 244 |
config = TASKS[task_id]
|
scripts/generate_code_md.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def generate_tree(dir_path, ignore_dirs=None, prefix=""):
|
| 4 |
+
"""Generates a text representation of the folder structure."""
|
| 5 |
+
if ignore_dirs is None:
|
| 6 |
+
ignore_dirs = {'.git', '__pycache__', 'venv', '.venv', 'env', 'node_modules', '.idea', '.vscode', 'build', 'dist'}
|
| 7 |
+
|
| 8 |
+
tree_str = ""
|
| 9 |
+
try:
|
| 10 |
+
items = os.listdir(dir_path)
|
| 11 |
+
except PermissionError:
|
| 12 |
+
return ""
|
| 13 |
+
|
| 14 |
+
items.sort()
|
| 15 |
+
# Filter out ignored directories
|
| 16 |
+
items = [item for item in items if item not in ignore_dirs]
|
| 17 |
+
|
| 18 |
+
for i, item in enumerate(items):
|
| 19 |
+
path = os.path.join(dir_path, item)
|
| 20 |
+
is_last = i == (len(items) - 1)
|
| 21 |
+
if is_last:
|
| 22 |
+
tree_str += f"{prefix}└── {item}\n"
|
| 23 |
+
new_prefix = prefix + " "
|
| 24 |
+
else:
|
| 25 |
+
tree_str += f"{prefix}├── {item}\n"
|
| 26 |
+
new_prefix = prefix + "│ "
|
| 27 |
+
|
| 28 |
+
if os.path.isdir(path):
|
| 29 |
+
tree_str += generate_tree(path, ignore_dirs, new_prefix)
|
| 30 |
+
|
| 31 |
+
return tree_str
|
| 32 |
+
|
| 33 |
+
def generate_markdown(output_file="codebase_summary.md", source_dir=".", ignore_dirs=None, ignore_exts=None):
|
| 34 |
+
"""Creates a markdown file with the folder structure and content of all files."""
|
| 35 |
+
if ignore_dirs is None:
|
| 36 |
+
ignore_dirs = {'.git', '__pycache__', 'venv', '.venv', 'env', 'node_modules', '.idea', '.vscode', 'build', 'dist'}
|
| 37 |
+
|
| 38 |
+
# Common binary and non-text files to ignore
|
| 39 |
+
if ignore_exts is None:
|
| 40 |
+
ignore_exts = {'.pyc', '.pyo', '.pyd', '.so', '.dll', '.exe', '.bin',
|
| 41 |
+
'.png', '.jpg', '.jpeg', '.gif', '.ico', '.pdf',
|
| 42 |
+
'.zip', '.tar', '.gz', '.mp4', '.mp3', '.sqlite3'}
|
| 43 |
+
|
| 44 |
+
print(f"Generating {output_file}...")
|
| 45 |
+
|
| 46 |
+
with open(output_file, 'w', encoding='utf-8') as outfile:
|
| 47 |
+
outfile.write("# Codebase Structure\n\n")
|
| 48 |
+
outfile.write("```text\n")
|
| 49 |
+
outfile.write(f"{os.path.basename(os.path.abspath(source_dir))}/\n")
|
| 50 |
+
outfile.write(generate_tree(source_dir, ignore_dirs))
|
| 51 |
+
outfile.write("```\n\n")
|
| 52 |
+
|
| 53 |
+
outfile.write("# Code Files\n\n")
|
| 54 |
+
|
| 55 |
+
for root, dirs, files in os.walk(source_dir):
|
| 56 |
+
# Modify dirs in-place to skip ignored directories in os.walk
|
| 57 |
+
dirs[:] = [d for d in dirs if d not in ignore_dirs]
|
| 58 |
+
|
| 59 |
+
for file in files:
|
| 60 |
+
ext = os.path.splitext(file)[1].lower()
|
| 61 |
+
# Skip binary files, images, or the output file itself
|
| 62 |
+
if ext in ignore_exts or file == output_file or file == os.path.basename(__file__):
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
file_path = os.path.join(root, file)
|
| 66 |
+
rel_path = os.path.relpath(file_path, source_dir)
|
| 67 |
+
|
| 68 |
+
outfile.write(f"## `{rel_path}`\n\n")
|
| 69 |
+
|
| 70 |
+
# Determine language for markdown block
|
| 71 |
+
lang = ext[1:] if ext else "text"
|
| 72 |
+
if lang == "txt": lang = "text"
|
| 73 |
+
|
| 74 |
+
outfile.write(f"```{lang}\n")
|
| 75 |
+
try:
|
| 76 |
+
with open(file_path, 'r', encoding='utf-8') as infile:
|
| 77 |
+
content = infile.read()
|
| 78 |
+
outfile.write(content)
|
| 79 |
+
# Ensure there is a newline at the end of the content before closing the code block
|
| 80 |
+
if content and not content.endswith('\n'):
|
| 81 |
+
outfile.write('\n')
|
| 82 |
+
except UnicodeDecodeError:
|
| 83 |
+
outfile.write(f"// File appears to be binary or has an unsupported encoding and could not be read.\n")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
outfile.write(f"// Error reading file: {e}\n")
|
| 86 |
+
outfile.write("```\n\n")
|
| 87 |
+
|
| 88 |
+
print(f"Successfully generated {output_file}!")
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
generate_markdown()
|
| 92 |
+
|
scripts/get_scores.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluate heuristic baseline on all tasks and print scores."""
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
from src.tasks import TASKS
|
| 8 |
+
from src.grader import RobustnessGrader
|
| 9 |
+
from src.baseline import heuristic_policy
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(n_episodes: int = 10):
|
| 13 |
+
all_results = {}
|
| 14 |
+
|
| 15 |
+
for tid, cfg in TASKS.items():
|
| 16 |
+
try:
|
| 17 |
+
grader = RobustnessGrader(copy.deepcopy(cfg))
|
| 18 |
+
result = grader.evaluate_policy(
|
| 19 |
+
heuristic_policy, n_episodes=n_episodes
|
| 20 |
+
)
|
| 21 |
+
all_results[tid] = result
|
| 22 |
+
|
| 23 |
+
print(f"{tid}:")
|
| 24 |
+
for k, v in result.items():
|
| 25 |
+
print(f" {k}: {v}")
|
| 26 |
+
print()
|
| 27 |
+
|
| 28 |
+
except Exception as e:
|
| 29 |
+
all_results[tid] = {"error": str(e)}
|
| 30 |
+
print(f"{tid}: FAILED — {e}\n")
|
| 31 |
+
|
| 32 |
+
return all_results
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
| 37 |
+
main(n_episodes=episodes)
|
scripts/verify_training.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive training pipeline verification.
|
| 3 |
+
Tests: scenarios, reward functions, policies, GRPO integration, safety.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import copy
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.insert(0, ".")
|
| 9 |
+
|
| 10 |
+
from src.tasks import TASKS, get_task
|
| 11 |
+
from src.environment import OpenGridEnv
|
| 12 |
+
from src.models import GridAction, GridObservation
|
| 13 |
+
from src.grader import RobustnessGrader
|
| 14 |
+
from src.baseline import heuristic_policy
|
| 15 |
+
from src.safety import SafetyLayer
|
| 16 |
+
|
| 17 |
+
print("=" * 60)
|
| 18 |
+
print(" COMPREHENSIVE TRAINING PIPELINE VERIFICATION")
|
| 19 |
+
print("=" * 60)
|
| 20 |
+
|
| 21 |
+
errors = []
|
| 22 |
+
|
| 23 |
+
# --- 1. Scenario loading ---
|
| 24 |
+
print("\n[1/7] Scenario Loading...")
|
| 25 |
+
expected_tasks = ["task_easy", "task_medium", "task_hard",
|
| 26 |
+
"task_karnataka", "karnataka_easy", "karnataka_medium", "karnataka_hard"]
|
| 27 |
+
for tid in expected_tasks:
|
| 28 |
+
if tid not in TASKS:
|
| 29 |
+
errors.append(f"Missing task: {tid}")
|
| 30 |
+
print(f" FAIL: {tid} not in TASKS")
|
| 31 |
+
else:
|
| 32 |
+
cfg = TASKS[tid]
|
| 33 |
+
print(f" OK: {tid} - {cfg['num_buses']}b/{cfg['num_agents']}a zones={cfg['zone_names']}")
|
| 34 |
+
|
| 35 |
+
# --- 2. Environment step for each scenario ---
|
| 36 |
+
print("\n[2/7] Environment Step Test...")
|
| 37 |
+
for tid in expected_tasks:
|
| 38 |
+
try:
|
| 39 |
+
cfg = get_task(tid)
|
| 40 |
+
env = OpenGridEnv(cfg)
|
| 41 |
+
obs = env.reset()
|
| 42 |
+
action = GridAction.model_validate_json(
|
| 43 |
+
json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 44 |
+
)
|
| 45 |
+
obs2, reward, done, info = env.step(action)
|
| 46 |
+
freq = obs2.grid_frequency
|
| 47 |
+
r = reward.value
|
| 48 |
+
print(f" OK: {tid} - freq={freq:.2f}Hz reward={r:.2f}")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
errors.append(f"Env step failed for {tid}: {e}")
|
| 51 |
+
print(f" FAIL: {tid} - {e}")
|
| 52 |
+
|
| 53 |
+
# --- 3. Reward function (GRPO) test ---
|
| 54 |
+
print("\n[3/7] GRPO Reward Function Test...")
|
| 55 |
+
from training.train_grpo import compute_grpo_reward_env
|
| 56 |
+
test_completions = [
|
| 57 |
+
'{"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []}',
|
| 58 |
+
'{"bus_adjustments": [], "topology_actions": []}',
|
| 59 |
+
'not valid json',
|
| 60 |
+
]
|
| 61 |
+
test_observations = [
|
| 62 |
+
{"grid_frequency": 49.5, "buses": [], "lines": []},
|
| 63 |
+
{"grid_frequency": 50.0, "buses": [], "lines": []},
|
| 64 |
+
{"grid_frequency": 48.0, "buses": [], "lines": []},
|
| 65 |
+
]
|
| 66 |
+
try:
|
| 67 |
+
cfg = get_task("karnataka_easy")
|
| 68 |
+
rewards = compute_grpo_reward_env(test_completions, test_observations, cfg, horizon=1)
|
| 69 |
+
for i, r in enumerate(rewards):
|
| 70 |
+
print(f" Completion {i}: reward={r:.3f}")
|
| 71 |
+
print(f" OK: GRPO rewards computed for {len(rewards)} completions")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
errors.append(f"GRPO reward failed: {e}")
|
| 74 |
+
print(f" FAIL: {e}")
|
| 75 |
+
|
| 76 |
+
# --- 4. Karnataka Difficulty Gradient Test ---
|
| 77 |
+
print("\n[4/7] Karnataka Difficulty Gradient Test...")
|
| 78 |
+
ka_rewards = {}
|
| 79 |
+
for tid in ["karnataka_easy", "karnataka_medium", "karnataka_hard"]:
|
| 80 |
+
try:
|
| 81 |
+
cfg = get_task(tid)
|
| 82 |
+
env = OpenGridEnv(cfg)
|
| 83 |
+
obs = env.reset()
|
| 84 |
+
total_r = 0
|
| 85 |
+
for step_i in range(5):
|
| 86 |
+
action = GridAction.model_validate_json(
|
| 87 |
+
json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 88 |
+
)
|
| 89 |
+
obs, reward, done, info = env.step(action)
|
| 90 |
+
total_r += reward.value
|
| 91 |
+
if done:
|
| 92 |
+
break
|
| 93 |
+
ka_rewards[tid] = total_r
|
| 94 |
+
print(f" {tid}: 5-step reward={total_r:.2f}")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
errors.append(f"Ka difficulty test failed for {tid}: {e}")
|
| 97 |
+
print(f" FAIL: {tid} - {e}")
|
| 98 |
+
|
| 99 |
+
if len(ka_rewards) == 3:
|
| 100 |
+
# Easy should generally give higher or equal rewards than hard
|
| 101 |
+
if ka_rewards["karnataka_easy"] >= ka_rewards["karnataka_hard"]:
|
| 102 |
+
print(f" OK: Difficulty gradient correct (easy >= hard)")
|
| 103 |
+
else:
|
| 104 |
+
print(f" WARN: easy ({ka_rewards['karnataka_easy']:.2f}) < hard ({ka_rewards['karnataka_hard']:.2f}) - may vary by seed")
|
| 105 |
+
|
| 106 |
+
# --- 5. Heuristic policy test ---
|
| 107 |
+
print("\n[5/7] Heuristic Policy Test...")
|
| 108 |
+
for tid in ["task_easy", "karnataka_easy", "task_karnataka"]:
|
| 109 |
+
try:
|
| 110 |
+
cfg = get_task(tid)
|
| 111 |
+
env = OpenGridEnv(cfg)
|
| 112 |
+
obs = env.reset()
|
| 113 |
+
total_r = 0
|
| 114 |
+
for step_i in range(10):
|
| 115 |
+
action = heuristic_policy(obs)
|
| 116 |
+
obs, reward, done, info = env.step(action)
|
| 117 |
+
total_r += reward.value
|
| 118 |
+
if done:
|
| 119 |
+
break
|
| 120 |
+
print(f" OK: {tid} - 10-step heuristic reward={total_r:.2f}")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
errors.append(f"Heuristic policy failed for {tid}: {e}")
|
| 123 |
+
print(f" FAIL: {tid} - {e}")
|
| 124 |
+
|
| 125 |
+
# --- 6. Safety layer test ---
|
| 126 |
+
print("\n[6/7] Safety Layer Test...")
|
| 127 |
+
for tid in ["task_easy", "karnataka_easy", "karnataka_hard"]:
|
| 128 |
+
try:
|
| 129 |
+
cfg = get_task(tid)
|
| 130 |
+
layer = SafetyLayer(cfg)
|
| 131 |
+
action = GridAction.model_validate_json(
|
| 132 |
+
json.dumps({"bus_adjustments": [{"bus_id": 0, "delta": 100.0}], "topology_actions": []})
|
| 133 |
+
)
|
| 134 |
+
bus_state = [{"id": b["id"], "p": b.get("base_p", 0), "soc": b.get("init_soc", 0)} for b in cfg["buses"]]
|
| 135 |
+
line_state = [{"id": l["id"], "connected": True, "flow": 0} for l in cfg["lines"]]
|
| 136 |
+
safe_action, report = layer.validate_and_correct(0, action, line_state, bus_state, {})
|
| 137 |
+
print(f" OK: {tid} - corrected={report.was_corrected}, n1_violations={report.n1_violations_detected}")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
errors.append(f"Safety layer failed for {tid}: {e}")
|
| 140 |
+
print(f" FAIL: {tid} - {e}")
|
| 141 |
+
|
| 142 |
+
# --- 7. Curriculum order test ---
|
| 143 |
+
print("\n[7/7] Curriculum Order Test...")
|
| 144 |
+
from training.train_grpo import CURRICULUM_ORDER
|
| 145 |
+
for tid in CURRICULUM_ORDER:
|
| 146 |
+
if tid in TASKS:
|
| 147 |
+
print(f" OK: {tid} available")
|
| 148 |
+
else:
|
| 149 |
+
errors.append(f"Curriculum task missing: {tid}")
|
| 150 |
+
print(f" FAIL: {tid} not in TASKS")
|
| 151 |
+
|
| 152 |
+
# --- Summary ---
|
| 153 |
+
print("\n" + "=" * 60)
|
| 154 |
+
if errors:
|
| 155 |
+
print(f" FAILED: {len(errors)} errors")
|
| 156 |
+
for e in errors:
|
| 157 |
+
print(f" - {e}")
|
| 158 |
+
sys.exit(1)
|
| 159 |
+
else:
|
| 160 |
+
print(" ALL CHECKS PASSED - Training pipeline ready")
|
| 161 |
+
print("=" * 60)
|
src/scenarios.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Karnataka Grid Scenarios
|
| 3 |
+
========================
|
| 4 |
+
Generates difficulty variants of the Karnataka 15-bus grid.
|
| 5 |
+
Same topology (KPTCL transmission map), different operating conditions.
|
| 6 |
+
|
| 7 |
+
Scenarios vary:
|
| 8 |
+
- Renewable penetration (solar/wind max_p)
|
| 9 |
+
- Load magnitude (base_p multiplier)
|
| 10 |
+
- Line capacity (tighter or relaxed limits)
|
| 11 |
+
- Battery capacity
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
from typing import Dict
|
| 16 |
+
|
| 17 |
+
from src.tasks import generate_karnataka_task
|
| 18 |
+
|
| 19 |
+
__all__ = ['generate_karnataka_scenario', 'KARNATAKA_SCENARIOS']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Difficulty profiles: multipliers applied to the base Karnataka grid
|
| 23 |
+
_DIFFICULTY_PROFILES = {
|
| 24 |
+
"easy": {
|
| 25 |
+
"description": "Low renewables, light load, relaxed lines",
|
| 26 |
+
"renewable_multiplier": 0.3, # Solar/wind max_p scaled down
|
| 27 |
+
"load_multiplier": 0.6, # Loads are lighter
|
| 28 |
+
"line_capacity_multiplier": 1.5, # Lines can carry more
|
| 29 |
+
"battery_capacity_multiplier": 1.5, # More storage headroom
|
| 30 |
+
"max_steps": 50,
|
| 31 |
+
},
|
| 32 |
+
"medium": {
|
| 33 |
+
"description": "Moderate renewables, normal load, standard lines",
|
| 34 |
+
"renewable_multiplier": 0.7,
|
| 35 |
+
"load_multiplier": 1.0,
|
| 36 |
+
"line_capacity_multiplier": 1.0,
|
| 37 |
+
"battery_capacity_multiplier": 1.0,
|
| 38 |
+
"max_steps": 50,
|
| 39 |
+
},
|
| 40 |
+
"hard": {
|
| 41 |
+
"description": "High renewables, peak demand, tight lines",
|
| 42 |
+
"renewable_multiplier": 1.3, # More volatile supply
|
| 43 |
+
"load_multiplier": 1.4, # Peak demand
|
| 44 |
+
"line_capacity_multiplier": 0.75, # Congested corridors
|
| 45 |
+
"battery_capacity_multiplier": 0.7, # Less storage
|
| 46 |
+
"max_steps": 50,
|
| 47 |
+
},
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def generate_karnataka_scenario(difficulty: str, seed: int = 808) -> Dict:
|
| 52 |
+
"""Generate a Karnataka grid scenario at a given difficulty level.
|
| 53 |
+
|
| 54 |
+
The base topology (15 buses, 18 lines, 4 zones) is identical across
|
| 55 |
+
all difficulties. Only the operating conditions change:
|
| 56 |
+
- Renewable generation capacity (solar/wind max_p)
|
| 57 |
+
- Load demand (base_p on load buses)
|
| 58 |
+
- Transmission line capacity
|
| 59 |
+
- Battery storage capacity
|
| 60 |
+
|
| 61 |
+
This enables curriculum learning on a consistent grid structure.
|
| 62 |
+
"""
|
| 63 |
+
if difficulty not in _DIFFICULTY_PROFILES:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
f"Unknown difficulty '{difficulty}'. "
|
| 66 |
+
f"Available: {list(_DIFFICULTY_PROFILES.keys())}"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
profile = _DIFFICULTY_PROFILES[difficulty]
|
| 70 |
+
base = generate_karnataka_task(seed=seed)
|
| 71 |
+
|
| 72 |
+
# Apply multipliers to buses
|
| 73 |
+
for bus in base["buses"]:
|
| 74 |
+
bus_type = bus["type"]
|
| 75 |
+
|
| 76 |
+
if bus_type in ("solar", "wind"):
|
| 77 |
+
bus["max_p"] = round(bus["max_p"] * profile["renewable_multiplier"], 1)
|
| 78 |
+
|
| 79 |
+
elif bus_type == "load":
|
| 80 |
+
bus["base_p"] = round(bus["base_p"] * profile["load_multiplier"], 1)
|
| 81 |
+
|
| 82 |
+
elif bus_type == "battery":
|
| 83 |
+
bus["max_p"] = round(bus["max_p"] * profile["battery_capacity_multiplier"], 1)
|
| 84 |
+
bus["capacity"] = round(bus["capacity"] * profile["battery_capacity_multiplier"], 1)
|
| 85 |
+
bus["init_soc"] = round(bus["capacity"] * 0.5, 1)
|
| 86 |
+
|
| 87 |
+
elif bus_type == "slack":
|
| 88 |
+
# Scale slack to cover the adjusted load
|
| 89 |
+
total_load = sum(
|
| 90 |
+
b["base_p"] * (profile["load_multiplier"] if b["type"] == "load" else 1.0)
|
| 91 |
+
for b in base["buses"] if b["type"] == "load"
|
| 92 |
+
)
|
| 93 |
+
bus["max_p"] = max(200, round(total_load * 0.8, 1))
|
| 94 |
+
bus["min_p"] = -bus["max_p"]
|
| 95 |
+
|
| 96 |
+
# Apply line capacity multiplier
|
| 97 |
+
for line in base["lines"]:
|
| 98 |
+
line["capacity"] = round(line["capacity"] * profile["line_capacity_multiplier"], 1)
|
| 99 |
+
|
| 100 |
+
# Update metadata
|
| 101 |
+
base["id"] = f"karnataka_{difficulty}"
|
| 102 |
+
base["difficulty"] = f"karnataka_{difficulty}"
|
| 103 |
+
base["max_steps"] = profile["max_steps"]
|
| 104 |
+
base["scenario_description"] = profile["description"]
|
| 105 |
+
|
| 106 |
+
return base
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Pre-built scenario configs
|
| 110 |
+
KARNATAKA_SCENARIOS = {
|
| 111 |
+
f"karnataka_{diff}": generate_karnataka_scenario(diff)
|
| 112 |
+
for diff in _DIFFICULTY_PROFILES
|
| 113 |
+
}
|
src/tasks.py
CHANGED
|
@@ -18,18 +18,20 @@ from typing import Dict, List, Tuple
|
|
| 18 |
__all__ = ['generate_procedural_grid', 'generate_karnataka_task', 'TASKS', 'get_task']
|
| 19 |
|
| 20 |
|
| 21 |
-
#
|
| 22 |
def _get_zone_names(num_agents: int) -> List[str]:
|
| 23 |
-
"""Get human-readable zone names for a given agent count."""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
]
|
| 28 |
-
if num_agents <= len(base_names):
|
| 29 |
-
return base_names[:num_agents]
|
| 30 |
return [f"Zone_{i}" for i in range(num_agents)]
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def _partition_into_zones(G: nx.Graph, num_agents: int) -> Dict[int, int]:
|
| 34 |
"""Partition graph nodes into balanced, connected zones.
|
| 35 |
|
|
@@ -283,13 +285,13 @@ def generate_karnataka_task(seed: int = 808) -> Dict:
|
|
| 283 |
KPTCL transmission map. Nodes have real GPS coordinates for GIS rendering.
|
| 284 |
"""
|
| 285 |
nodes = [
|
| 286 |
-
{"id": 0, "name": "Raichur_TPS", "type": "slack", "lat": 16.
|
| 287 |
{"id": 1, "name": "Kalaburagi", "type": "load", "lat": 17.33, "lon": 76.83, "max_p": 0, "base_p": 40},
|
| 288 |
{"id": 2, "name": "Belagavi", "type": "load", "lat": 15.85, "lon": 74.50, "max_p": 0, "base_p": 35},
|
| 289 |
-
{"id": 3, "name": "Hubballi", "type": "load", "lat": 15.36, "lon": 75.
|
| 290 |
{"id": 4, "name": "Ballari_TPS", "type": "generator", "lat": 15.14, "lon": 76.92, "max_p": 150, "base_p": 0},
|
| 291 |
{"id": 5, "name": "Chitradurga_Wind", "type": "wind", "lat": 14.23, "lon": 76.40, "max_p": 80, "base_p": 0},
|
| 292 |
-
{"id": 6, "name": "Pavagada_Solar", "type": "solar", "lat": 14.10, "lon": 77.
|
| 293 |
{"id": 7, "name": "Sharavathi_Hydro", "type": "generator", "lat": 14.18, "lon": 74.83, "max_p": 100, "base_p": 0},
|
| 294 |
{"id": 8, "name": "Shivamogga", "type": "load", "lat": 13.93, "lon": 75.57, "max_p": 0, "base_p": 30},
|
| 295 |
{"id": 9, "name": "Mangaluru", "type": "load", "lat": 12.87, "lon": 74.88, "max_p": 0, "base_p": 50},
|
|
@@ -350,7 +352,7 @@ def generate_karnataka_task(seed: int = 808) -> Dict:
|
|
| 350 |
"difficulty": "karnataka",
|
| 351 |
"num_agents": 4,
|
| 352 |
"zone_assignments": zone_assignments,
|
| 353 |
-
"zone_names":
|
| 354 |
"zone_bus_ids": zone_bus_ids,
|
| 355 |
"internal_lines": internal_lines,
|
| 356 |
"boundary_lines": boundary_lines,
|
|
@@ -380,5 +382,12 @@ TASKS = {
|
|
| 380 |
"task_easy": generate_procedural_grid("easy", seed=101),
|
| 381 |
"task_medium": generate_procedural_grid("medium", seed=102),
|
| 382 |
"task_hard": generate_procedural_grid("hard", seed=103),
|
| 383 |
-
"task_karnataka": generate_karnataka_task()
|
| 384 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
__all__ = ['generate_procedural_grid', 'generate_karnataka_task', 'TASKS', 'get_task']
|
| 19 |
|
| 20 |
|
| 21 |
+
# Generic zone names for procedural grids
|
| 22 |
def _get_zone_names(num_agents: int) -> List[str]:
|
| 23 |
+
"""Get human-readable zone names for a given agent count (generic)."""
|
| 24 |
+
generic = ["Zone_Alpha", "Zone_Beta", "Zone_Gamma", "Zone_Delta", "Zone_Epsilon"]
|
| 25 |
+
if num_agents <= len(generic):
|
| 26 |
+
return generic[:num_agents]
|
|
|
|
|
|
|
|
|
|
| 27 |
return [f"Zone_{i}" for i in range(num_agents)]
|
| 28 |
|
| 29 |
|
| 30 |
+
# KPTCL-specific zone names (only for Karnataka tasks)
|
| 31 |
+
def _get_karnataka_zone_names() -> List[str]:
|
| 32 |
+
return ["Kalaburagi_Region", "Hubballi_Region", "Mysuru_Region", "Bengaluru_Region"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
def _partition_into_zones(G: nx.Graph, num_agents: int) -> Dict[int, int]:
|
| 36 |
"""Partition graph nodes into balanced, connected zones.
|
| 37 |
|
|
|
|
| 285 |
KPTCL transmission map. Nodes have real GPS coordinates for GIS rendering.
|
| 286 |
"""
|
| 287 |
nodes = [
|
| 288 |
+
{"id": 0, "name": "Raichur_TPS", "type": "slack", "lat": 16.36, "lon": 77.34, "max_p": 200, "base_p": 0},
|
| 289 |
{"id": 1, "name": "Kalaburagi", "type": "load", "lat": 17.33, "lon": 76.83, "max_p": 0, "base_p": 40},
|
| 290 |
{"id": 2, "name": "Belagavi", "type": "load", "lat": 15.85, "lon": 74.50, "max_p": 0, "base_p": 35},
|
| 291 |
+
{"id": 3, "name": "Hubballi", "type": "load", "lat": 15.36, "lon": 75.12, "max_p": 0, "base_p": 45},
|
| 292 |
{"id": 4, "name": "Ballari_TPS", "type": "generator", "lat": 15.14, "lon": 76.92, "max_p": 150, "base_p": 0},
|
| 293 |
{"id": 5, "name": "Chitradurga_Wind", "type": "wind", "lat": 14.23, "lon": 76.40, "max_p": 80, "base_p": 0},
|
| 294 |
+
{"id": 6, "name": "Pavagada_Solar", "type": "solar", "lat": 14.10, "lon": 77.28, "max_p": 120, "base_p": 0},
|
| 295 |
{"id": 7, "name": "Sharavathi_Hydro", "type": "generator", "lat": 14.18, "lon": 74.83, "max_p": 100, "base_p": 0},
|
| 296 |
{"id": 8, "name": "Shivamogga", "type": "load", "lat": 13.93, "lon": 75.57, "max_p": 0, "base_p": 30},
|
| 297 |
{"id": 9, "name": "Mangaluru", "type": "load", "lat": 12.87, "lon": 74.88, "max_p": 0, "base_p": 50},
|
|
|
|
| 352 |
"difficulty": "karnataka",
|
| 353 |
"num_agents": 4,
|
| 354 |
"zone_assignments": zone_assignments,
|
| 355 |
+
"zone_names": _get_karnataka_zone_names(),
|
| 356 |
"zone_bus_ids": zone_bus_ids,
|
| 357 |
"internal_lines": internal_lines,
|
| 358 |
"boundary_lines": boundary_lines,
|
|
|
|
| 382 |
"task_easy": generate_procedural_grid("easy", seed=101),
|
| 383 |
"task_medium": generate_procedural_grid("medium", seed=102),
|
| 384 |
"task_hard": generate_procedural_grid("hard", seed=103),
|
| 385 |
+
"task_karnataka": generate_karnataka_task(),
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
# Register Karnataka scenario variants (same topology, different difficulty)
|
| 389 |
+
from src.scenarios import KARNATAKA_SCENARIOS, generate_karnataka_scenario # noqa: E402
|
| 390 |
+
|
| 391 |
+
for _sid, _cfg in KARNATAKA_SCENARIOS.items():
|
| 392 |
+
TASKS[_sid] = _cfg
|
| 393 |
+
_TASK_GENERATORS[_sid] = (lambda d=_sid.replace("karnataka_", ""): generate_karnataka_scenario(d))
|
static/app.js
CHANGED
|
@@ -15,17 +15,52 @@ let state = {
|
|
| 15 |
};
|
| 16 |
|
| 17 |
// --- Init ---
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
btn.addEventListener('click', () => {
|
| 21 |
document.querySelectorAll('.task-btn').forEach(b => b.classList.remove('active'));
|
| 22 |
btn.classList.add('active');
|
| 23 |
-
state.task =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
});
|
|
|
|
|
|
|
|
|
|
| 26 |
fetch(`${API}/tasks`).then(r=>r.json()).then(d=>{
|
| 27 |
d.forEach(t => state.taskConfigs[t.id] = t);
|
| 28 |
-
|
|
|
|
| 29 |
setTimeout(() => document.getElementById('loading').classList.add('hidden'), 800);
|
| 30 |
});
|
| 31 |
});
|
|
@@ -358,29 +393,41 @@ function initLeafletMap() {
|
|
| 358 |
const container = document.getElementById('gridMap');
|
| 359 |
if (leafletMap) return;
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
|
|
|
| 363 |
|
| 364 |
-
|
| 365 |
-
center: [14.5, 76.5],
|
| 366 |
-
zoom: 7,
|
| 367 |
zoomControl: true,
|
| 368 |
attributionControl: false,
|
| 369 |
-
minZoom:
|
| 370 |
maxZoom: 15,
|
| 371 |
preferCanvas: true,
|
| 372 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
// Layer groups for easy clearing
|
| 386 |
mapLayers.lines = L.layerGroup().addTo(leafletMap);
|
|
@@ -390,7 +437,7 @@ function initLeafletMap() {
|
|
| 390 |
// Fix Leaflet size after container is fully rendered
|
| 391 |
setTimeout(() => {
|
| 392 |
leafletMap.invalidateSize();
|
| 393 |
-
leafletMap.fitBounds(kaBounds, { padding: [20, 20] });
|
| 394 |
}, 200);
|
| 395 |
}
|
| 396 |
|
|
@@ -464,17 +511,40 @@ function updateGridMap() {
|
|
| 464 |
if (!from || !to) return;
|
| 465 |
|
| 466 |
const lc = !l.connected ? '#4a5568' : l.rho > 1 ? '#ff1744' : l.rho > 0.8 ? '#ff9100' : '#e91e63';
|
| 467 |
-
const w = !l.connected ?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
const polyline = L.polyline(
|
| 470 |
[[from.lat, from.lon], [to.lat, to.lon]],
|
| 471 |
-
{ color: lc, weight: w, dashArray: l.connected ? '
|
| 472 |
);
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
});
|
|
|
|
| 477 |
}
|
|
|
|
| 478 |
mapLayers.lines.addLayer(polyline);
|
| 479 |
});
|
| 480 |
}
|
|
|
|
| 15 |
};
|
| 16 |
|
| 17 |
// --- Init ---
|
| 18 |
+
function isKarnatakaTask(taskId) {
|
| 19 |
+
return taskId.includes('karnataka');
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
function buildTaskButtons(tasks) {
|
| 23 |
+
const procContainer = document.getElementById('proceduralTasks');
|
| 24 |
+
const kaContainer = document.getElementById('karnatakaTasks');
|
| 25 |
+
procContainer.innerHTML = '';
|
| 26 |
+
kaContainer.innerHTML = '';
|
| 27 |
+
|
| 28 |
+
// Display-friendly names
|
| 29 |
+
const nameMap = {
|
| 30 |
+
'task_easy': 'Easy', 'task_medium': 'Medium', 'task_hard': 'Hard',
|
| 31 |
+
'task_karnataka': 'Full ★',
|
| 32 |
+
'karnataka_easy': 'Easy', 'karnataka_medium': 'Medium', 'karnataka_hard': 'Hard',
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
tasks.forEach(t => {
|
| 36 |
+
const btn = document.createElement('button');
|
| 37 |
+
btn.className = 'task-btn' + (t.id === state.task ? ' active' : '');
|
| 38 |
+
if (isKarnatakaTask(t.id)) btn.classList.add('ka');
|
| 39 |
+
btn.dataset.task = t.id;
|
| 40 |
+
const label = nameMap[t.id] || t.id.replace('task_','').replace('karnataka_','');
|
| 41 |
+
btn.innerHTML = `<span class="task-name">${label}</span><span class="task-info">${t.num_buses}b · ${t.num_agents}a</span>`;
|
| 42 |
btn.addEventListener('click', () => {
|
| 43 |
document.querySelectorAll('.task-btn').forEach(b => b.classList.remove('active'));
|
| 44 |
btn.classList.add('active');
|
| 45 |
+
state.task = t.id;
|
| 46 |
+
// Destroy map so it reinitializes with correct bounds
|
| 47 |
+
if (leafletMap) { leafletMap.remove(); leafletMap = null; mapLayers = {lines:null,nodes:null,badges:null}; }
|
| 48 |
+
mapFitted = false;
|
| 49 |
+
resetEpisode();
|
| 50 |
});
|
| 51 |
+
if (t.id.startsWith('task_') && !t.id.includes('karnataka')) {
|
| 52 |
+
procContainer.appendChild(btn);
|
| 53 |
+
} else {
|
| 54 |
+
kaContainer.appendChild(btn);
|
| 55 |
+
}
|
| 56 |
});
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 60 |
fetch(`${API}/tasks`).then(r=>r.json()).then(d=>{
|
| 61 |
d.forEach(t => state.taskConfigs[t.id] = t);
|
| 62 |
+
buildTaskButtons(d);
|
| 63 |
+
resetEpisode();
|
| 64 |
setTimeout(() => document.getElementById('loading').classList.add('hidden'), 800);
|
| 65 |
});
|
| 66 |
});
|
|
|
|
| 393 |
const container = document.getElementById('gridMap');
|
| 394 |
if (leafletMap) return;
|
| 395 |
|
| 396 |
+
const isKa = isKarnatakaTask(state.task);
|
| 397 |
+
// Karnataka bounds: tight crop around the state
|
| 398 |
+
const kaBounds = [[11.5, 73.8], [18.5, 79.0]];
|
| 399 |
|
| 400 |
+
const mapOpts = {
|
| 401 |
+
center: isKa ? [14.5, 76.5] : [15, 76],
|
| 402 |
+
zoom: isKa ? 7 : 6,
|
| 403 |
zoomControl: true,
|
| 404 |
attributionControl: false,
|
| 405 |
+
minZoom: isKa ? 6 : 3,
|
| 406 |
maxZoom: 15,
|
| 407 |
preferCanvas: true,
|
| 408 |
+
};
|
| 409 |
+
// Lock panning for Karnataka tasks
|
| 410 |
+
if (isKa) {
|
| 411 |
+
mapOpts.maxBounds = L.latLngBounds(kaBounds).pad(0.15);
|
| 412 |
+
mapOpts.maxBoundsViscosity = 1.0;
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
leafletMap = L.map(container, mapOpts);
|
| 416 |
|
| 417 |
+
if (isKa) {
|
| 418 |
+
// Real map tiles for Karnataka tasks
|
| 419 |
+
L.tileLayer('https://{s}.basemaps.cartocdn.com/dark_all/{z}/{x}/{y}{r}.png', {
|
| 420 |
+
subdomains: 'abcd',
|
| 421 |
+
maxZoom: 19,
|
| 422 |
+
}).addTo(leafletMap);
|
| 423 |
|
| 424 |
+
L.control.attribution({position: 'bottomright', prefix: false})
|
| 425 |
+
.addAttribution('© <a href="https://carto.com/">CARTO</a>')
|
| 426 |
+
.addTo(leafletMap);
|
| 427 |
+
|
| 428 |
+
leafletMap.fitBounds(kaBounds, { padding: [20, 20] });
|
| 429 |
+
}
|
| 430 |
+
// Procedural grids: no tiles — plain dark background via CSS
|
| 431 |
|
| 432 |
// Layer groups for easy clearing
|
| 433 |
mapLayers.lines = L.layerGroup().addTo(leafletMap);
|
|
|
|
| 437 |
// Fix Leaflet size after container is fully rendered
|
| 438 |
setTimeout(() => {
|
| 439 |
leafletMap.invalidateSize();
|
| 440 |
+
if (isKa) leafletMap.fitBounds(kaBounds, { padding: [20, 20] });
|
| 441 |
}, 200);
|
| 442 |
}
|
| 443 |
|
|
|
|
| 511 |
if (!from || !to) return;
|
| 512 |
|
| 513 |
const lc = !l.connected ? '#4a5568' : l.rho > 1 ? '#ff1744' : l.rho > 0.8 ? '#ff9100' : '#e91e63';
|
| 514 |
+
const w = !l.connected ? 2 : l.rho > 1 ? 6 : l.rho > 0.8 ? 5 : 3.5;
|
| 515 |
+
|
| 516 |
+
// Glow layer for overloaded/congested lines
|
| 517 |
+
if (l.connected && l.rho > 0.8) {
|
| 518 |
+
const glow = L.polyline(
|
| 519 |
+
[[from.lat, from.lon], [to.lat, to.lon]],
|
| 520 |
+
{ color: lc, weight: w + 6, opacity: 0.15, dashArray: null, interactive: false }
|
| 521 |
+
);
|
| 522 |
+
mapLayers.lines.addLayer(glow);
|
| 523 |
+
}
|
| 524 |
|
| 525 |
const polyline = L.polyline(
|
| 526 |
[[from.lat, from.lon], [to.lat, to.lon]],
|
| 527 |
+
{ color: lc, weight: w, dashArray: l.connected ? '12 6' : '4 6', opacity: 0.95 }
|
| 528 |
);
|
| 529 |
+
// Show tooltip with flow info
|
| 530 |
+
const flowStr = l.connected ? `${l.flow.toFixed(0)} MW · ${(l.rho*100).toFixed(0)}% load` : 'Disconnected';
|
| 531 |
+
polyline.bindTooltip(`<b>${l.id}</b><br>${flowStr}`, {
|
| 532 |
+
permanent: false, className: 'leaflet-tooltip-dark', direction: 'center'
|
| 533 |
+
});
|
| 534 |
+
|
| 535 |
+
// Permanent label for high-flow lines
|
| 536 |
+
if (l.connected && Math.abs(l.flow) > 10) {
|
| 537 |
+
const midLat = (from.lat + to.lat) / 2;
|
| 538 |
+
const midLon = (from.lon + to.lon) / 2;
|
| 539 |
+
const flowLabel = L.divIcon({
|
| 540 |
+
className: 'line-flow-label',
|
| 541 |
+
html: `<span style="color:${lc};text-shadow:0 0 4px #000,0 0 8px #000;font-size:9px;font-family:'JetBrains Mono',monospace;font-weight:600;white-space:nowrap;">${Math.abs(l.flow).toFixed(0)}MW</span>`,
|
| 542 |
+
iconSize: [40, 12],
|
| 543 |
+
iconAnchor: [20, 6],
|
| 544 |
});
|
| 545 |
+
L.marker([midLat, midLon], { icon: flowLabel, interactive: false }).addTo(mapLayers.lines);
|
| 546 |
}
|
| 547 |
+
|
| 548 |
mapLayers.lines.addLayer(polyline);
|
| 549 |
});
|
| 550 |
}
|
static/index.html
CHANGED
|
@@ -66,6 +66,46 @@
|
|
| 66 |
</div>
|
| 67 |
</header>
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
<!-- ===== LEFT PANEL ===== -->
|
| 70 |
<aside class="left-panel">
|
| 71 |
|
|
@@ -128,25 +168,7 @@
|
|
| 128 |
<!-- Exception Log -->
|
| 129 |
<div class="card" style="flex:1; display:flex; flex-direction:column; overflow:hidden;">
|
| 130 |
<div class="card-title" style="color: var(--status-warning);">Exception Log</div>
|
| 131 |
-
<div class="alarm-log" id="alarmLog">
|
| 132 |
-
<!-- Populated by JS -->
|
| 133 |
-
</div>
|
| 134 |
-
</div>
|
| 135 |
-
|
| 136 |
-
<!-- Task Selector -->
|
| 137 |
-
<div class="card" style="flex-shrink:0;">
|
| 138 |
-
<div class="card-title">Task & Controls</div>
|
| 139 |
-
<div class="task-selector" id="taskSelector">
|
| 140 |
-
<button class="task-btn" data-task="task_easy">Easy</button>
|
| 141 |
-
<button class="task-btn" data-task="task_medium">Medium</button>
|
| 142 |
-
<button class="task-btn" data-task="task_hard">Hard</button>
|
| 143 |
-
<button class="task-btn active" data-task="task_karnataka" style="color: #ffeb3b; border-color: rgba(255,235,59,0.3);">Karnataka</button>
|
| 144 |
-
</div>
|
| 145 |
-
<div class="controls-row" style="margin-top: var(--gap-sm);">
|
| 146 |
-
<button class="ctrl-btn active" id="btnReset" onclick="resetEpisode()">Reset</button>
|
| 147 |
-
<button class="ctrl-btn" id="btnStep" onclick="stepEpisode()">Step</button>
|
| 148 |
-
<button class="ctrl-btn" id="btnAutoRun" onclick="toggleAutoRun()">Auto</button>
|
| 149 |
-
</div>
|
| 150 |
</div>
|
| 151 |
|
| 152 |
</aside>
|
|
@@ -154,6 +176,19 @@
|
|
| 154 |
<!-- ===== CENTER PANEL (Grid Map) ===== -->
|
| 155 |
<main class="center-panel" id="centerPanel">
|
| 156 |
<div class="grid-map" id="gridMap"></div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
<div class="bus-tooltip" id="busTooltip">
|
| 158 |
<div class="tt-title" id="ttTitle">Bus 0</div>
|
| 159 |
<div class="tt-row"><span>Type</span><span class="tt-val" id="ttType">--</span></div>
|
|
@@ -162,60 +197,29 @@
|
|
| 162 |
</div>
|
| 163 |
</main>
|
| 164 |
|
| 165 |
-
<!-- ===== RIGHT PANEL
|
| 166 |
<aside class="right-panel">
|
| 167 |
<div class="card">
|
| 168 |
<div class="card-title">Agent Leaderboard</div>
|
| 169 |
-
<ul class="leaderboard" id="leaderboard">
|
| 170 |
-
<!-- Populated by JS -->
|
| 171 |
-
</ul>
|
| 172 |
-
</div>
|
| 173 |
-
|
| 174 |
-
<div id="agentCards">
|
| 175 |
-
<!-- Populated by JS -->
|
| 176 |
</div>
|
|
|
|
| 177 |
</aside>
|
| 178 |
|
| 179 |
<!-- ===== BOTTOM PANEL ===== -->
|
| 180 |
<footer class="bottom-panel">
|
| 181 |
-
|
| 182 |
-
<!-- Reward History Chart -->
|
| 183 |
<div class="bottom-card">
|
| 184 |
<div class="card-title">Reward History</div>
|
| 185 |
<div class="chart-area" id="rewardChart"></div>
|
| 186 |
</div>
|
| 187 |
-
|
| 188 |
-
<!-- Frequency Trend -->
|
| 189 |
<div class="bottom-card">
|
| 190 |
<div class="card-title">Frequency Trend</div>
|
| 191 |
<div class="chart-area" id="freqChart"></div>
|
| 192 |
</div>
|
| 193 |
-
|
| 194 |
-
<!-- Generation Mix -->
|
| 195 |
<div class="bottom-card">
|
| 196 |
<div class="card-title">Generation Mix</div>
|
| 197 |
<div class="chart-area" id="genMixChart"></div>
|
| 198 |
</div>
|
| 199 |
-
|
| 200 |
-
<!-- Episode Score -->
|
| 201 |
-
<div class="bottom-card">
|
| 202 |
-
<div class="card-title">Episode Score</div>
|
| 203 |
-
<div class="coord-score" style="flex:1; display:flex; flex-direction:column; justify-content:center;">
|
| 204 |
-
<div class="big-value" id="episodeScore" style="color: var(--chart-reward); font-size: 36px;">--</div>
|
| 205 |
-
<div style="font-size:10px; color: var(--text-secondary); margin-top:4px;">Grader Score</div>
|
| 206 |
-
<div style="font-size:11px; margin-top:8px;">
|
| 207 |
-
<span style="color: var(--text-secondary);">Steps:</span>
|
| 208 |
-
<span id="totalSteps" style="font-family: 'JetBrains Mono'; font-weight:600;">0</span>
|
| 209 |
-
<span style="color: var(--text-secondary); margin-left:8px;">Blackout:</span>
|
| 210 |
-
<span id="blackoutStatus" style="font-family: 'JetBrains Mono'; font-weight:600; color: var(--status-normal);">No</span>
|
| 211 |
-
</div>
|
| 212 |
-
</div>
|
| 213 |
-
<div class="controls-row">
|
| 214 |
-
<button class="ctrl-btn" onclick="getGrade()">Grade</button>
|
| 215 |
-
<button class="ctrl-btn danger" onclick="resetEpisode()">New Episode</button>
|
| 216 |
-
</div>
|
| 217 |
-
</div>
|
| 218 |
-
|
| 219 |
</footer>
|
| 220 |
|
| 221 |
</div>
|
|
|
|
| 66 |
</div>
|
| 67 |
</header>
|
| 68 |
|
| 69 |
+
<!-- ===== TOOLBAR ===== -->
|
| 70 |
+
<div class="toolbar">
|
| 71 |
+
<div class="toolbar-section">
|
| 72 |
+
<span class="toolbar-label">PROCEDURAL</span>
|
| 73 |
+
<div class="task-group" id="proceduralTasks">
|
| 74 |
+
<!-- Populated by JS -->
|
| 75 |
+
</div>
|
| 76 |
+
</div>
|
| 77 |
+
<div class="toolbar-divider"></div>
|
| 78 |
+
<div class="toolbar-section">
|
| 79 |
+
<span class="toolbar-label">KARNATAKA</span>
|
| 80 |
+
<div class="task-group" id="karnatakaTasks">
|
| 81 |
+
<!-- Populated by JS -->
|
| 82 |
+
</div>
|
| 83 |
+
</div>
|
| 84 |
+
<div class="toolbar-divider"></div>
|
| 85 |
+
<div class="toolbar-section">
|
| 86 |
+
<span class="toolbar-label">CONTROLS</span>
|
| 87 |
+
<div class="controls-row">
|
| 88 |
+
<button class="ctrl-btn accent" id="btnReset" onclick="resetEpisode()">⟳ Reset</button>
|
| 89 |
+
<button class="ctrl-btn" id="btnStep" onclick="stepEpisode()">▶ Step</button>
|
| 90 |
+
<button class="ctrl-btn" id="btnAutoRun" onclick="toggleAutoRun()">⏩ Auto</button>
|
| 91 |
+
<button class="ctrl-btn" onclick="getGrade()">★ Grade</button>
|
| 92 |
+
</div>
|
| 93 |
+
</div>
|
| 94 |
+
<div class="toolbar-divider"></div>
|
| 95 |
+
<div class="toolbar-section">
|
| 96 |
+
<span class="toolbar-label">SCORE</span>
|
| 97 |
+
<div class="toolbar-score" id="episodeScore">--</div>
|
| 98 |
+
</div>
|
| 99 |
+
<div class="toolbar-section" style="margin-left:auto;">
|
| 100 |
+
<span class="toolbar-label">STEPS</span>
|
| 101 |
+
<span class="toolbar-value" id="totalSteps">0</span>
|
| 102 |
+
</div>
|
| 103 |
+
<div class="toolbar-section">
|
| 104 |
+
<span class="toolbar-label">BLACKOUT</span>
|
| 105 |
+
<span class="toolbar-value" id="blackoutStatus" style="color: var(--status-normal);">No</span>
|
| 106 |
+
</div>
|
| 107 |
+
</div>
|
| 108 |
+
|
| 109 |
<!-- ===== LEFT PANEL ===== -->
|
| 110 |
<aside class="left-panel">
|
| 111 |
|
|
|
|
| 168 |
<!-- Exception Log -->
|
| 169 |
<div class="card" style="flex:1; display:flex; flex-direction:column; overflow:hidden;">
|
| 170 |
<div class="card-title" style="color: var(--status-warning);">Exception Log</div>
|
| 171 |
+
<div class="alarm-log" id="alarmLog"></div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
</div>
|
| 173 |
|
| 174 |
</aside>
|
|
|
|
| 176 |
<!-- ===== CENTER PANEL (Grid Map) ===== -->
|
| 177 |
<main class="center-panel" id="centerPanel">
|
| 178 |
<div class="grid-map" id="gridMap"></div>
|
| 179 |
+
<!-- Map legend -->
|
| 180 |
+
<div class="map-legend">
|
| 181 |
+
<div class="legend-title">Legend</div>
|
| 182 |
+
<div class="legend-item"><span class="legend-dot" style="background:#00e5a0;"></span> Slack</div>
|
| 183 |
+
<div class="legend-item"><span class="legend-dot" style="background:#f5a623;"></span> Generator</div>
|
| 184 |
+
<div class="legend-item"><span class="legend-dot" style="background:#e94560;"></span> Load</div>
|
| 185 |
+
<div class="legend-item"><span class="legend-dot" style="background:#4a90d9;"></span> Battery</div>
|
| 186 |
+
<div class="legend-item"><span class="legend-dot" style="background:#ffeb3b;"></span> Solar</div>
|
| 187 |
+
<div class="legend-item"><span class="legend-dot" style="background:#64ffda;"></span> Wind</div>
|
| 188 |
+
<div class="legend-line"><span class="legend-line-sample normal"></span> Normal</div>
|
| 189 |
+
<div class="legend-line"><span class="legend-line-sample warn"></span> Congested</div>
|
| 190 |
+
<div class="legend-line"><span class="legend-line-sample crit"></span> Overloaded</div>
|
| 191 |
+
</div>
|
| 192 |
<div class="bus-tooltip" id="busTooltip">
|
| 193 |
<div class="tt-title" id="ttTitle">Bus 0</div>
|
| 194 |
<div class="tt-row"><span>Type</span><span class="tt-val" id="ttType">--</span></div>
|
|
|
|
| 197 |
</div>
|
| 198 |
</main>
|
| 199 |
|
| 200 |
+
<!-- ===== RIGHT PANEL ===== -->
|
| 201 |
<aside class="right-panel">
|
| 202 |
<div class="card">
|
| 203 |
<div class="card-title">Agent Leaderboard</div>
|
| 204 |
+
<ul class="leaderboard" id="leaderboard"></ul>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
</div>
|
| 206 |
+
<div id="agentCards"></div>
|
| 207 |
</aside>
|
| 208 |
|
| 209 |
<!-- ===== BOTTOM PANEL ===== -->
|
| 210 |
<footer class="bottom-panel">
|
|
|
|
|
|
|
| 211 |
<div class="bottom-card">
|
| 212 |
<div class="card-title">Reward History</div>
|
| 213 |
<div class="chart-area" id="rewardChart"></div>
|
| 214 |
</div>
|
|
|
|
|
|
|
| 215 |
<div class="bottom-card">
|
| 216 |
<div class="card-title">Frequency Trend</div>
|
| 217 |
<div class="chart-area" id="freqChart"></div>
|
| 218 |
</div>
|
|
|
|
|
|
|
| 219 |
<div class="bottom-card">
|
| 220 |
<div class="card-title">Generation Mix</div>
|
| 221 |
<div class="chart-area" id="genMixChart"></div>
|
| 222 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
</footer>
|
| 224 |
|
| 225 |
</div>
|
static/style.css
CHANGED
|
@@ -89,10 +89,11 @@ body::before {
|
|
| 89 |
/* ---------- Layout ---------- */
|
| 90 |
.control-room {
|
| 91 |
display: grid;
|
| 92 |
-
grid-template-rows:
|
| 93 |
-
grid-template-columns:
|
| 94 |
grid-template-areas:
|
| 95 |
"header header header"
|
|
|
|
| 96 |
"left center right"
|
| 97 |
"bottom bottom bottom";
|
| 98 |
height: 100vh;
|
|
@@ -100,6 +101,190 @@ body::before {
|
|
| 100 |
background: rgba(255,255,255,0.04);
|
| 101 |
}
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
/* ---------- Header ---------- */
|
| 104 |
.header {
|
| 105 |
grid-area: header;
|
|
@@ -667,7 +852,7 @@ body::before {
|
|
| 667 |
grid-area: bottom;
|
| 668 |
background: var(--bg-secondary);
|
| 669 |
display: grid;
|
| 670 |
-
grid-template-columns: 2fr 1fr 1fr
|
| 671 |
gap: 1px;
|
| 672 |
border-top: 1px solid rgba(255,255,255,0.05);
|
| 673 |
}
|
|
@@ -695,73 +880,6 @@ body::before {
|
|
| 695 |
flex: 1;
|
| 696 |
}
|
| 697 |
|
| 698 |
-
/* Controls */
|
| 699 |
-
.controls-row {
|
| 700 |
-
display: flex;
|
| 701 |
-
gap: var(--gap-sm);
|
| 702 |
-
margin-top: var(--gap-sm);
|
| 703 |
-
}
|
| 704 |
-
|
| 705 |
-
.ctrl-btn {
|
| 706 |
-
flex: 1;
|
| 707 |
-
padding: 6px 10px;
|
| 708 |
-
background: rgba(255,255,255,0.04);
|
| 709 |
-
border: 1px solid rgba(255,255,255,0.1);
|
| 710 |
-
border-radius: var(--radius-sm);
|
| 711 |
-
color: var(--text-primary);
|
| 712 |
-
font-family: 'Inter', sans-serif;
|
| 713 |
-
font-size: 11px;
|
| 714 |
-
font-weight: 500;
|
| 715 |
-
cursor: pointer;
|
| 716 |
-
transition: all 0.2s;
|
| 717 |
-
text-align: center;
|
| 718 |
-
}
|
| 719 |
-
|
| 720 |
-
.ctrl-btn:hover {
|
| 721 |
-
background: rgba(0,229,160,0.1);
|
| 722 |
-
border-color: rgba(0,229,160,0.3);
|
| 723 |
-
}
|
| 724 |
-
|
| 725 |
-
.ctrl-btn.active {
|
| 726 |
-
background: rgba(0,229,160,0.15);
|
| 727 |
-
border-color: var(--status-normal);
|
| 728 |
-
color: var(--status-normal);
|
| 729 |
-
}
|
| 730 |
-
|
| 731 |
-
.ctrl-btn.danger {
|
| 732 |
-
border-color: rgba(255,61,61,0.3);
|
| 733 |
-
}
|
| 734 |
-
|
| 735 |
-
.ctrl-btn.danger:hover {
|
| 736 |
-
background: rgba(255,61,61,0.1);
|
| 737 |
-
border-color: rgba(255,61,61,0.5);
|
| 738 |
-
color: var(--status-critical);
|
| 739 |
-
}
|
| 740 |
-
|
| 741 |
-
/* Task selector */
|
| 742 |
-
.task-selector {
|
| 743 |
-
display: flex;
|
| 744 |
-
gap: 4px;
|
| 745 |
-
}
|
| 746 |
-
|
| 747 |
-
.task-btn {
|
| 748 |
-
flex: 1;
|
| 749 |
-
padding: 4px 8px;
|
| 750 |
-
background: rgba(255,255,255,0.03);
|
| 751 |
-
border: 1px solid rgba(255,255,255,0.08);
|
| 752 |
-
border-radius: var(--radius-sm);
|
| 753 |
-
color: var(--text-secondary);
|
| 754 |
-
font-size: 10px;
|
| 755 |
-
font-weight: 500;
|
| 756 |
-
cursor: pointer;
|
| 757 |
-
transition: all 0.2s;
|
| 758 |
-
text-transform: uppercase;
|
| 759 |
-
letter-spacing: 0.5px;
|
| 760 |
-
}
|
| 761 |
-
|
| 762 |
-
.task-btn:hover { border-color: rgba(0,229,160,0.3); color: var(--text-primary); }
|
| 763 |
-
.task-btn.active { background: rgba(0,229,160,0.1); border-color: var(--status-normal); color: var(--status-normal); }
|
| 764 |
-
|
| 765 |
/* Leaderboard */
|
| 766 |
.leaderboard {
|
| 767 |
list-style: none;
|
|
@@ -933,3 +1051,14 @@ body::before {
|
|
| 933 |
.leaflet-control-attribution a {
|
| 934 |
color: #666 !important;
|
| 935 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
/* ---------- Layout ---------- */
|
| 90 |
.control-room {
|
| 91 |
display: grid;
|
| 92 |
+
grid-template-rows: 48px 44px 1fr 160px;
|
| 93 |
+
grid-template-columns: 240px 1fr 280px;
|
| 94 |
grid-template-areas:
|
| 95 |
"header header header"
|
| 96 |
+
"toolbar toolbar toolbar"
|
| 97 |
"left center right"
|
| 98 |
"bottom bottom bottom";
|
| 99 |
height: 100vh;
|
|
|
|
| 101 |
background: rgba(255,255,255,0.04);
|
| 102 |
}
|
| 103 |
|
| 104 |
+
/* ---------- Toolbar ---------- */
|
| 105 |
+
.toolbar {
|
| 106 |
+
grid-area: toolbar;
|
| 107 |
+
background: linear-gradient(90deg, #0d1225, #111a33);
|
| 108 |
+
display: flex;
|
| 109 |
+
align-items: center;
|
| 110 |
+
padding: 0 var(--gap-md);
|
| 111 |
+
gap: var(--gap-md);
|
| 112 |
+
border-bottom: 1px solid rgba(0,229,160,0.1);
|
| 113 |
+
z-index: 10;
|
| 114 |
+
overflow-x: auto;
|
| 115 |
+
}
|
| 116 |
+
.toolbar-section {
|
| 117 |
+
display: flex;
|
| 118 |
+
align-items: center;
|
| 119 |
+
gap: 8px;
|
| 120 |
+
flex-shrink: 0;
|
| 121 |
+
}
|
| 122 |
+
.toolbar-label {
|
| 123 |
+
font-size: 9px;
|
| 124 |
+
font-weight: 600;
|
| 125 |
+
text-transform: uppercase;
|
| 126 |
+
letter-spacing: 1.2px;
|
| 127 |
+
color: var(--text-muted);
|
| 128 |
+
white-space: nowrap;
|
| 129 |
+
}
|
| 130 |
+
.toolbar-divider {
|
| 131 |
+
width: 1px;
|
| 132 |
+
height: 24px;
|
| 133 |
+
background: rgba(255,255,255,0.08);
|
| 134 |
+
flex-shrink: 0;
|
| 135 |
+
}
|
| 136 |
+
.toolbar-score {
|
| 137 |
+
font-family: 'JetBrains Mono', monospace;
|
| 138 |
+
font-size: 18px;
|
| 139 |
+
font-weight: 700;
|
| 140 |
+
color: var(--chart-reward);
|
| 141 |
+
}
|
| 142 |
+
.toolbar-value {
|
| 143 |
+
font-family: 'JetBrains Mono', monospace;
|
| 144 |
+
font-size: 13px;
|
| 145 |
+
font-weight: 600;
|
| 146 |
+
color: var(--text-primary);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/* Task buttons (toolbar version) */
|
| 150 |
+
.task-selector {
|
| 151 |
+
display: flex;
|
| 152 |
+
gap: 4px;
|
| 153 |
+
}
|
| 154 |
+
.task-btn {
|
| 155 |
+
display: flex;
|
| 156 |
+
flex-direction: column;
|
| 157 |
+
align-items: center;
|
| 158 |
+
padding: 4px 10px;
|
| 159 |
+
background: rgba(255,255,255,0.03);
|
| 160 |
+
border: 1px solid rgba(255,255,255,0.08);
|
| 161 |
+
border-radius: var(--radius-sm);
|
| 162 |
+
color: var(--text-secondary);
|
| 163 |
+
font-size: 11px;
|
| 164 |
+
font-weight: 600;
|
| 165 |
+
cursor: pointer;
|
| 166 |
+
transition: all 0.2s;
|
| 167 |
+
text-transform: uppercase;
|
| 168 |
+
letter-spacing: 0.5px;
|
| 169 |
+
line-height: 1.2;
|
| 170 |
+
}
|
| 171 |
+
.task-btn .task-name { font-size: 11px; }
|
| 172 |
+
.task-btn .task-info {
|
| 173 |
+
font-size: 8px;
|
| 174 |
+
font-weight: 400;
|
| 175 |
+
color: var(--text-muted);
|
| 176 |
+
letter-spacing: 0;
|
| 177 |
+
text-transform: none;
|
| 178 |
+
white-space: nowrap;
|
| 179 |
+
}
|
| 180 |
+
.task-btn:hover { border-color: rgba(0,229,160,0.3); color: var(--text-primary); }
|
| 181 |
+
.task-btn.active {
|
| 182 |
+
background: rgba(0,229,160,0.12);
|
| 183 |
+
border-color: var(--status-normal);
|
| 184 |
+
color: var(--status-normal);
|
| 185 |
+
}
|
| 186 |
+
.task-btn.active .task-info { color: rgba(0,229,160,0.6); }
|
| 187 |
+
|
| 188 |
+
/* Karnataka scenario gold accent */
|
| 189 |
+
.task-btn.ka:hover { border-color: rgba(255,235,59,0.4); color: #ffeb3b; }
|
| 190 |
+
.task-btn.ka.active {
|
| 191 |
+
background: rgba(255,235,59,0.1);
|
| 192 |
+
border-color: #ffeb3b;
|
| 193 |
+
color: #ffeb3b;
|
| 194 |
+
}
|
| 195 |
+
.task-btn.ka.active .task-info { color: rgba(255,235,59,0.5); }
|
| 196 |
+
|
| 197 |
+
/* Task group container */
|
| 198 |
+
.task-group {
|
| 199 |
+
display: flex;
|
| 200 |
+
gap: 4px;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
/* Controls in toolbar */
|
| 204 |
+
.controls-row {
|
| 205 |
+
display: flex;
|
| 206 |
+
gap: 4px;
|
| 207 |
+
}
|
| 208 |
+
.ctrl-btn {
|
| 209 |
+
padding: 5px 12px;
|
| 210 |
+
background: rgba(255,255,255,0.04);
|
| 211 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 212 |
+
border-radius: var(--radius-sm);
|
| 213 |
+
color: var(--text-primary);
|
| 214 |
+
font-family: 'Inter', sans-serif;
|
| 215 |
+
font-size: 11px;
|
| 216 |
+
font-weight: 500;
|
| 217 |
+
cursor: pointer;
|
| 218 |
+
transition: all 0.2s;
|
| 219 |
+
text-align: center;
|
| 220 |
+
white-space: nowrap;
|
| 221 |
+
}
|
| 222 |
+
.ctrl-btn:hover {
|
| 223 |
+
background: rgba(0,229,160,0.1);
|
| 224 |
+
border-color: rgba(0,229,160,0.3);
|
| 225 |
+
}
|
| 226 |
+
.ctrl-btn.accent {
|
| 227 |
+
background: rgba(0,229,160,0.12);
|
| 228 |
+
border-color: rgba(0,229,160,0.3);
|
| 229 |
+
color: var(--status-normal);
|
| 230 |
+
}
|
| 231 |
+
.ctrl-btn.active {
|
| 232 |
+
background: rgba(0,229,160,0.2);
|
| 233 |
+
border-color: var(--status-normal);
|
| 234 |
+
color: var(--status-normal);
|
| 235 |
+
box-shadow: 0 0 8px rgba(0,229,160,0.15);
|
| 236 |
+
}
|
| 237 |
+
.ctrl-btn.danger { border-color: rgba(255,61,61,0.3); }
|
| 238 |
+
.ctrl-btn.danger:hover {
|
| 239 |
+
background: rgba(255,61,61,0.1);
|
| 240 |
+
border-color: rgba(255,61,61,0.5);
|
| 241 |
+
color: var(--status-critical);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/* ---------- Map Legend ---------- */
|
| 245 |
+
.map-legend {
|
| 246 |
+
position: absolute;
|
| 247 |
+
bottom: 12px;
|
| 248 |
+
left: 12px;
|
| 249 |
+
background: rgba(10,14,26,0.92);
|
| 250 |
+
border: 1px solid rgba(255,255,255,0.1);
|
| 251 |
+
border-radius: var(--radius-md);
|
| 252 |
+
padding: 8px 12px;
|
| 253 |
+
z-index: 5;
|
| 254 |
+
backdrop-filter: blur(8px);
|
| 255 |
+
font-size: 10px;
|
| 256 |
+
}
|
| 257 |
+
.legend-title {
|
| 258 |
+
font-size: 9px;
|
| 259 |
+
font-weight: 600;
|
| 260 |
+
text-transform: uppercase;
|
| 261 |
+
letter-spacing: 1px;
|
| 262 |
+
color: var(--text-muted);
|
| 263 |
+
margin-bottom: 4px;
|
| 264 |
+
}
|
| 265 |
+
.legend-item, .legend-line {
|
| 266 |
+
display: flex;
|
| 267 |
+
align-items: center;
|
| 268 |
+
gap: 6px;
|
| 269 |
+
padding: 1px 0;
|
| 270 |
+
color: var(--text-secondary);
|
| 271 |
+
}
|
| 272 |
+
.legend-dot {
|
| 273 |
+
width: 8px;
|
| 274 |
+
height: 8px;
|
| 275 |
+
border-radius: 50%;
|
| 276 |
+
flex-shrink: 0;
|
| 277 |
+
}
|
| 278 |
+
.legend-line-sample {
|
| 279 |
+
width: 20px;
|
| 280 |
+
height: 3px;
|
| 281 |
+
border-radius: 2px;
|
| 282 |
+
flex-shrink: 0;
|
| 283 |
+
}
|
| 284 |
+
.legend-line-sample.normal { background: #e91e63; }
|
| 285 |
+
.legend-line-sample.warn { background: #ff9100; }
|
| 286 |
+
.legend-line-sample.crit { background: #ff1744; box-shadow: 0 0 6px rgba(255,23,68,0.5); }
|
| 287 |
+
|
| 288 |
/* ---------- Header ---------- */
|
| 289 |
.header {
|
| 290 |
grid-area: header;
|
|
|
|
| 852 |
grid-area: bottom;
|
| 853 |
background: var(--bg-secondary);
|
| 854 |
display: grid;
|
| 855 |
+
grid-template-columns: 2fr 1fr 1fr;
|
| 856 |
gap: 1px;
|
| 857 |
border-top: 1px solid rgba(255,255,255,0.05);
|
| 858 |
}
|
|
|
|
| 880 |
flex: 1;
|
| 881 |
}
|
| 882 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
/* Leaderboard */
|
| 884 |
.leaderboard {
|
| 885 |
list-style: none;
|
|
|
|
| 1051 |
.leaflet-control-attribution a {
|
| 1052 |
color: #666 !important;
|
| 1053 |
}
|
| 1054 |
+
|
| 1055 |
+
.line-flow-label {
|
| 1056 |
+
background: none !important;
|
| 1057 |
+
border: none !important;
|
| 1058 |
+
text-align: center;
|
| 1059 |
+
}
|
| 1060 |
+
|
| 1061 |
+
/* Dark background for procedural grids (no map tiles) */
|
| 1062 |
+
.leaflet-container {
|
| 1063 |
+
background: #0a0e1a !important;
|
| 1064 |
+
}
|
training/train_grpo.py
CHANGED
|
@@ -786,6 +786,94 @@ def run_test_mode():
|
|
| 786 |
print("="*60)
|
| 787 |
|
| 788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
# ============================================================================
|
| 790 |
# Main
|
| 791 |
# ============================================================================
|
|
@@ -795,7 +883,7 @@ def main():
|
|
| 795 |
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
|
| 796 |
help="HuggingFace model name or path")
|
| 797 |
parser.add_argument("--task", default="task_easy", choices=list(TASKS.keys()),
|
| 798 |
-
help="Which task to train on")
|
| 799 |
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
| 800 |
parser.add_argument("--batch-size", type=int, default=2, help="Batch size per device")
|
| 801 |
parser.add_argument("--num-prompts", type=int, default=50,
|
|
@@ -806,6 +894,10 @@ def main():
|
|
| 806 |
help="Use Unsloth for 4-bit quantized training")
|
| 807 |
parser.add_argument("--test-mode", action="store_true",
|
| 808 |
help="Run pipeline verification without GPU")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
|
| 810 |
args = parser.parse_args()
|
| 811 |
|
|
@@ -816,11 +908,12 @@ def main():
|
|
| 816 |
# Create output directory
|
| 817 |
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 818 |
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
|
|
|
| 824 |
|
| 825 |
|
| 826 |
if __name__ == "__main__":
|
|
|
|
| 786 |
print("="*60)
|
| 787 |
|
| 788 |
|
| 789 |
+
# ============================================================================
|
| 790 |
+
# Curriculum Training
|
| 791 |
+
# ============================================================================
|
| 792 |
+
|
| 793 |
+
CURRICULUM_ORDER = ["karnataka_easy", "karnataka_medium", "karnataka_hard", "task_karnataka"]
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def run_curriculum(args):
|
| 797 |
+
"""Run curriculum training: easy→medium→hard→full on Karnataka grid.
|
| 798 |
+
|
| 799 |
+
Each phase trains for `args.epochs` epochs, saves a checkpoint,
|
| 800 |
+
and the next phase resumes from that checkpoint.
|
| 801 |
+
"""
|
| 802 |
+
print("\n" + "=" * 60)
|
| 803 |
+
print(" OpenGrid Curriculum Training")
|
| 804 |
+
print(f" Phases: {' → '.join(CURRICULUM_ORDER)}")
|
| 805 |
+
print(f" Epochs per phase: {args.epochs}")
|
| 806 |
+
print("=" * 60)
|
| 807 |
+
|
| 808 |
+
checkpoint_path = args.resume_from
|
| 809 |
+
all_results = {}
|
| 810 |
+
|
| 811 |
+
for phase_idx, task_id in enumerate(CURRICULUM_ORDER):
|
| 812 |
+
phase_num = phase_idx + 1
|
| 813 |
+
print(f"\n{'─' * 60}")
|
| 814 |
+
print(f" Phase {phase_num}/{len(CURRICULUM_ORDER)}: {task_id}")
|
| 815 |
+
if checkpoint_path:
|
| 816 |
+
print(f" Resuming from: {checkpoint_path}")
|
| 817 |
+
print(f"{'─' * 60}")
|
| 818 |
+
|
| 819 |
+
# Override args for this phase
|
| 820 |
+
phase_args = copy.copy(args)
|
| 821 |
+
phase_args.task = task_id
|
| 822 |
+
phase_args.output_dir = str(Path(args.output_dir) / f"phase_{phase_num}_{task_id}")
|
| 823 |
+
if checkpoint_path:
|
| 824 |
+
phase_args.model = checkpoint_path
|
| 825 |
+
|
| 826 |
+
Path(phase_args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 827 |
+
|
| 828 |
+
# Train this phase
|
| 829 |
+
train_result = train_grpo(phase_args)
|
| 830 |
+
|
| 831 |
+
# Set checkpoint for next phase
|
| 832 |
+
checkpoint_path = str(Path(phase_args.output_dir) / "trained_model")
|
| 833 |
+
|
| 834 |
+
# Evaluate on all Karnataka tasks
|
| 835 |
+
print(f"\n [EVAL] Phase {phase_num} evaluation...")
|
| 836 |
+
eval_tasks = CURRICULUM_ORDER
|
| 837 |
+
from src.baseline import heuristic_policy
|
| 838 |
+
|
| 839 |
+
def heuristic_generate(prompt):
|
| 840 |
+
freq_match = re.search(r'Frequency:\s*([-+]?\d+(?:\.\d+)?)', prompt)
|
| 841 |
+
freq = float(freq_match.group(1)) if freq_match else 50.0
|
| 842 |
+
error = 50.0 - freq
|
| 843 |
+
delta = max(-20, min(20, error * 10))
|
| 844 |
+
bus_matches = re.findall(r'Bus (\d+) \((generator|battery)\)', prompt)
|
| 845 |
+
if bus_matches:
|
| 846 |
+
per_bus = delta / len(bus_matches)
|
| 847 |
+
return json.dumps({"bus_adjustments": [{"bus_id": int(m[0]), "delta": round(per_bus, 1)} for m in bus_matches], "topology_actions": []})
|
| 848 |
+
return json.dumps({"bus_adjustments": [], "topology_actions": []})
|
| 849 |
+
|
| 850 |
+
phase_results = evaluate_model(heuristic_generate, task_ids=eval_tasks, n_episodes=2)
|
| 851 |
+
all_results[f"phase_{phase_num}"] = phase_results
|
| 852 |
+
for tid, res in phase_results.items():
|
| 853 |
+
print(f" {tid}: {res['avg_reward']:.2f} ± {res['std_reward']:.2f}")
|
| 854 |
+
|
| 855 |
+
# Summary
|
| 856 |
+
print("\n" + "=" * 60)
|
| 857 |
+
print(" CURRICULUM TRAINING COMPLETE")
|
| 858 |
+
print("=" * 60)
|
| 859 |
+
print(f" Final model: {checkpoint_path}")
|
| 860 |
+
print(f" Phases completed: {len(CURRICULUM_ORDER)}")
|
| 861 |
+
|
| 862 |
+
# Save curriculum summary
|
| 863 |
+
summary = {
|
| 864 |
+
"phases": CURRICULUM_ORDER,
|
| 865 |
+
"epochs_per_phase": args.epochs,
|
| 866 |
+
"results": {k: {t: {"avg": round(r["avg_reward"], 2)} for t, r in v.items()} for k, v in all_results.items()},
|
| 867 |
+
"final_model": checkpoint_path,
|
| 868 |
+
}
|
| 869 |
+
summary_path = Path(args.output_dir) / "curriculum_summary.json"
|
| 870 |
+
with open(summary_path, "w") as f:
|
| 871 |
+
json.dump(summary, f, indent=2)
|
| 872 |
+
print(f" Summary: {summary_path}")
|
| 873 |
+
|
| 874 |
+
return summary
|
| 875 |
+
|
| 876 |
+
|
| 877 |
# ============================================================================
|
| 878 |
# Main
|
| 879 |
# ============================================================================
|
|
|
|
| 883 |
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
|
| 884 |
help="HuggingFace model name or path")
|
| 885 |
parser.add_argument("--task", default="task_easy", choices=list(TASKS.keys()),
|
| 886 |
+
help="Which task to train on (ignored if --curriculum)")
|
| 887 |
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
| 888 |
parser.add_argument("--batch-size", type=int, default=2, help="Batch size per device")
|
| 889 |
parser.add_argument("--num-prompts", type=int, default=50,
|
|
|
|
| 894 |
help="Use Unsloth for 4-bit quantized training")
|
| 895 |
parser.add_argument("--test-mode", action="store_true",
|
| 896 |
help="Run pipeline verification without GPU")
|
| 897 |
+
parser.add_argument("--curriculum", action="store_true",
|
| 898 |
+
help="Run curriculum training: karnataka_easy → medium → hard → full")
|
| 899 |
+
parser.add_argument("--resume-from", default=None,
|
| 900 |
+
help="Resume training from a checkpoint path")
|
| 901 |
|
| 902 |
args = parser.parse_args()
|
| 903 |
|
|
|
|
| 908 |
# Create output directory
|
| 909 |
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 910 |
|
| 911 |
+
if args.curriculum:
|
| 912 |
+
run_curriculum(args)
|
| 913 |
+
else:
|
| 914 |
+
train_result = train_grpo(args)
|
| 915 |
+
print("\n[DONE] Training complete!")
|
| 916 |
+
print(f" Output: {args.output_dir}")
|
| 917 |
|
| 918 |
|
| 919 |
if __name__ == "__main__":
|