|
import os |
|
from pathlib import Path |
|
|
|
import pytest |
|
from hydra.core.hydra_config import HydraConfig |
|
from omegaconf import DictConfig, open_dict |
|
|
|
from src.eval import evaluate |
|
from src.train import train |
|
|
|
|
|
@pytest.mark.slow |
|
def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None: |
|
"""Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with |
|
`eval.py`. |
|
|
|
:param tmp_path: The temporary logging path. |
|
:param cfg_train: A DictConfig containing a valid training configuration. |
|
:param cfg_eval: A DictConfig containing a valid evaluation configuration. |
|
""" |
|
assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir |
|
|
|
with open_dict(cfg_train): |
|
cfg_train.trainer.max_epochs = 1 |
|
cfg_train.test = True |
|
|
|
HydraConfig().set_config(cfg_train) |
|
train_metric_dict, _ = train(cfg_train) |
|
|
|
assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") |
|
|
|
with open_dict(cfg_eval): |
|
cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") |
|
|
|
HydraConfig().set_config(cfg_eval) |
|
test_metric_dict, _ = evaluate(cfg_eval) |
|
|
|
assert test_metric_dict["test/acc"] > 0.0 |
|
assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 |
|
|