CreativeEngineer commited on
Commit
ebd0ff3
·
1 Parent(s): 729c711

feat: add llm rollout contract and simplify ppo smoke

Browse files
fusion_lab/llm_agent.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Final, Sequence
7
+
8
+ from fusion_lab.models import (
9
+ DirectionName,
10
+ MagnitudeName,
11
+ ParameterName,
12
+ StellaratorAction,
13
+ StellaratorObservation,
14
+ )
15
+ from server.environment import BUDGET, StellaratorEnvironment
16
+
17
+ RUN_PARAMETERS: Final[tuple[ParameterName, ...]] = (
18
+ "aspect_ratio",
19
+ "elongation",
20
+ "rotational_transform",
21
+ "triangularity_scale",
22
+ )
23
+ RUN_DIRECTIONS: Final[tuple[DirectionName, ...]] = ("increase", "decrease")
24
+ RUN_MAGNITUDES: Final[tuple[MagnitudeName, ...]] = ("small", "medium", "large")
25
+
26
+ SYSTEM_PROMPT: Final[str] = """You are an expert stellarator designer.
27
+
28
+ Goal:
29
+ - satisfy the P1 physics constraints
30
+ - then improve the design score by lowering max elongation
31
+
32
+ You control a 4-knob low-dimensional design:
33
+ - aspect_ratio
34
+ - elongation
35
+ - rotational_transform
36
+ - triangularity_scale
37
+
38
+ Action rules:
39
+ - output a JSON array
40
+ - each item must be either:
41
+ - {"intent":"run","parameter":"<parameter>","direction":"increase|decrease","magnitude":"small|medium|large"}
42
+ - {"intent":"restore_best"}
43
+ - {"intent":"submit"}
44
+ - keep the plan short and within the remaining budget
45
+ - use "submit" only when the design looks ready
46
+
47
+ Constraint directions:
48
+ - aspect_ratio <= 4.0
49
+ - average_triangularity <= -0.5
50
+ - edge_iota_over_nfp >= 0.3"""
51
+
52
+ ACTION_ARRAY_PATTERN: Final[re.Pattern[str]] = re.compile(r"\[[\s\S]*\]")
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class LLMStepTrace:
57
+ step: int
58
+ action_label: str
59
+ reward: float
60
+ p1_score: float
61
+ p1_feasibility: float
62
+ constraints_satisfied: bool
63
+ evaluation_fidelity: str
64
+ evaluation_failed: bool
65
+ budget_remaining: int
66
+ diagnostics_text: str
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class LLMEpisodeTrace:
71
+ seed: int
72
+ total_reward: float
73
+ final_score: float
74
+ final_feasibility: float
75
+ constraints_satisfied: bool
76
+ evaluation_failed: bool
77
+ steps: list[LLMStepTrace]
78
+
79
+ def asdict(self) -> dict[str, object]:
80
+ return asdict(self)
81
+
82
+
83
+ def action_label(action: StellaratorAction) -> str:
84
+ if action.intent != "run":
85
+ return action.intent
86
+ return f"{action.intent} {action.parameter} {action.direction} {action.magnitude}"
87
+
88
+
89
+ def format_observation(observation: StellaratorObservation) -> str:
90
+ return (
91
+ "Current stellarator state:\n"
92
+ f"- max_elongation: {observation.max_elongation:.4f}\n"
93
+ f"- aspect_ratio: {observation.aspect_ratio:.4f} (must stay <= 4.0)\n"
94
+ f"- average_triangularity: {observation.average_triangularity:.6f} "
95
+ "(must stay <= -0.5)\n"
96
+ f"- edge_iota_over_nfp: {observation.edge_iota_over_nfp:.4f} "
97
+ "(must stay >= 0.3)\n"
98
+ f"- p1_score: {observation.p1_score:.4f}\n"
99
+ f"- p1_feasibility: {observation.p1_feasibility:.6f}\n"
100
+ f"- constraints_satisfied: {observation.constraints_satisfied}\n"
101
+ f"- evaluation_fidelity: {observation.evaluation_fidelity}\n"
102
+ f"- evaluation_failed: {observation.evaluation_failed}\n"
103
+ f"- budget_remaining: {observation.budget_remaining}\n"
104
+ f"- best_low_fidelity_score: {observation.best_low_fidelity_score:.4f}\n"
105
+ f"- best_low_fidelity_feasibility: {observation.best_low_fidelity_feasibility:.6f}\n"
106
+ f"- diagnostics: {observation.diagnostics_text}\n"
107
+ )
108
+
109
+
110
+ def build_prompt(observation: StellaratorObservation) -> str:
111
+ return (
112
+ f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
113
+ f"<|im_start|>user\n{format_observation(observation)}<|im_end|>\n"
114
+ "<|im_start|>assistant\n"
115
+ )
116
+
117
+
118
+ def extract_json_plan(text: str) -> str | None:
119
+ match = ACTION_ARRAY_PATTERN.search(text)
120
+ if match is None:
121
+ return None
122
+ return match.group()
123
+
124
+
125
+ def _parse_action_item(item: object) -> StellaratorAction | None:
126
+ if not isinstance(item, dict):
127
+ return None
128
+
129
+ intent = item.get("intent")
130
+ if intent == "submit":
131
+ return StellaratorAction(intent="submit")
132
+ if intent == "restore_best":
133
+ return StellaratorAction(intent="restore_best")
134
+ if intent != "run":
135
+ return None
136
+
137
+ parameter = item.get("parameter")
138
+ direction = item.get("direction")
139
+ magnitude = item.get("magnitude", "small")
140
+ if parameter not in RUN_PARAMETERS:
141
+ return None
142
+ if direction not in RUN_DIRECTIONS:
143
+ return None
144
+ if magnitude not in RUN_MAGNITUDES:
145
+ return None
146
+
147
+ return StellaratorAction(
148
+ intent="run",
149
+ parameter=parameter,
150
+ direction=direction,
151
+ magnitude=magnitude,
152
+ )
153
+
154
+
155
+ def parse_action_plan(text: str) -> list[StellaratorAction]:
156
+ raw_plan = extract_json_plan(text)
157
+ if raw_plan is None:
158
+ return []
159
+ try:
160
+ decoded = json.loads(raw_plan)
161
+ except json.JSONDecodeError:
162
+ return []
163
+ if not isinstance(decoded, list):
164
+ return []
165
+
166
+ parsed: list[StellaratorAction] = []
167
+ for item in decoded:
168
+ action = _parse_action_item(item)
169
+ if action is None:
170
+ continue
171
+ parsed.append(action)
172
+ if action.intent == "submit":
173
+ break
174
+ return parsed
175
+
176
+
177
+ def run_episode_with_actions(
178
+ actions: Sequence[StellaratorAction],
179
+ *,
180
+ seed_idx: int,
181
+ ) -> LLMEpisodeTrace:
182
+ environment = StellaratorEnvironment()
183
+ observation = environment.reset(seed=seed_idx)
184
+ step_traces: list[LLMStepTrace] = []
185
+ total_reward = 0.0
186
+
187
+ for step_index, action in enumerate(actions[:BUDGET], start=1):
188
+ observation = environment.step(action)
189
+ reward = float(observation.reward or 0.0)
190
+ total_reward += reward
191
+ step_traces.append(
192
+ LLMStepTrace(
193
+ step=step_index,
194
+ action_label=action_label(action),
195
+ reward=reward,
196
+ p1_score=observation.p1_score,
197
+ p1_feasibility=observation.p1_feasibility,
198
+ constraints_satisfied=observation.constraints_satisfied,
199
+ evaluation_fidelity=observation.evaluation_fidelity,
200
+ evaluation_failed=observation.evaluation_failed,
201
+ budget_remaining=observation.budget_remaining,
202
+ diagnostics_text=observation.diagnostics_text,
203
+ )
204
+ )
205
+ if observation.done:
206
+ break
207
+
208
+ return LLMEpisodeTrace(
209
+ seed=seed_idx,
210
+ total_reward=round(total_reward, 4),
211
+ final_score=observation.p1_score,
212
+ final_feasibility=observation.p1_feasibility,
213
+ constraints_satisfied=observation.constraints_satisfied,
214
+ evaluation_failed=observation.evaluation_failed,
215
+ steps=step_traces,
216
+ )
training/README.md CHANGED
@@ -19,3 +19,17 @@ Training policy:
19
 
