compvis / test /x /test_segmentation.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from kornia.x import Configuration, SemanticSegmentationTrainer
class DummyDatasetSegmentation(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.ones(3, 32, 32), torch.ones(32, 32).long()
@pytest.fixture
def model():
return nn.Conv2d(3, 10, kernel_size=1)
@pytest.fixture
def dataloader():
dataset = DummyDatasetSegmentation()
return torch.utils.data.DataLoader(dataset, batch_size=1)
@pytest.fixture
def criterion():
return nn.CrossEntropyLoss()
@pytest.fixture
def optimizer(model):
return torch.optim.AdamW(model.parameters())
@pytest.fixture
def scheduler(optimizer, dataloader):
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))
@pytest.fixture
def configuration():
config = Configuration()
config.num_epochs = 1
return config
class TestsemanticSegmentationTrainer:
def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration):
trainer = SemanticSegmentationTrainer(
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration
)
trainer.fit()
def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration):
with pytest.raises(ValueError):
SemanticSegmentationTrainer(
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None}
)