import math import os.path as osp import pytest from torch.utils.data import (DistributedSampler, RandomSampler, SequentialSampler) from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, build_dataset) @DATASETS.register_module() class ToyDataset(object): def __init__(self, cnt=0): self.cnt = cnt def __item__(self, idx): return idx def __len__(self): return 100 def test_build_dataset(): cfg = dict(type='ToyDataset') dataset = build_dataset(cfg) assert isinstance(dataset, ToyDataset) assert dataset.cnt == 0 dataset = build_dataset(cfg, default_args=dict(cnt=1)) assert isinstance(dataset, ToyDataset) assert dataset.cnt == 1 data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') img_dir = 'imgs/' ann_dir = 'gts/' # We use same dir twice for simplicity # with ann_dir cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], ann_dir=[ann_dir, ann_dir]) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 # with ann_dir, split cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, split=['splits/train.txt', 'splits/val.txt']) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 5 # with ann_dir, split cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=[ann_dir, ann_dir], split=['splits/train.txt', 'splits/val.txt']) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 5 # test mode cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], test_mode=True) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 # test mode with splits cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], split=['splits/val.txt', 'splits/val.txt'], test_mode=True) dataset = build_dataset(cfg) assert isinstance(dataset, ConcatDataset) assert len(dataset) == 2 # len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1 with pytest.raises(AssertionError): cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], ann_dir=[ann_dir, ann_dir, ann_dir]) build_dataset(cfg) # len(splits) should be zero or len(img_dir) when len(img_dir) > 1 with pytest.raises(AssertionError): cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=[img_dir, img_dir], split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) build_dataset(cfg) # len(splits) == len(ann_dir) when only len(img_dir) == 1 and len( # ann_dir) > 1 with pytest.raises(AssertionError): cfg = dict( type='CustomDataset', pipeline=[], data_root=data_root, img_dir=img_dir, ann_dir=[ann_dir, ann_dir], split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) build_dataset(cfg) def test_build_dataloader(): dataset = ToyDataset() samples_per_gpu = 3 # dist=True, shuffle=True, 1GPU dataloader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, DistributedSampler) assert dataloader.sampler.shuffle # dist=True, shuffle=False, 1GPU dataloader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, shuffle=False) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, DistributedSampler) assert not dataloader.sampler.shuffle # dist=True, shuffle=True, 8GPU dataloader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, num_gpus=8) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert dataloader.num_workers == 2 # dist=False, shuffle=True, 1GPU dataloader = build_dataloader( dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2, dist=False) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, RandomSampler) assert dataloader.num_workers == 2 # dist=False, shuffle=False, 1GPU dataloader = build_dataloader( dataset, samples_per_gpu=3, workers_per_gpu=2, shuffle=False, dist=False) assert dataloader.batch_size == samples_per_gpu assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) assert isinstance(dataloader.sampler, SequentialSampler) assert dataloader.num_workers == 2 # dist=False, shuffle=True, 8GPU dataloader = build_dataloader( dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False) assert dataloader.batch_size == samples_per_gpu * 8 assert len(dataloader) == int( math.ceil(len(dataset) / samples_per_gpu / 8)) assert isinstance(dataloader.sampler, RandomSampler) assert dataloader.num_workers == 16