File size: 1,638 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from kornia.contrib import ClassificationHead, VisionTransformer
from kornia.x import Configuration, ImageClassifierTrainer


class DummyDatasetClassification(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        return torch.ones(3, 32, 32), torch.tensor(1)


@pytest.fixture
def model():
    return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10))


@pytest.fixture
def dataloader():
    dataset = DummyDatasetClassification()
    return torch.utils.data.DataLoader(dataset, batch_size=1)


@pytest.fixture
def criterion():
    return nn.CrossEntropyLoss()


@pytest.fixture
def optimizer(model):
    return torch.optim.AdamW(model.parameters())


@pytest.fixture
def scheduler(optimizer, dataloader):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))


@pytest.fixture
def configuration():
    config = Configuration()
    config.num_epochs = 1
    return config


class TestImageClassifierTrainer:
    def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration):
        trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration)
        trainer.fit()

    def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration):
        with pytest.raises(ValueError):
            ImageClassifierTrainer(
                model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None}
            )