File size: 3,002 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from mmengine.dataset import ConcatDataset
from mmyolo.datasets import YOLOv5VOCDataset
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv5VocDataset(unittest.TestCase):
def test_batch_shapes_cfg(self):
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=2,
img_size=640,
size_divisor=32,
extra_pad_ratio=0.5)
# test serialize_data=True
dataset = YOLOv5VOCDataset(
data_root='tests/data/VOCdevkit/',
ann_file='VOC2007/ImageSets/Main/trainval.txt',
data_prefix=dict(sub_data_root='VOC2007/'),
test_mode=True,
pipeline=[],
batch_shapes_cfg=batch_shapes_cfg,
)
expected_img_ids = ['000001']
expected_batch_shapes = [[672, 480]]
for i, data in enumerate(dataset):
assert data['img_id'] == expected_img_ids[i]
assert data['batch_shape'].tolist() == expected_batch_shapes[i]
def test_prepare_data(self):
dataset = YOLOv5VOCDataset(
data_root='tests/data/VOCdevkit/',
ann_file='VOC2007/ImageSets/Main/trainval.txt',
data_prefix=dict(sub_data_root='VOC2007/'),
filter_cfg=dict(filter_empty_gt=False, min_size=0),
pipeline=[],
serialize_data=True,
batch_shapes_cfg=None,
)
for data in dataset:
assert 'dataset' in data
# test with test_mode = True
dataset = YOLOv5VOCDataset(
data_root='tests/data/VOCdevkit/',
ann_file='VOC2007/ImageSets/Main/trainval.txt',
data_prefix=dict(sub_data_root='VOC2007/'),
filter_cfg=dict(
filter_empty_gt=True, min_size=32, bbox_min_size=None),
pipeline=[],
test_mode=True,
batch_shapes_cfg=None)
for data in dataset:
assert 'dataset' not in data
def test_concat_dataset(self):
dataset = ConcatDataset(
datasets=[
dict(
type='YOLOv5VOCDataset',
data_root='tests/data/VOCdevkit/',
ann_file='VOC2007/ImageSets/Main/trainval.txt',
data_prefix=dict(sub_data_root='VOC2007/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=[]),
dict(
type='YOLOv5VOCDataset',
data_root='tests/data/VOCdevkit/',
ann_file='VOC2012/ImageSets/Main/trainval.txt',
data_prefix=dict(sub_data_root='VOC2012/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=[])
],
ignore_keys='dataset_type')
dataset.full_init()
self.assertEqual(len(dataset), 2)
|