import pytest import torch import torch.nn as nn from torch.utils.data import Dataset from kornia.x import Configuration, SemanticSegmentationTrainer class DummyDatasetSegmentation(Dataset): def __len__(self): return 10 def __getitem__(self, index): return torch.ones(3, 32, 32), torch.ones(32, 32).long() @pytest.fixture def model(): return nn.Conv2d(3, 10, kernel_size=1) @pytest.fixture def dataloader(): dataset = DummyDatasetSegmentation() return torch.utils.data.DataLoader(dataset, batch_size=1) @pytest.fixture def criterion(): return nn.CrossEntropyLoss() @pytest.fixture def optimizer(model): return torch.optim.AdamW(model.parameters()) @pytest.fixture def scheduler(optimizer, dataloader): return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader)) @pytest.fixture def configuration(): config = Configuration() config.num_epochs = 1 return config class TestsemanticSegmentationTrainer: def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration): trainer = SemanticSegmentationTrainer( model, dataloader, dataloader, criterion, optimizer, scheduler, configuration ) trainer.fit() def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration): with pytest.raises(ValueError): SemanticSegmentationTrainer( model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None} )