20
  - install the training dependencies: `uv sync --extra training`
21
  - tiny low-fi PPO smoke run: `uv run --extra training python training/ppo_smoke.py`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  - install the training dependencies: `uv sync --extra training`
21
  - tiny low-fi PPO smoke run: `uv run --extra training python training/ppo_smoke.py`
22
+ - generate an LLM-ready prompt payload: `uv run python training/llm_rollout.py prompt --seed 0`
23
+ - replay an LLM completion or action plan: `uv run python training/llm_rollout.py replay --seed 0 --completion-file <path>`
24
+
25
+ ## Shared LLM Contract
26
+
27
+ The prompt/action/replay contract for LLM training lives in:
28
+
29
+ - `fusion_lab/llm_agent.py`
30
+
31
+ Use that module as the source of truth for:
32
+
33
+ - prompt formatting
34
+ - action-plan parsing
35
+ - local rollout replay
training/llm_rollout.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from datetime import UTC, datetime
6
+ from pathlib import Path
7
+ from typing import Final
8
+
9
+ from fusion_lab.llm_agent import (
10
+ build_prompt,
11
+ parse_action_plan,
12
+ run_episode_with_actions,
13
+ )
14
+ from fusion_lab.models import StellaratorAction
15
+ from server.environment import StellaratorEnvironment
16
+
17
+ DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_rollout")
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = argparse.ArgumentParser(
22
+ description=(
23
+ "Generate an LLM-ready prompt or replay an LLM completion against the live "
24
+ "Fusion Design Lab environment."
25
+ )
26
+ )
27
+ subparsers = parser.add_subparsers(dest="command", required=True)
28
+
29
+ prompt_parser = subparsers.add_parser("prompt", help="Print or save an LLM prompt.")
30
+ prompt_parser.add_argument("--seed", type=int, default=0, help="Reset seed index.")
31
+ prompt_parser.add_argument(
32
+ "--output-file",
33
+ type=Path,
34
+ default=None,
35
+ help="Optional JSON file path for the prompt payload.",
36
+ )
37
+
38
+ replay_parser = subparsers.add_parser(
39
+ "replay",
40
+ help="Replay a completion or action-plan file and save a rollout artifact.",
41
+ )
42
+ replay_parser.add_argument("--seed", type=int, default=0, help="Reset seed index.")
43
+ replay_parser.add_argument(
44
+ "--completion-file",
45
+ type=Path,
46
+ default=None,
47
+ help="Path to a raw LLM completion containing a JSON action array.",
48
+ )
49
+ replay_parser.add_argument(
50
+ "--action-plan-file",
51
+ type=Path,
52
+ default=None,
53
+ help="Path to a JSON array of actions.",
54
+ )
55
+ replay_parser.add_argument(
56
+ "--output-dir",
57
+ type=Path,
58
+ default=DEFAULT_OUTPUT_DIR,
59
+ help="Directory for rollout artifacts.",
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ def prompt_payload(seed: int) -> dict[str, object]:
65
+ environment = StellaratorEnvironment()
66
+ observation = environment.reset(seed=seed)
67
+ return {
68
+ "created_at_utc": datetime.now(UTC).isoformat(),
69
+ "seed": seed,
70
+ "prompt": build_prompt(observation),
71
+ "target_spec": observation.target_spec,
72
+ "budget_remaining": observation.budget_remaining,
73
+ "diagnostics_text": observation.diagnostics_text,
74
+ }
75
+
76
+
77
+ def parse_actions(args: argparse.Namespace) -> tuple[str, list[StellaratorAction]]:
78
+ if args.action_plan_file is not None:
79
+ text = args.action_plan_file.read_text()
80
+ source = str(args.action_plan_file)
81
+ elif args.completion_file is not None:
82
+ text = args.completion_file.read_text()
83
+ source = str(args.completion_file)
84
+ else:
85
+ raise ValueError("replay requires --completion-file or --action-plan-file")
86
+
87
+ return source, parse_action_plan(text)
88
+
89
+
90
+ def write_json(path: Path, payload: dict[str, object]) -> None:
91
+ path.parent.mkdir(parents=True, exist_ok=True)
92
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
93
+
94
+
95
+ def run_prompt(args: argparse.Namespace) -> None:
96
+ payload = prompt_payload(args.seed)
97
+ if args.output_file is not None:
98
+ write_json(args.output_file, payload)
99
+ print(args.output_file)
100
+ return
101
+ print(json.dumps(payload, indent=2))
102
+
103
+
104
+ def run_replay(args: argparse.Namespace) -> None:
105
+ source, actions = parse_actions(args)
106
+ trace = run_episode_with_actions(actions, seed_idx=args.seed)
107
+ timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
108
+ output_path = args.output_dir / f"llm_rollout_{timestamp}.json"
109
+ payload = {
110
+ "created_at_utc": datetime.now(UTC).isoformat(),
111
+ "seed": args.seed,
112
+ "source": source,
113
+ "parsed_action_count": len(actions),
114
+ "actions": [action.model_dump(exclude_none=True) for action in actions],
115
+ "trace": trace.asdict(),
116
+ }
117
+ write_json(output_path, payload)
118
+ print(output_path)
119
+
120
+
121
+ def main() -> None:
122
+ args = parse_args()
123
+ if args.command == "prompt":
124
+ run_prompt(args)
125
+ return
126
+ run_replay(args)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
training/ppo_smoke.py CHANGED
@@ -17,37 +17,16 @@ from server.contract import RESET_SEEDS
17
  from server.environment import BUDGET, StellaratorEnvironment
18
 
19
  DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/ppo_smoke")
20
- DEFAULT_TOTAL_TIMESTEPS: Final[int] = 128
21
  DEFAULT_EVAL_EPISODES: Final[int] = 3
 
22
 
23
- RUN_ACTION_SPECS: Final[tuple[tuple[str, str, str], ...]] = (
24
- ("aspect_ratio", "increase", "small"),
25
- ("aspect_ratio", "increase", "medium"),
26
- ("aspect_ratio", "increase", "large"),
27
- ("aspect_ratio", "decrease", "small"),
28
- ("aspect_ratio", "decrease", "medium"),
29
- ("aspect_ratio", "decrease", "large"),
30
- ("elongation", "increase", "small"),
31
- ("elongation", "increase", "medium"),
32
- ("elongation", "increase", "large"),
33
- ("elongation", "decrease", "small"),
34
- ("elongation", "decrease", "medium"),
35
- ("elongation", "decrease", "large"),
36
- ("rotational_transform", "increase", "small"),
37
  ("rotational_transform", "increase", "medium"),
38
- ("rotational_transform", "increase", "large"),
39
- ("rotational_transform", "decrease", "small"),
40
- ("rotational_transform", "decrease", "medium"),
41
- ("rotational_transform", "decrease", "large"),
42
- ("triangularity_scale", "increase", "small"),
43
  ("triangularity_scale", "increase", "medium"),
44
- ("triangularity_scale", "increase", "large"),
45
- ("triangularity_scale", "decrease", "small"),
46
- ("triangularity_scale", "decrease", "medium"),
47
- ("triangularity_scale", "decrease", "large"),
48
  )
49
- LOW_FI_ACTION_COUNT: Final[int] = len(RUN_ACTION_SPECS) + 1
50
- LOW_FI_RESTORE_ACTION_INDEX: Final[int] = len(RUN_ACTION_SPECS)
51
 
52
 
53
  @dataclass(frozen=True)
@@ -61,6 +40,7 @@ class TraceStep:
61
  constraints_satisfied: bool
62
  evaluation_failed: bool
63
  budget_remaining: int
 
64
  max_elongation: float
65
  average_triangularity: float
66
  edge_iota_over_nfp: float
@@ -75,9 +55,25 @@ class EpisodeTrace:
75
  final_feasibility: float
76
  constraints_satisfied: bool
77
  evaluation_failed: bool
 
78
  steps: list[TraceStep]
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
82
  metadata = {"render_modes": []}
83
 
@@ -89,7 +85,8 @@ class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
89
  self.observation_space = spaces.Box(
90
  low=-np.inf,
91
  high=np.inf,
92
- shape=(12,),
 
93
  dtype=np.float32,
94
  )
95
  self.action_space = spaces.Discrete(LOW_FI_ACTION_COUNT)
@@ -109,7 +106,9 @@ class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
109
  if seed is not None:
110
  self._episode_index = 0
111
  return seed % len(RESET_SEEDS)
112
- next_seed = self._episode_index % len(RESET_SEEDS)
 
 
113
  self._episode_index += 1
114
  return next_seed
115
 
@@ -120,30 +119,20 @@ class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
120
  obs = self._env.step(self._decode_action(action))
121
  return (
122
  self._encode_observation(obs),
123
- float(obs.reward or 0.0),
124
  bool(obs.done),
125
  False,
126
  self._info(obs),
127
  )
128
 
129
  def _decode_action(self, action: int) -> StellaratorAction:
130
- if action == LOW_FI_RESTORE_ACTION_INDEX:
131
- return StellaratorAction(intent="restore_best")
132
- parameter, direction, magnitude = RUN_ACTION_SPECS[action]
133
- return StellaratorAction(
134
- intent="run",
135
- parameter=parameter,
136
- direction=direction,
137
- magnitude=magnitude,
138
- )
139
 
140
  def action_label(self, action: int) -> str:
141
- if action == LOW_FI_RESTORE_ACTION_INDEX:
142
- return "restore_best"
143
- parameter, direction, magnitude = RUN_ACTION_SPECS[action]
144
- return f"{parameter} {direction} {magnitude}"
145
 
146
  def _encode_observation(self, obs: StellaratorObservation) -> np.ndarray:
 
147
  budget_fraction = obs.budget_remaining / BUDGET
148
  step_fraction = obs.step_number / BUDGET
149
  return np.array(
@@ -155,11 +144,16 @@ class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
155
  obs.p1_score,
156
  obs.p1_feasibility,
157
  obs.vacuum_well,
 
 
 
 
158
  budget_fraction,
159
  step_fraction,
160
  obs.best_low_fidelity_score,
161
  obs.best_low_fidelity_feasibility,
162
- float(obs.constraints_satisfied) - float(obs.evaluation_failed),
 
163
  ],
164
  dtype=np.float32,
165
  )
