| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| import train.run_experiment as run_experiment | |
| def test_adapter_dataset_rebuilds_when_existing_bundle_lacks_proposal_targets(monkeypatch, tmp_path): | |
| dataset_path = tmp_path / "proxy_train.pt" | |
| dataset_path.write_text("stub", encoding="utf-8") | |
| builder = object() | |
| captured: dict[str, object] = {} | |
| monkeypatch.setattr( | |
| run_experiment, | |
| "load_teacher_dataset", | |
| lambda path: {"samples": [{"task_name": "bag"}]}, | |
| ) | |
| def _fake_collect_teacher_dataset(**kwargs): | |
| captured.update(kwargs) | |
| return { | |
| "samples": [ | |
| { | |
| "task_name": "bag", | |
| "proposal_target_action_chunks": 1, | |
| "proposal_target_retrieval_success": 1, | |
| "proposal_target_risk": 1, | |
| "proposal_target_utility": 1, | |
| } | |
| ] | |
| } | |
| monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset) | |
| monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path)) | |
| data_cfg = OmegaConf.create( | |
| { | |
| "proxies": ["bag_proxy"], | |
| "resolution": 16, | |
| "seed": 17, | |
| "chunk_horizon": 2, | |
| "rollout_horizon": 2, | |
| "history_steps": 1, | |
| "planner_candidates": 2, | |
| "dataset_version": "proxy_test", | |
| "train_episodes_per_proxy": 1, | |
| "train_dataset_path": str(dataset_path), | |
| "rebuild_dataset": False, | |
| } | |
| ) | |
| bundle = run_experiment._build_dataset_from_config( | |
| data_cfg, | |
| "train", | |
| proposal_target_builder=builder, | |
| require_proposal_targets=True, | |
| ) | |
| assert run_experiment._bundle_has_proposal_targets(bundle) | |
| assert captured["proposal_target_builder"] is builder | |
| def test_adapter_dataset_missing_proposal_targets_raises_without_builder(monkeypatch, tmp_path): | |
| dataset_path = tmp_path / "proxy_train.pt" | |
| dataset_path.write_text("stub", encoding="utf-8") | |
| monkeypatch.setattr( | |
| run_experiment, | |
| "load_teacher_dataset", | |
| lambda path: {"samples": [{"task_name": "bag"}]}, | |
| ) | |
| data_cfg = OmegaConf.create( | |
| { | |
| "proxies": ["bag_proxy"], | |
| "resolution": 16, | |
| "seed": 17, | |
| "chunk_horizon": 2, | |
| "rollout_horizon": 2, | |
| "history_steps": 1, | |
| "planner_candidates": 2, | |
| "dataset_version": "proxy_test", | |
| "train_episodes_per_proxy": 1, | |
| "train_dataset_path": str(dataset_path), | |
| "rebuild_dataset": False, | |
| } | |
| ) | |
| try: | |
| run_experiment._build_dataset_from_config( | |
| data_cfg, | |
| "train", | |
| proposal_target_builder=None, | |
| require_proposal_targets=True, | |
| ) | |
| except RuntimeError as exc: | |
| assert "proposal-aligned targets" in str(exc) | |
| else: | |
| raise AssertionError("Expected a RuntimeError for unaligned adapter dataset.") | |
| def test_adapter_dataset_rebuilds_when_transition_rollout_targets_are_missing(monkeypatch, tmp_path): | |
| dataset_path = tmp_path / "proxy_train.pt" | |
| dataset_path.write_text("stub", encoding="utf-8") | |
| builder = object() | |
| captured: dict[str, object] = {} | |
| monkeypatch.setattr( | |
| run_experiment, | |
| "load_teacher_dataset", | |
| lambda path: { | |
| "samples": [ | |
| { | |
| "task_name": "bag", | |
| "proposal_target_action_chunks": 1, | |
| "proposal_target_retrieval_success": 1, | |
| "proposal_target_risk": 1, | |
| "proposal_target_utility": 1, | |
| } | |
| ] | |
| }, | |
| ) | |
| def _fake_collect_teacher_dataset(**kwargs): | |
| captured.update(kwargs) | |
| return { | |
| "samples": [ | |
| { | |
| "task_name": "bag", | |
| "proposal_target_action_chunks": 1, | |
| "proposal_target_retrieval_success": 1, | |
| "proposal_target_risk": 1, | |
| "proposal_target_utility": 1, | |
| "proposal_target_rollout_support_mode": 1, | |
| "proposal_target_rollout_corridor_feasible": 1, | |
| "proposal_target_rollout_persistence_horizon": 1, | |
| "proposal_target_rollout_disturbance_cost": 1, | |
| } | |
| ] | |
| } | |
| monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset) | |
| monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path)) | |
| data_cfg = OmegaConf.create( | |
| { | |
| "proxies": ["bag_proxy"], | |
| "resolution": 16, | |
| "seed": 17, | |
| "chunk_horizon": 2, | |
| "rollout_horizon": 2, | |
| "history_steps": 1, | |
| "planner_candidates": 2, | |
| "dataset_version": "proxy_test", | |
| "train_episodes_per_proxy": 1, | |
| "train_dataset_path": str(dataset_path), | |
| "rebuild_dataset": False, | |
| } | |
| ) | |
| bundle = run_experiment._build_dataset_from_config( | |
| data_cfg, | |
| "train", | |
| proposal_target_builder=builder, | |
| require_proposal_targets=True, | |
| require_proposal_rollout_targets=True, | |
| ) | |
| assert run_experiment._bundle_has_proposal_rollout_targets(bundle) | |
| assert captured["proposal_target_builder"] is builder | |