compvis / test /x /test_image_classification.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from kornia.contrib import ClassificationHead, VisionTransformer
from kornia.x import Configuration, ImageClassifierTrainer
class DummyDatasetClassification(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.ones(3, 32, 32), torch.tensor(1)
@pytest.fixture
def model():
return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10))
@pytest.fixture
def dataloader():
dataset = DummyDatasetClassification()
return torch.utils.data.DataLoader(dataset, batch_size=1)
@pytest.fixture
def criterion():
return nn.CrossEntropyLoss()
@pytest.fixture
def optimizer(model):
return torch.optim.AdamW(model.parameters())
@pytest.fixture
def scheduler(optimizer, dataloader):
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))
@pytest.fixture
def configuration():
config = Configuration()
config.num_epochs = 1
return config
class TestImageClassifierTrainer:
def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration):
trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration)
trainer.fit()
def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration):
with pytest.raises(ValueError):
ImageClassifierTrainer(
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None}
)