import pytest import torch import torch.nn as nn from torch.utils.data import Dataset from kornia.contrib import ClassificationHead, VisionTransformer from kornia.x import Configuration, ImageClassifierTrainer class DummyDatasetClassification(Dataset): def __len__(self): return 10 def __getitem__(self, index): return torch.ones(3, 32, 32), torch.tensor(1) @pytest.fixture def model(): return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10)) @pytest.fixture def dataloader(): dataset = DummyDatasetClassification() return torch.utils.data.DataLoader(dataset, batch_size=1) @pytest.fixture def criterion(): return nn.CrossEntropyLoss() @pytest.fixture def optimizer(model): return torch.optim.AdamW(model.parameters()) @pytest.fixture def scheduler(optimizer, dataloader): return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader)) @pytest.fixture def configuration(): config = Configuration() config.num_epochs = 1 return config class TestImageClassifierTrainer: def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration): trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration) trainer.fit() def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration): with pytest.raises(ValueError): ImageClassifierTrainer( model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None} )