camenduru's picture
thanks to show ❤
3bbb319
raw
history blame
No virus
4.78 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import unittest
from unittest.mock import MagicMock, patch
import pytest
from mmdet.datasets import DATASETS
@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
@pytest.mark.parametrize('dataset',
['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
dataset_class = DATASETS.get(dataset)
if dataset in ['CocoDataset', 'CityscapesDataset']:
dataset_class.coco = MagicMock()
dataset_class.cat_ids = MagicMock()
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=('bus', 'car'),
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ('bus', 'car')
print(custom_dataset)
# Test setting classes as a list
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['bus', 'car'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
print(custom_dataset)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['foo'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['foo']
print(custom_dataset)
# Test default behavior
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=None,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES == original_classes
print(custom_dataset)
# Test sending file path
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + 'classes.txt'
with open(path, 'w') as f:
f.write('bus\ncar\n')
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=path,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
print(custom_dataset)
class CustomDatasetTests(unittest.TestCase):
def setUp(self):
super().setUp()
self.data_dir = osp.join(
osp.dirname(osp.dirname(osp.dirname(__file__))), 'data')
self.dataset_class = DATASETS.get('XMLDataset')
def test_data_infos__default_db_directories(self):
"""Test correct data read having a Pacal-VOC directory structure."""
test_dataset_root = osp.join(self.data_dir, 'VOCdevkit', 'VOC2007')
custom_ds = self.dataset_class(
data_root=test_dataset_root,
ann_file=osp.join(test_dataset_root, 'ImageSets', 'Main',
'trainval.txt'),
pipeline=[],
classes=('person', 'dog'),
test_mode=True)
self.assertListEqual([{
'id': '000001',
'filename': osp.join('JPEGImages', '000001.jpg'),
'width': 353,
'height': 500
}], custom_ds.data_infos)
def test_data_infos__overridden_db_subdirectories(self):
"""Test correct data read having a customized directory structure."""
test_dataset_root = osp.join(self.data_dir, 'custom_dataset')
custom_ds = self.dataset_class(
data_root=test_dataset_root,
ann_file=osp.join(test_dataset_root, 'trainval.txt'),
pipeline=[],
classes=('person', 'dog'),
test_mode=True,
img_prefix='',
img_subdir='images',
ann_subdir='images')
self.assertListEqual([{
'id': '000001',
'filename': osp.join('images', '000001.jpg'),
'width': 353,
'height': 500
}], custom_ds.data_infos)