forgeenv-source / tests /test_training.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Smoke tests for the training pipeline (rollout, dry-run trainers)."""
import json
import tempfile
from pathlib import Path
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.grpo_repair import run_grpo
from forgeenv.training.grpo_drift import run_drift_grpo_dry_run
from forgeenv.training.rollout import (
baseline_oracle_repair_generate,
rollout_one_episode,
)
def test_rollout_one_episode_baseline_no_op_repair():
env = ForgeEnvironment(seed=1)
result = rollout_one_episode(env)
assert result.task_id
assert result.primitive_type
assert isinstance(result.visible_reward, float)
assert "executed_cleanly" in result.held_out_breakdown
def test_rollout_one_episode_with_oracle_repair_succeeds():
env = ForgeEnvironment(seed=2)
repair_gen = baseline_oracle_repair_generate(env)
result = rollout_one_episode(env, repair_generate=repair_gen)
# Oracle should usually score well on `intent_preserved` (script identical to original).
assert result.held_out_breakdown.get("intent_preserved", 0.0) > 0.7
def test_grpo_repair_dry_run_smoke():
with tempfile.TemporaryDirectory() as tmp:
run_grpo(
base_model="(unused-in-dry-run)",
adapter_path=None,
output_dir=tmp,
total_episodes=5,
group_size=2,
seed=0,
use_unsloth=False,
)
rewards_path = Path(tmp) / "dry_run_rewards.json"
assert rewards_path.exists()
rewards = json.loads(rewards_path.read_text())
assert len(rewards) == 5
assert all(isinstance(r, (int, float)) for r in rewards)
def test_grpo_drift_dry_run_smoke():
with tempfile.TemporaryDirectory() as tmp:
run_drift_grpo_dry_run(
output_dir=tmp, total_episodes=3, group_size=2, seed=0
)
log_path = Path(tmp) / "drift_dry_run.json"
assert log_path.exists()
log = json.loads(log_path.read_text())
assert len(log) == 3
for entry in log:
assert "rewards" in entry and "candidates" in entry
assert all(0.0 <= r <= 2.0 for r in entry["rewards"])