| """Smoke tests for the training pipeline (rollout, dry-run trainers).""" |
| import json |
| import tempfile |
| from pathlib import Path |
|
|
| from forgeenv.env.forge_environment import ForgeEnvironment |
| from forgeenv.training.grpo_repair import run_grpo |
| from forgeenv.training.grpo_drift import run_drift_grpo_dry_run |
| from forgeenv.training.rollout import ( |
| baseline_oracle_repair_generate, |
| rollout_one_episode, |
| ) |
|
|
|
|
| def test_rollout_one_episode_baseline_no_op_repair(): |
| env = ForgeEnvironment(seed=1) |
| result = rollout_one_episode(env) |
| assert result.task_id |
| assert result.primitive_type |
| assert isinstance(result.visible_reward, float) |
| assert "executed_cleanly" in result.held_out_breakdown |
|
|
|
|
| def test_rollout_one_episode_with_oracle_repair_succeeds(): |
| env = ForgeEnvironment(seed=2) |
| repair_gen = baseline_oracle_repair_generate(env) |
| result = rollout_one_episode(env, repair_generate=repair_gen) |
| |
| assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7 |
|
|
|
|
| def test_grpo_repair_dry_run_smoke(): |
| with tempfile.TemporaryDirectory() as tmp: |
| run_grpo( |
| base_model="(unused-in-dry-run)", |
| adapter_path=None, |
| output_dir=tmp, |
| total_episodes=5, |
| group_size=2, |
| seed=0, |
| use_unsloth=False, |
| ) |
| rewards_path = Path(tmp) / "dry_run_rewards.json" |
| assert rewards_path.exists() |
| rewards = json.loads(rewards_path.read_text()) |
| assert len(rewards) == 5 |
| assert all(isinstance(r, (int, float)) for r in rewards) |
|
|
|
|
| def test_grpo_drift_dry_run_smoke(): |
| with tempfile.TemporaryDirectory() as tmp: |
| run_drift_grpo_dry_run( |
| output_dir=tmp, total_episodes=3, group_size=2, seed=0 |
| ) |
| log_path = Path(tmp) / "drift_dry_run.json" |
| assert log_path.exists() |
| log = json.loads(log_path.read_text()) |
| assert len(log) == 3 |
| for entry in log: |
| assert "rewards" in entry and "candidates" in entry |
| assert all(0.0 <= r <= 2.0 for r in entry["rewards"]) |
|
|