@@ -172,8 +166,22 @@ class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
172
  "evaluation_failed": obs.evaluation_failed,
173
  "p1_score": obs.p1_score,
174
  "p1_feasibility": obs.p1_feasibility,
 
 
 
 
 
175
  }
176
 
 
 
 
 
 
 
 
 
 
177
 
178
  def parse_args() -> argparse.Namespace:
179
  parser = argparse.ArgumentParser(
@@ -216,26 +224,30 @@ def build_model(env: LowFiSmokeEnv, seed: int) -> PPO:
216
  seed=seed,
217
  verbose=0,
218
  device="cpu",
219
- n_steps=32,
220
- batch_size=32,
221
- n_epochs=4,
222
- gamma=0.98,
223
  learning_rate=3e-4,
224
  ent_coef=0.01,
225
  )
226
 
227
 
228
- def evaluate_policy(model: PPO, *, eval_episodes: int, base_seed: int) -> list[EpisodeTrace]:
 
 
229
  traces: list[EpisodeTrace] = []
 
 
230
  for episode in range(eval_episodes):
231
- env = LowFiSmokeEnv()
232
  seed = base_seed + episode
233
- obs, _ = env.reset(seed=seed)
 
234
  done = False
235
  total_reward = 0.0
236
  steps: list[TraceStep] = []
