compvis / test /x /test_x.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch.nn as nn
from kornia.metrics import AverageMeter
from kornia.x import EarlyStopping, ModelCheckpoint
from kornia.x.utils import TrainerState
@pytest.fixture
def model():
return nn.Conv2d(3, 10, kernel_size=1)
def test_callback_modelcheckpoint(tmp_path, model):
cb = ModelCheckpoint(tmp_path, 'test_monitor')
assert cb is not None
metric = {'test_monitor': AverageMeter()}
metric['test_monitor'].avg = 1.0
cb(model, epoch=0, valid_metric=metric)
assert cb.best_metric == 1.0
assert (tmp_path / "model_0.pt").is_file()
def test_callback_earlystopping(model):
cb = EarlyStopping('test_monitor', patience=2)
assert cb is not None
assert cb.counter == 0
metric = {'test_monitor': AverageMeter()}
metric['test_monitor'].avg = 1
state = cb(model, epoch=0, valid_metric=metric)
assert state == TrainerState.TRAINING
assert cb.best_score == -1
assert cb.counter == 0
metric['test_monitor'].avg = 2
state = cb(model, epoch=0, valid_metric=metric)
assert state == TrainerState.TRAINING
assert cb.best_score == -1
assert cb.counter == 1
state = cb(model, epoch=0, valid_metric=metric)
assert state == TrainerState.TERMINATE