|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
|
|
from kornia.contrib import ClassificationHead, VisionTransformer |
|
from kornia.x import Configuration, ImageClassifierTrainer |
|
|
|
|
|
class DummyDatasetClassification(Dataset): |
|
def __len__(self): |
|
return 10 |
|
|
|
def __getitem__(self, index): |
|
return torch.ones(3, 32, 32), torch.tensor(1) |
|
|
|
|
|
@pytest.fixture |
|
def model(): |
|
return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10)) |
|
|
|
|
|
@pytest.fixture |
|
def dataloader(): |
|
dataset = DummyDatasetClassification() |
|
return torch.utils.data.DataLoader(dataset, batch_size=1) |
|
|
|
|
|
@pytest.fixture |
|
def criterion(): |
|
return nn.CrossEntropyLoss() |
|
|
|
|
|
@pytest.fixture |
|
def optimizer(model): |
|
return torch.optim.AdamW(model.parameters()) |
|
|
|
|
|
@pytest.fixture |
|
def scheduler(optimizer, dataloader): |
|
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader)) |
|
|
|
|
|
@pytest.fixture |
|
def configuration(): |
|
config = Configuration() |
|
config.num_epochs = 1 |
|
return config |
|
|
|
|
|
class TestImageClassifierTrainer: |
|
def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration): |
|
trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration) |
|
trainer.fit() |
|
|
|
def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration): |
|
with pytest.raises(ValueError): |
|
ImageClassifierTrainer( |
|
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None} |
|
) |
|
|