237
  step_index = 0
238
- final_info: dict[str, object] = {}
239
 
240
  while not done:
241
  action, _ = model.predict(obs, deterministic=True)
@@ -256,9 +268,10 @@ def evaluate_policy(model: PPO, *, eval_episodes: int, base_seed: int) -> list[E
256
  constraints_satisfied=bool(info["constraints_satisfied"]),
257
  evaluation_failed=bool(info["evaluation_failed"]),
258
  budget_remaining=int(info["budget_remaining"]),
259
- max_elongation=float(obs[0]),
260
- average_triangularity=float(obs[2]),
261
- edge_iota_over_nfp=float(obs[3]),
 
262
  )
263
  )
264
 
@@ -271,10 +284,11 @@ def evaluate_policy(model: PPO, *, eval_episodes: int, base_seed: int) -> list[E
271
  final_feasibility=float(final_info["p1_feasibility"]),
272
  constraints_satisfied=bool(final_info["constraints_satisfied"]),
273
  evaluation_failed=bool(final_info["evaluation_failed"]),
 
274
  steps=steps,
275
  )
276
  )
277
- return traces
278
 
279
 
280
  def artifact_payload(
@@ -282,6 +296,7 @@ def artifact_payload(
282
  total_timesteps: int,
283
  eval_episodes: int,
284
  seed: int,
 
285
  traces: list[EpisodeTrace],
286
  ) -> dict[str, object]:
287
  mean_reward = sum(trace.total_reward for trace in traces) / max(len(traces), 1)
@@ -292,12 +307,16 @@ def artifact_payload(
292
  "total_timesteps": total_timesteps,
293
  "eval_episodes": eval_episodes,
294
  "seed": seed,
295
- "train_reset_seed_indices": list(range(len(RESET_SEEDS))),
 
296
  "action_space_size": LOW_FI_ACTION_COUNT,
 
 
 
297
  "notes": (
298
- "Diagnostic-only PPO smoke run. Submit is intentionally excluded here so the "
299
- "smoke loop stays low-fidelity and fast. Training resets cycle through the "
300
- "frozen low-fidelity reset seeds to surface positive repair signal sooner."
301
  ),
302
  "summary": {
303
  "mean_eval_reward": round(mean_reward, 4),
@@ -320,7 +339,7 @@ def main() -> None:
320
  env = LowFiSmokeEnv()
321
  model = build_model(env, seed=args.seed)
322
  model.learn(total_timesteps=args.total_timesteps, progress_bar=False)
323
- traces = evaluate_policy(
324
  model,
325
  eval_episodes=args.eval_episodes,
326
  base_seed=args.seed,
@@ -329,10 +348,14 @@ def main() -> None:
329
  total_timesteps=args.total_timesteps,
330
  eval_episodes=args.eval_episodes,
331
  seed=args.seed,
 
332
  traces=traces,
333
  )
334
  output_path = write_artifact(args.output_dir, payload)
 
335
  print(output_path)
 
 
336
 
337
 
338
  if __name__ == "__main__":
 
17
  from server.environment import BUDGET, StellaratorEnvironment
18
 
19
  DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/ppo_smoke")
20
+ DEFAULT_TOTAL_TIMESTEPS: Final[int] = 32
21
  DEFAULT_EVAL_EPISODES: Final[int] = 3
22
+ ENCODED_OBSERVATION_DIM: Final[int] = 17
23
 
24
+ DIAGNOSTIC_RUN_ACTION_SPECS: Final[tuple[tuple[str, str, str], ...]] = (
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ("rotational_transform", "increase", "medium"),
 
 
 
 
 
26
  ("triangularity_scale", "increase", "medium"),
 
 
 
 
27
  )
28
+ TRAIN_RESET_SEED_INDICES: Final[tuple[int, ...]] = (2,)
29
+ LOW_FI_ACTION_COUNT: Final[int] = len(DIAGNOSTIC_RUN_ACTION_SPECS)
30
 
31
 
32
  @dataclass(frozen=True)
 
40
  constraints_satisfied: bool
41
  evaluation_failed: bool
42
  budget_remaining: int
43
+ termination_reason: str
44
  max_elongation: float
45
  average_triangularity: float
46
  edge_iota_over_nfp: float
 
55
  final_feasibility: float
56
  constraints_satisfied: bool
57
  evaluation_failed: bool
58
+ termination_reason: str
59
  steps: list[TraceStep]
60
 
61
 
62
+ def diagnostic_action(action_index: int) -> StellaratorAction:
63
+ parameter, direction, magnitude = DIAGNOSTIC_RUN_ACTION_SPECS[action_index]
64
+ return StellaratorAction(
65
+ intent="run",
66
+ parameter=parameter,
67
+ direction=direction,
68
+ magnitude=magnitude,
69
+ )
70
+
71
+
72
+ def diagnostic_action_label(action_index: int) -> str:
73
+ action = diagnostic_action(action_index)
74
+ return f"{action.parameter} {action.direction} {action.magnitude}"
75
+
76
+
77
  class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
78
  metadata = {"render_modes": []}
79
 
 
85
  self.observation_space = spaces.Box(
86
  low=-np.inf,
87
  high=np.inf,
88
+ # Keep this aligned with _encode_observation feature count.
89
+ shape=(ENCODED_OBSERVATION_DIM,),
90
  dtype=np.float32,
91
  )
92
  self.action_space = spaces.Discrete(LOW_FI_ACTION_COUNT)
 
106
  if seed is not None:
107
  self._episode_index = 0
108
  return seed % len(RESET_SEEDS)
109
+ if not TRAIN_RESET_SEED_INDICES:
110
+ raise ValueError("TRAIN_RESET_SEED_INDICES must define at least one seed index.")
111
+ next_seed = TRAIN_RESET_SEED_INDICES[self._episode_index % len(TRAIN_RESET_SEED_INDICES)]
112
  self._episode_index += 1
113
  return next_seed
114
 
 
119
  obs = self._env.step(self._decode_action(action))
120
  return (
121
  self._encode_observation(obs),
122
+ float(obs.reward if obs.reward is not None else 0.0),
123
  bool(obs.done),
124
  False,
125
  self._info(obs),
126
  )
127
 
128
  def _decode_action(self, action: int) -> StellaratorAction:
129
+ return diagnostic_action(action)
 
 
 
 
 
 
 
 
130
 
131
  def action_label(self, action: int) -> str:
132
+ return diagnostic_action_label(action)
 
 
 
133
 
134
  def _encode_observation(self, obs: StellaratorObservation) -> np.ndarray:
135
+ params = self._env.state.current_params
136
  budget_fraction = obs.budget_remaining / BUDGET
137
  step_fraction = obs.step_number / BUDGET
138
  return np.array(
 
144
  obs.p1_score,
145
  obs.p1_feasibility,
146
  obs.vacuum_well,
147
+ params.aspect_ratio,
148
+ params.elongation,
149
+ params.rotational_transform,
150
+ params.triangularity_scale,
151
  budget_fraction,
152
  step_fraction,
153
  obs.best_low_fidelity_score,
154
  obs.best_low_fidelity_feasibility,
155
+ float(obs.constraints_satisfied),
156
+ float(obs.evaluation_failed),
157
  ],
158
  dtype=np.float32,
159
  )
 
166
  "evaluation_failed": obs.evaluation_failed,
167
  "p1_score": obs.p1_score,
168
  "p1_feasibility": obs.p1_feasibility,
169
+ "max_elongation": obs.max_elongation,
170
+ "average_triangularity": obs.average_triangularity,
171
+ "edge_iota_over_nfp": obs.edge_iota_over_nfp,
172
+ "termination_reason": self._termination_reason(obs),
173
+ "current_seed": self._seed,
174
  }
175
 
176
+ def _termination_reason(self, obs: StellaratorObservation) -> str:
177
+ if obs.evaluation_failed:
178
+ return "evaluation_failed"
179
+ if obs.constraints_satisfied:
180
+ return "constraints_satisfied"
181
+ if obs.done:
182
+ return "budget_exhausted"
183
+ return "in_progress"
184
+
185
 
186
  def parse_args() -> argparse.Namespace:
187
  parser = argparse.ArgumentParser(
 
224
  seed=seed,
225
  verbose=0,
226
  device="cpu",
227
+ n_steps=16,
228
+ batch_size=16,
229
+ n_epochs=8,
230
+ gamma=0.995,
231
  learning_rate=3e-4,
232
  ent_coef=0.01,
233
  )
234
 
235
 
236
+ def evaluate_policy(
237
+ model: PPO, *, eval_episodes: int, base_seed: int
238
+ ) -> tuple[list[EpisodeTrace], list[int]]:
239
  traces: list[EpisodeTrace] = []
240
+ eval_reset_seed_indices: list[int] = []
241
+ env = LowFiSmokeEnv()
242
  for episode in range(eval_episodes):
 
243
  seed = base_seed + episode
244
+ eval_reset_seed_indices.append(seed % len(RESET_SEEDS))
245
+ obs, info = env.reset(seed=seed)
246
  done = False
247
  total_reward = 0.0
248
  steps: list[TraceStep] = []
249
  step_index = 0
250
+ final_info = dict[str, object](info)
251
 
252
  while not done:
253
  action, _ = model.predict(obs, deterministic=True)
 
268
  constraints_satisfied=bool(info["constraints_satisfied"]),
269
  evaluation_failed=bool(info["evaluation_failed"]),
270
  budget_remaining=int(info["budget_remaining"]),
271
+ termination_reason=str(info["termination_reason"]),
272
+ max_elongation=float(info["max_elongation"]),
273
+ average_triangularity=float(info["average_triangularity"]),
274
+ edge_iota_over_nfp=float(info["edge_iota_over_nfp"]),
275
  )
