|
|
|
import os.path as osp |
|
import unittest |
|
from unittest.mock import MagicMock, patch |
|
|
|
import pytest |
|
|
|
from mmdet.datasets import DATASETS |
|
|
|
|
|
@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock()) |
|
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock()) |
|
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock()) |
|
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock()) |
|
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock) |
|
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock) |
|
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock) |
|
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock) |
|
@pytest.mark.parametrize('dataset', |
|
['CocoDataset', 'VOCDataset', 'CityscapesDataset']) |
|
def test_custom_classes_override_default(dataset): |
|
dataset_class = DATASETS.get(dataset) |
|
if dataset in ['CocoDataset', 'CityscapesDataset']: |
|
dataset_class.coco = MagicMock() |
|
dataset_class.cat_ids = MagicMock() |
|
|
|
original_classes = dataset_class.CLASSES |
|
|
|
|
|
custom_dataset = dataset_class( |
|
ann_file=MagicMock(), |
|
pipeline=[], |
|
classes=('bus', 'car'), |
|
test_mode=True, |
|
img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
|
assert custom_dataset.CLASSES != original_classes |
|
assert custom_dataset.CLASSES == ('bus', 'car') |
|
print(custom_dataset) |
|
|
|
|
|
custom_dataset = dataset_class( |
|
ann_file=MagicMock(), |
|
pipeline=[], |
|
classes=['bus', 'car'], |
|
test_mode=True, |
|
img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
|
assert custom_dataset.CLASSES != original_classes |
|
assert custom_dataset.CLASSES == ['bus', 'car'] |
|
print(custom_dataset) |
|
|
|
|
|
custom_dataset = dataset_class( |
|
ann_file=MagicMock(), |
|
pipeline=[], |
|
classes=['foo'], |
|
test_mode=True, |
|
img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
|
assert custom_dataset.CLASSES != original_classes |
|
assert custom_dataset.CLASSES == ['foo'] |
|
print(custom_dataset) |
|
|
|
|
|
custom_dataset = dataset_class( |
|
ann_file=MagicMock(), |
|
pipeline=[], |
|
classes=None, |
|
test_mode=True, |
|
img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
|
assert custom_dataset.CLASSES == original_classes |
|
print(custom_dataset) |
|
|
|
|
|
import tempfile |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
path = tmpdir + 'classes.txt' |
|
with open(path, 'w') as f: |
|
f.write('bus\ncar\n') |
|
custom_dataset = dataset_class( |
|
ann_file=MagicMock(), |
|
pipeline=[], |
|
classes=path, |
|
test_mode=True, |
|
img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
|
assert custom_dataset.CLASSES != original_classes |
|
assert custom_dataset.CLASSES == ['bus', 'car'] |
|
print(custom_dataset) |
|
|
|
|
|
class CustomDatasetTests(unittest.TestCase): |
|
|
|
def setUp(self): |
|
super().setUp() |
|
self.data_dir = osp.join( |
|
osp.dirname(osp.dirname(osp.dirname(__file__))), 'data') |
|
self.dataset_class = DATASETS.get('XMLDataset') |
|
|
|
def test_data_infos__default_db_directories(self): |
|
"""Test correct data read having a Pacal-VOC directory structure.""" |
|
test_dataset_root = osp.join(self.data_dir, 'VOCdevkit', 'VOC2007') |
|
custom_ds = self.dataset_class( |
|
data_root=test_dataset_root, |
|
ann_file=osp.join(test_dataset_root, 'ImageSets', 'Main', |
|
'trainval.txt'), |
|
pipeline=[], |
|
classes=('person', 'dog'), |
|
test_mode=True) |
|
|
|
self.assertListEqual([{ |
|
'id': '000001', |
|
'filename': osp.join('JPEGImages', '000001.jpg'), |
|
'width': 353, |
|
'height': 500 |
|
}], custom_ds.data_infos) |
|
|
|
def test_data_infos__overridden_db_subdirectories(self): |
|
"""Test correct data read having a customized directory structure.""" |
|
test_dataset_root = osp.join(self.data_dir, 'custom_dataset') |
|
custom_ds = self.dataset_class( |
|
data_root=test_dataset_root, |
|
ann_file=osp.join(test_dataset_root, 'trainval.txt'), |
|
pipeline=[], |
|
classes=('person', 'dog'), |
|
test_mode=True, |
|
img_prefix='', |
|
img_subdir='images', |
|
ann_subdir='images') |
|
|
|
self.assertListEqual([{ |
|
'id': '000001', |
|
'filename': osp.join('images', '000001.jpg'), |
|
'width': 353, |
|
'height': 500 |
|
}], custom_ds.data_infos) |
|
|