stevengrove
initial commit
186701e
raw
history blame
No virus
2.43 kB
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from mmyolo.datasets import YOLOv5CocoDataset
class TestYOLOv5CocoDataset(unittest.TestCase):
def test_batch_shapes_cfg(self):
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=2,
img_size=640,
size_divisor=32,
extra_pad_ratio=0.5)
# test serialize_data=True
dataset = YOLOv5CocoDataset(
data_prefix=dict(img='imgs'),
ann_file='tests/data/coco_sample.json',
filter_cfg=dict(filter_empty_gt=False, min_size=0),
pipeline=[],
serialize_data=True,
batch_shapes_cfg=batch_shapes_cfg,
)
expected_img_ids = [3, 0, 2, 1]
expected_batch_shapes = [[512, 672], [512, 672], [672, 672],
[672, 672]]
for i, data in enumerate(dataset):
assert data['img_id'] == expected_img_ids[i]
assert data['batch_shape'].tolist() == expected_batch_shapes[i]
# test serialize_data=True
dataset = YOLOv5CocoDataset(
data_prefix=dict(img='imgs'),
ann_file='tests/data/coco_sample.json',
filter_cfg=dict(filter_empty_gt=False, min_size=0),
pipeline=[],
serialize_data=False,
batch_shapes_cfg=batch_shapes_cfg,
)
expected_img_ids = [3, 0, 2, 1]
expected_batch_shapes = [[512, 672], [512, 672], [672, 672],
[672, 672]]
for i, data in enumerate(dataset):
assert data['img_id'] == expected_img_ids[i]
assert data['batch_shape'].tolist() == expected_batch_shapes[i]
def test_prepare_data(self):
dataset = YOLOv5CocoDataset(
data_prefix=dict(img='imgs'),
ann_file='tests/data/coco_sample.json',
filter_cfg=dict(filter_empty_gt=False, min_size=0),
pipeline=[],
serialize_data=True,
batch_shapes_cfg=None,
)
for data in dataset:
assert 'dataset' in data
# test with test_mode = True
dataset = YOLOv5CocoDataset(
data_prefix=dict(img='imgs'),
ann_file='tests/data/coco_sample.json',
test_mode=True,
pipeline=[])
for data in dataset:
assert 'dataset' not in data