276
  )
277
 
 
284
  final_feasibility=float(final_info["p1_feasibility"]),
285
  constraints_satisfied=bool(final_info["constraints_satisfied"]),
286
  evaluation_failed=bool(final_info["evaluation_failed"]),
287
+ termination_reason=str(final_info["termination_reason"]),
288
  steps=steps,
289
  )
290
  )
291
+ return traces, eval_reset_seed_indices
292
 
293
 
294
  def artifact_payload(
 
296
  total_timesteps: int,
297
  eval_episodes: int,
298
  seed: int,
299
+ eval_reset_seed_indices: list[int],
300
  traces: list[EpisodeTrace],
301
  ) -> dict[str, object]:
302
  mean_reward = sum(trace.total_reward for trace in traces) / max(len(traces), 1)
 
307
  "total_timesteps": total_timesteps,
308
  "eval_episodes": eval_episodes,
309
  "seed": seed,
310
+ "train_reset_seed_indices": list(TRAIN_RESET_SEED_INDICES),
311
+ "eval_reset_seed_indices": eval_reset_seed_indices,
312
  "action_space_size": LOW_FI_ACTION_COUNT,
313
+ "diagnostic_run_actions": [
314
+ diagnostic_action_label(action_index) for action_index in range(LOW_FI_ACTION_COUNT)
315
+ ],
316
  "notes": (
317
+ "Diagnostics-only low-fidelity PPO smoke; submit is excluded and the action "
318
+ "space is narrowed to a two-step repair arc. Evaluation runs across "
319
+ "frozen seeds and records full low-fi traces."
320
  ),
321
  "summary": {
322
  "mean_eval_reward": round(mean_reward, 4),
 
339
  env = LowFiSmokeEnv()
340
  model = build_model(env, seed=args.seed)
341
  model.learn(total_timesteps=args.total_timesteps, progress_bar=False)
342
+ traces, eval_reset_seed_indices = evaluate_policy(
343
  model,
344
  eval_episodes=args.eval_episodes,
345
  base_seed=args.seed,
 
348
  total_timesteps=args.total_timesteps,
349
  eval_episodes=args.eval_episodes,
350
  seed=args.seed,
351
+ eval_reset_seed_indices=eval_reset_seed_indices,
352
  traces=traces,
353
  )
354
  output_path = write_artifact(args.output_dir, payload)
355
+ summary = payload["summary"]
356
  print(output_path)
357
+ print(f"constraint_satisfaction_rate={summary['constraint_satisfaction_rate']}")
358
+ print(f"mean_eval_reward={summary['mean_eval_reward']}")
359
 
360
 
361
  if __name__ == "__main__":