| """Heuristic baseline agent for the stellarator design environment.""" |
|
|
| from __future__ import annotations |
|
|
| import sys |
|
|
| from fusion_lab.models import StellaratorAction, StellaratorObservation |
| from server.environment import StellaratorEnvironment |
|
|
| FEASIBLE_SUBMIT_ELONGATION_MAX = 7.45 |
| TRIANGULARITY_TARGET_MAX = -0.5 |
| LOW_IOTA_RESET_THRESHOLD = 0.305 |
| IOTA_RECOVERY_THRESHOLD = 0.3 |
| ASPECT_RATIO_TARGET_MAX = 4.0 |
|
|
|
|
| def heuristic_episode( |
| env: StellaratorEnvironment, seed: int | None = None |
| ) -> tuple[float, list[dict[str, object]]]: |
| obs = env.reset(seed=seed) |
| total_reward = 0.0 |
| trace: list[dict[str, object]] = [ |
| { |
| "step": 0, |
| "score": obs.p1_score, |
| "evaluation_fidelity": obs.evaluation_fidelity, |
| "constraints_satisfied": obs.constraints_satisfied, |
| "feasibility": obs.p1_feasibility, |
| "max_elongation": obs.max_elongation, |
| "average_triangularity": obs.average_triangularity, |
| "edge_iota_over_nfp": obs.edge_iota_over_nfp, |
| } |
| ] |
|
|
| while not obs.done: |
| action = ( |
| StellaratorAction(intent="submit") if obs.budget_remaining <= 1 else _choose_action(obs) |
| ) |
| obs = env.step(action) |
| total_reward += obs.reward or 0.0 |
| trace.append( |
| { |
| "step": len(trace), |
| "action": _action_label(action), |
| "score": obs.p1_score, |
| "evaluation_fidelity": obs.evaluation_fidelity, |
| "constraints_satisfied": obs.constraints_satisfied, |
| "feasibility": obs.p1_feasibility, |
| "max_elongation": obs.max_elongation, |
| "average_triangularity": obs.average_triangularity, |
| "edge_iota_over_nfp": obs.edge_iota_over_nfp, |
| "reward": obs.reward, |
| "evaluation_failed": obs.evaluation_failed, |
| } |
| ) |
|
|
| return total_reward, trace |
|
|
|
|
| def _choose_action(obs: StellaratorObservation) -> StellaratorAction: |
| if obs.evaluation_failed: |
| return StellaratorAction(intent="restore_best") |
|
|
| if obs.constraints_satisfied: |
| if ( |
| obs.max_elongation <= FEASIBLE_SUBMIT_ELONGATION_MAX |
| or obs.budget_remaining <= 2 |
| or obs.step_number >= 3 |
| ): |
| return StellaratorAction(intent="submit") |
| return StellaratorAction( |
| intent="run", |
| parameter="elongation", |
| direction="decrease", |
| magnitude="small", |
| ) |
|
|
| if obs.average_triangularity > TRIANGULARITY_TARGET_MAX: |
| if obs.step_number == 0 and obs.edge_iota_over_nfp < LOW_IOTA_RESET_THRESHOLD: |
| return StellaratorAction( |
| intent="run", |
| parameter="rotational_transform", |
| direction="increase", |
| magnitude="medium", |
| ) |
| return StellaratorAction( |
| intent="run", |
| parameter="triangularity_scale", |
| direction="increase", |
| magnitude="medium", |
| ) |
|
|
| if obs.edge_iota_over_nfp < IOTA_RECOVERY_THRESHOLD: |
| return StellaratorAction( |
| intent="run", |
| parameter="rotational_transform", |
| direction="increase", |
| magnitude="small", |
| ) |
|
|
| if obs.aspect_ratio > ASPECT_RATIO_TARGET_MAX: |
| return StellaratorAction( |
| intent="run", |
| parameter="aspect_ratio", |
| direction="decrease", |
| magnitude="small", |
| ) |
|
|
| return StellaratorAction( |
| intent="run", |
| parameter="elongation", |
| direction="decrease", |
| magnitude="small", |
| ) |
|
|
|
|
| def _action_label(action: StellaratorAction) -> str: |
| if action.intent != "run": |
| return action.intent |
| return f"{action.parameter} {action.direction} {action.magnitude}" |
|
|
|
|
| def main(n_episodes: int = 20) -> None: |
| env = StellaratorEnvironment() |
| rewards: list[float] = [] |
|
|
| for i in range(n_episodes): |
| total_reward, trace = heuristic_episode(env, seed=i) |
| final = trace[-1] |
| rewards.append(total_reward) |
| print( |
| f"Episode {i:3d}: steps={len(trace) - 1} " |
| f"final_score={final['score']:.6f} fidelity={final['evaluation_fidelity']} " |
| f"constraints={'yes' if final['constraints_satisfied'] else 'no'} " |
| f"reward={total_reward:+.4f}" |
| ) |
|
|
| mean_reward = sum(rewards) / len(rewards) |
| print(f"\nHeuristic baseline ({n_episodes} episodes): mean_reward={mean_reward:+.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| n = int(sys.argv[1]) if len(sys.argv) > 1 else 20 |
| main(n) |
|
|