| import dataclasses |
| import os |
| import pathlib |
|
|
| import pytest |
|
|
| os.environ["JAX_PLATFORMS"] = "cpu" |
|
|
| from openpi.training import config as _config |
|
|
| from . import train |
|
|
|
|
| @pytest.mark.parametrize("config_name", ["debug"]) |
| def test_train(tmp_path: pathlib.Path, config_name: str): |
| config = dataclasses.replace( |
| _config._CONFIGS_DICT[config_name], |
| batch_size=2, |
| checkpoint_base_dir=str(tmp_path / "checkpoint"), |
| exp_name="test", |
| overwrite=False, |
| resume=False, |
| num_train_steps=2, |
| log_interval=1, |
| ) |
| train.main(config) |
|
|
| |
| config = dataclasses.replace(config, resume=True, num_train_steps=4) |
| train.main(config) |
|
|