import pytest import torch.nn as nn from kornia.metrics import AverageMeter from kornia.x import EarlyStopping, ModelCheckpoint from kornia.x.utils import TrainerState @pytest.fixture def model(): return nn.Conv2d(3, 10, kernel_size=1) def test_callback_modelcheckpoint(tmp_path, model): cb = ModelCheckpoint(tmp_path, 'test_monitor') assert cb is not None metric = {'test_monitor': AverageMeter()} metric['test_monitor'].avg = 1.0 cb(model, epoch=0, valid_metric=metric) assert cb.best_metric == 1.0 assert (tmp_path / "model_0.pt").is_file() def test_callback_earlystopping(model): cb = EarlyStopping('test_monitor', patience=2) assert cb is not None assert cb.counter == 0 metric = {'test_monitor': AverageMeter()} metric['test_monitor'].avg = 1 state = cb(model, epoch=0, valid_metric=metric) assert state == TrainerState.TRAINING assert cb.best_score == -1 assert cb.counter == 0 metric['test_monitor'].avg = 2 state = cb(model, epoch=0, valid_metric=metric) assert state == TrainerState.TRAINING assert cb.best_score == -1 assert cb.counter == 1 state = cb(model, epoch=0, valid_metric=metric) assert state == TrainerState.TERMINATE