Opengrid / scripts /verify_training.py
K446's picture
feat: curriculum training + Karnataka scenarios + repo cleanup
8a02303
"""
Comprehensive training pipeline verification.
Tests: scenarios, reward functions, policies, GRPO integration, safety.
"""
import json
import copy
import sys
sys.path.insert(0, ".")
from src.tasks import TASKS, get_task
from src.environment import OpenGridEnv
from src.models import GridAction, GridObservation
from src.grader import RobustnessGrader
from src.baseline import heuristic_policy
from src.safety import SafetyLayer
print("=" * 60)
print(" COMPREHENSIVE TRAINING PIPELINE VERIFICATION")
print("=" * 60)
errors = []
# --- 1. Scenario loading ---
print("\n[1/7] Scenario Loading...")
expected_tasks = ["task_easy", "task_medium", "task_hard",
"task_karnataka", "karnataka_easy", "karnataka_medium", "karnataka_hard"]
for tid in expected_tasks:
if tid not in TASKS:
errors.append(f"Missing task: {tid}")
print(f" FAIL: {tid} not in TASKS")
else:
cfg = TASKS[tid]
print(f" OK: {tid} - {cfg['num_buses']}b/{cfg['num_agents']}a zones={cfg['zone_names']}")
# --- 2. Environment step for each scenario ---
print("\n[2/7] Environment Step Test...")
for tid in expected_tasks:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [], "topology_actions": []})
)
obs2, reward, done, info = env.step(action)
freq = obs2.grid_frequency
r = reward.value
print(f" OK: {tid} - freq={freq:.2f}Hz reward={r:.2f}")
except Exception as e:
errors.append(f"Env step failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 3. Reward function (GRPO) test ---
print("\n[3/7] GRPO Reward Function Test...")
from training.train_grpo import compute_grpo_reward_env
test_completions = [
'{"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []}',
'{"bus_adjustments": [], "topology_actions": []}',
'not valid json',
]
test_observations = [
{"grid_frequency": 49.5, "buses": [], "lines": []},
{"grid_frequency": 50.0, "buses": [], "lines": []},
{"grid_frequency": 48.0, "buses": [], "lines": []},
]
try:
cfg = get_task("karnataka_easy")
rewards = compute_grpo_reward_env(test_completions, test_observations, cfg, horizon=1)
for i, r in enumerate(rewards):
print(f" Completion {i}: reward={r:.3f}")
print(f" OK: GRPO rewards computed for {len(rewards)} completions")
except Exception as e:
errors.append(f"GRPO reward failed: {e}")
print(f" FAIL: {e}")
# --- 4. Karnataka Difficulty Gradient Test ---
print("\n[4/7] Karnataka Difficulty Gradient Test...")
ka_rewards = {}
for tid in ["karnataka_easy", "karnataka_medium", "karnataka_hard"]:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
total_r = 0
for step_i in range(5):
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [], "topology_actions": []})
)
obs, reward, done, info = env.step(action)
total_r += reward.value
if done:
break
ka_rewards[tid] = total_r
print(f" {tid}: 5-step reward={total_r:.2f}")
except Exception as e:
errors.append(f"Ka difficulty test failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
if len(ka_rewards) == 3:
# Easy should generally give higher or equal rewards than hard
if ka_rewards["karnataka_easy"] >= ka_rewards["karnataka_hard"]:
print(f" OK: Difficulty gradient correct (easy >= hard)")
else:
print(f" WARN: easy ({ka_rewards['karnataka_easy']:.2f}) < hard ({ka_rewards['karnataka_hard']:.2f}) - may vary by seed")
# --- 5. Heuristic policy test ---
print("\n[5/7] Heuristic Policy Test...")
for tid in ["task_easy", "karnataka_easy", "task_karnataka"]:
try:
cfg = get_task(tid)
env = OpenGridEnv(cfg)
obs = env.reset()
total_r = 0
for step_i in range(10):
action = heuristic_policy(obs)
obs, reward, done, info = env.step(action)
total_r += reward.value
if done:
break
print(f" OK: {tid} - 10-step heuristic reward={total_r:.2f}")
except Exception as e:
errors.append(f"Heuristic policy failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 6. Safety layer test ---
print("\n[6/7] Safety Layer Test...")
for tid in ["task_easy", "karnataka_easy", "karnataka_hard"]:
try:
cfg = get_task(tid)
layer = SafetyLayer(cfg)
action = GridAction.model_validate_json(
json.dumps({"bus_adjustments": [{"bus_id": 0, "delta": 100.0}], "topology_actions": []})
)
bus_state = [{"id": b["id"], "p": b.get("base_p", 0), "soc": b.get("init_soc", 0)} for b in cfg["buses"]]
line_state = [{"id": l["id"], "connected": True, "flow": 0} for l in cfg["lines"]]
safe_action, report = layer.validate_and_correct(0, action, line_state, bus_state, {})
print(f" OK: {tid} - corrected={report.was_corrected}, n1_violations={report.n1_violations_detected}")
except Exception as e:
errors.append(f"Safety layer failed for {tid}: {e}")
print(f" FAIL: {tid} - {e}")
# --- 7. Curriculum order test ---
print("\n[7/7] Curriculum Order Test...")
from training.train_grpo import CURRICULUM_ORDER
for tid in CURRICULUM_ORDER:
if tid in TASKS:
print(f" OK: {tid} available")
else:
errors.append(f"Curriculum task missing: {tid}")
print(f" FAIL: {tid} not in TASKS")
# --- Summary ---
print("\n" + "=" * 60)
if errors:
print(f" FAILED: {len(errors)} errors")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print(" ALL CHECKS PASSED - Training pipeline ready")
print("=" * 60)