| import pytest
|
|
|
| from tests.helpers.run_if import RunIf
|
| from tests.helpers.run_sh_command import run_sh_command
|
|
|
| startfile = "runner/src/train.py"
|
| overrides = ["logger=[]"]
|
| dir_overrides = ["paths.data_dir", "hydra.sweep.dir"]
|
|
|
|
|
| @RunIf(sh=True)
|
| @pytest.mark.slow
|
| @pytest.mark.xfail(
|
| reason="Currently failing experiments with fast_dev_run which messes with gradients"
|
| )
|
| def test_xfail_fast_dev_experiments(tmp_path):
|
| """Test running all available experiment configs with fast_dev_run=True."""
|
| command = (
|
| [
|
| startfile,
|
| "-m",
|
| "experiment=glob(*)",
|
| "++trainer.fast_dev_run=true",
|
| ]
|
| + overrides
|
| + [f"{d}={tmp_path}" for d in dir_overrides]
|
| )
|
| run_sh_command(command)
|
|
|
|
|
| @RunIf(sh=True)
|
| @pytest.mark.slow
|
| def test_experiments(tmp_path):
|
| """Test running all available experiment configs with fast_dev_run=True."""
|
| command = (
|
| [
|
| startfile,
|
| "-m",
|
| "experiment=cfm",
|
| "model=cfm,otcfm,sbcfm,fm",
|
| "++trainer.fast_dev_run=true",
|
| "++trainer.limit_val_batches=0.25",
|
| ]
|
| + overrides
|
| + [f"{d}={tmp_path}" for d in dir_overrides]
|
| )
|
| run_sh_command(command)
|
|
|
|
|
| @RunIf(sh=True)
|
| @pytest.mark.slow
|
| def test_hydra_sweep(tmp_path):
|
| """Test default hydra sweep."""
|
| command = (
|
| [
|
| startfile,
|
| "-m",
|
| "hydra.sweep.dir=" + str(tmp_path),
|
| "model.optimizer.lr=0.005,0.01",
|
| "++trainer.fast_dev_run=true",
|
| ]
|
| + overrides
|
| + [f"{d}={tmp_path}" for d in dir_overrides]
|
| )
|
|
|
| run_sh_command(command)
|
|
|
|
|
| @RunIf(sh=True)
|
| @pytest.mark.slow
|
| @pytest.mark.xfail(reason="DDP is not working yet")
|
| def test_hydra_sweep_ddp_sim(tmp_path):
|
| """Test default hydra sweep with ddp sim."""
|
| command = (
|
| [
|
| startfile,
|
| "-m",
|
| "trainer=ddp_sim",
|
| "trainer.max_epochs=3",
|
| "+trainer.limit_train_batches=0.01",
|
| "+trainer.limit_val_batches=0.1",
|
| "+trainer.limit_test_batches=0.1",
|
| "model.optimizer.lr=0.005,0.01,0.02",
|
| ]
|
| + overrides
|
| + [f"{d}={tmp_path}" for d in dir_overrides]
|
| )
|
| run_sh_command(command)
|
|
|
|
|
| @RunIf(sh=True)
|
| @pytest.mark.slow
|
| @pytest.mark.skip(reason="Too slow for easy esting, pathway currently not used")
|
| def test_optuna_sweep(tmp_path):
|
| """Test optuna sweep."""
|
| command = (
|
| [
|
| startfile,
|
| "-m",
|
| "hparams_search=optuna",
|
| "hydra.sweep.dir=" + str(tmp_path),
|
| "hydra.sweeper.n_trials=3",
|
| "hydra.sweeper.sampler.n_startup_trials=2",
|
|
|
| ]
|
| + overrides
|
| + [f"{d}={tmp_path}" for d in dir_overrides]
|
| )
|
| run_sh_command(command)
|
|
|
|
|
| @RunIf(wandb=True, sh=True)
|
| @pytest.mark.slow
|
| @pytest.mark.xfail(reason="wandb import is still bad without API key")
|
| def test_optuna_sweep_ddp_sim_wandb(tmp_path):
|
| """Test optuna sweep with wandb and ddp sim."""
|
| command = [
|
| startfile,
|
| "-m",
|
| "hparams_search=optuna",
|
| "hydra.sweeper.n_trials=5",
|
| "trainer=ddp_sim",
|
| "trainer.max_epochs=3",
|
| "+trainer.limit_train_batches=0.01",
|
| "+trainer.limit_val_batches=0.1",
|
| "+trainer.limit_test_batches=0.1",
|
| "logger=wandb",
|
| ] + [f"{d}={tmp_path}" for d in dir_overrides]
|
| run_sh_command(command)
|
|
|