| """ |
| Comprehensive training pipeline verification. |
| Tests: scenarios, reward functions, policies, GRPO integration, safety. |
| """ |
| import json |
| import copy |
| import sys |
| sys.path.insert(0, ".") |
|
|
| from src.tasks import TASKS, get_task |
| from src.environment import OpenGridEnv |
| from src.models import GridAction, GridObservation |
| from src.grader import RobustnessGrader |
| from src.baseline import heuristic_policy |
| from src.safety import SafetyLayer |
|
|
| print("=" * 60) |
| print(" COMPREHENSIVE TRAINING PIPELINE VERIFICATION") |
| print("=" * 60) |
|
|
| errors = [] |
|
|
| |
| print("\n[1/7] Scenario Loading...") |
| expected_tasks = ["task_easy", "task_medium", "task_hard", |
| "task_karnataka", "karnataka_easy", "karnataka_medium", "karnataka_hard"] |
| for tid in expected_tasks: |
| if tid not in TASKS: |
| errors.append(f"Missing task: {tid}") |
| print(f" FAIL: {tid} not in TASKS") |
| else: |
| cfg = TASKS[tid] |
| print(f" OK: {tid} - {cfg['num_buses']}b/{cfg['num_agents']}a zones={cfg['zone_names']}") |
|
|
| |
| print("\n[2/7] Environment Step Test...") |
| for tid in expected_tasks: |
| try: |
| cfg = get_task(tid) |
| env = OpenGridEnv(cfg) |
| obs = env.reset() |
| action = GridAction.model_validate_json( |
| json.dumps({"bus_adjustments": [], "topology_actions": []}) |
| ) |
| obs2, reward, done, info = env.step(action) |
| freq = obs2.grid_frequency |
| r = reward.value |
| print(f" OK: {tid} - freq={freq:.2f}Hz reward={r:.2f}") |
| except Exception as e: |
| errors.append(f"Env step failed for {tid}: {e}") |
| print(f" FAIL: {tid} - {e}") |
|
|
| |
| print("\n[3/7] GRPO Reward Function Test...") |
| from training.train_grpo import compute_grpo_reward_env |
| test_completions = [ |
| '{"bus_adjustments": [{"bus_id": 0, "delta": 5.0}], "topology_actions": []}', |
| '{"bus_adjustments": [], "topology_actions": []}', |
| 'not valid json', |
| ] |
| test_observations = [ |
| {"grid_frequency": 49.5, "buses": [], "lines": []}, |
| {"grid_frequency": 50.0, "buses": [], "lines": []}, |
| {"grid_frequency": 48.0, "buses": [], "lines": []}, |
| ] |
| try: |
| cfg = get_task("karnataka_easy") |
| rewards = compute_grpo_reward_env(test_completions, test_observations, cfg, horizon=1) |
| for i, r in enumerate(rewards): |
| print(f" Completion {i}: reward={r:.3f}") |
| print(f" OK: GRPO rewards computed for {len(rewards)} completions") |
| except Exception as e: |
| errors.append(f"GRPO reward failed: {e}") |
| print(f" FAIL: {e}") |
|
|
| |
| print("\n[4/7] Karnataka Difficulty Gradient Test...") |
| ka_rewards = {} |
| for tid in ["karnataka_easy", "karnataka_medium", "karnataka_hard"]: |
| try: |
| cfg = get_task(tid) |
| env = OpenGridEnv(cfg) |
| obs = env.reset() |
| total_r = 0 |
| for step_i in range(5): |
| action = GridAction.model_validate_json( |
| json.dumps({"bus_adjustments": [], "topology_actions": []}) |
| ) |
| obs, reward, done, info = env.step(action) |
| total_r += reward.value |
| if done: |
| break |
| ka_rewards[tid] = total_r |
| print(f" {tid}: 5-step reward={total_r:.2f}") |
| except Exception as e: |
| errors.append(f"Ka difficulty test failed for {tid}: {e}") |
| print(f" FAIL: {tid} - {e}") |
|
|
| if len(ka_rewards) == 3: |
| |
| if ka_rewards["karnataka_easy"] >= ka_rewards["karnataka_hard"]: |
| print(f" OK: Difficulty gradient correct (easy >= hard)") |
| else: |
| print(f" WARN: easy ({ka_rewards['karnataka_easy']:.2f}) < hard ({ka_rewards['karnataka_hard']:.2f}) - may vary by seed") |
|
|
| |
| print("\n[5/7] Heuristic Policy Test...") |
| for tid in ["task_easy", "karnataka_easy", "task_karnataka"]: |
| try: |
| cfg = get_task(tid) |
| env = OpenGridEnv(cfg) |
| obs = env.reset() |
| total_r = 0 |
| for step_i in range(10): |
| action = heuristic_policy(obs) |
| obs, reward, done, info = env.step(action) |
| total_r += reward.value |
| if done: |
| break |
| print(f" OK: {tid} - 10-step heuristic reward={total_r:.2f}") |
| except Exception as e: |
| errors.append(f"Heuristic policy failed for {tid}: {e}") |
| print(f" FAIL: {tid} - {e}") |
|
|
| |
| print("\n[6/7] Safety Layer Test...") |
| for tid in ["task_easy", "karnataka_easy", "karnataka_hard"]: |
| try: |
| cfg = get_task(tid) |
| layer = SafetyLayer(cfg) |
| action = GridAction.model_validate_json( |
| json.dumps({"bus_adjustments": [{"bus_id": 0, "delta": 100.0}], "topology_actions": []}) |
| ) |
| bus_state = [{"id": b["id"], "p": b.get("base_p", 0), "soc": b.get("init_soc", 0)} for b in cfg["buses"]] |
| line_state = [{"id": l["id"], "connected": True, "flow": 0} for l in cfg["lines"]] |
| safe_action, report = layer.validate_and_correct(0, action, line_state, bus_state, {}) |
| print(f" OK: {tid} - corrected={report.was_corrected}, n1_violations={report.n1_violations_detected}") |
| except Exception as e: |
| errors.append(f"Safety layer failed for {tid}: {e}") |
| print(f" FAIL: {tid} - {e}") |
|
|
| |
| print("\n[7/7] Curriculum Order Test...") |
| from training.train_grpo import CURRICULUM_ORDER |
| for tid in CURRICULUM_ORDER: |
| if tid in TASKS: |
| print(f" OK: {tid} available") |
| else: |
| errors.append(f"Curriculum task missing: {tid}") |
| print(f" FAIL: {tid} not in TASKS") |
|
|
| |
| print("\n" + "=" * 60) |
| if errors: |
| print(f" FAILED: {len(errors)} errors") |
| for e in errors: |
| print(f" - {e}") |
| sys.exit(1) |
| else: |
| print(" ALL CHECKS PASSED - Training pipeline ready") |
| print("=" * 60) |
|
|