File size: 1,638 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from kornia.contrib import ClassificationHead, VisionTransformer
from kornia.x import Configuration, ImageClassifierTrainer
class DummyDatasetClassification(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.ones(3, 32, 32), torch.tensor(1)
@pytest.fixture
def model():
return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10))
@pytest.fixture
def dataloader():
dataset = DummyDatasetClassification()
return torch.utils.data.DataLoader(dataset, batch_size=1)
@pytest.fixture
def criterion():
return nn.CrossEntropyLoss()
@pytest.fixture
def optimizer(model):
return torch.optim.AdamW(model.parameters())
@pytest.fixture
def scheduler(optimizer, dataloader):
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))
@pytest.fixture
def configuration():
config = Configuration()
config.num_epochs = 1
return config
class TestImageClassifierTrainer:
def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration):
trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration)
trainer.fit()
def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration):
with pytest.raises(ValueError):
ImageClassifierTrainer(
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None}
)
|