test2 / tests /test_data /test_dataset.py
mccaly's picture
Upload 660 files
b13b124
import os.path as osp
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)
def test_classes():
assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
'pascal_voc')
assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
'pascal_voc')
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
with pytest.raises(ValueError):
get_palette('unsupported')
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_dataset_wrapper():
# CustomDataset.load_annotations = MagicMock()
# CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
len_a = 10
dataset_a.img_infos = MagicMock()
dataset_a.img_infos.__len__.return_value = len_a
dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
len_b = 20
dataset_b.img_infos = MagicMock()
dataset_b.img_infos.__len__.return_value = len_b
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert len(repeat_dataset) == 10 * len(dataset_a)
def test_custom_dataset():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(128, 256),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
# with img_dir and ann_dir
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir='imgs/',
ann_dir='gts/',
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with img_dir, ann_dir, split
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir='imgs/',
ann_dir='gts/',
img_suffix='img.jpg',
seg_map_suffix='gt.png',
split='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = CustomDataset(
train_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but img_dir/ann_dir are abs path
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
ann_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = CustomDataset(
test_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
img_suffix='img.jpg',
test_mode=True)
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
# get gt seg map
gt_seg_maps = train_dataset.get_gt_seg_maps()
assert len(gt_seg_maps) == 5
# evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mDice', 'mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
# evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7)
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes
# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)
# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_random_palette_is_generated():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
test_mode=True)
assert len(dataset.PALETTE) == 2
for class_color in dataset.PALETTE:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_custom_palette():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
palette=[[100, 100, 100], [200, 200, 200]],
test_mode=True)
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])