stevengrove
initial commit
186701e
raw
history blame
No virus
5.09 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import unittest
import numpy as np
from mmdet.structures import DetDataSample
from mmdet.structures.mask import BitmapMasks
from mmengine.structures import InstanceData, PixelData
from mmyolo.datasets.transforms import PackDetInputs
class TestPackDetInputs(unittest.TestCase):
def setUp(self):
"""Setup the model and optimizer which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
data_prefix = osp.join(osp.dirname(__file__), '../../data')
img_path = osp.join(data_prefix, 'color.jpg')
rng = np.random.RandomState(0)
self.results1 = {
'img_id': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (600, 800),
'scale_factor': 2.0,
'flip': False,
'img': rng.rand(300, 400),
'gt_seg_map': rng.rand(300, 400),
'gt_masks':
BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
'gt_bboxes_labels': rng.rand(3, ),
'gt_ignore_flags': np.array([0, 0, 1], dtype=bool),
'proposals': rng.rand(2, 4),
'proposals_scores': rng.rand(2, )
}
self.results2 = {
'img_id': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (600, 800),
'scale_factor': 2.0,
'flip': False,
'img': rng.rand(300, 400),
'gt_seg_map': rng.rand(300, 400),
'gt_masks':
BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
'gt_bboxes_labels': rng.rand(3, ),
'proposals': rng.rand(2, 4),
'proposals_scores': rng.rand(2, )
}
self.results3 = {
'img_id': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (600, 800),
'scale_factor': 2.0,
'flip': False,
'img': rng.rand(300, 400),
'gt_seg_map': rng.rand(300, 400),
'gt_masks':
BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
'gt_panoptic_seg': rng.rand(1, 300, 400),
'gt_bboxes_labels': rng.rand(3, ),
'proposals': rng.rand(2, 4),
'proposals_scores': rng.rand(2, )
}
self.meta_keys = ('img_id', 'img_path', 'ori_shape', 'scale_factor',
'flip')
def test_transform(self):
transform = PackDetInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results1))
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], DetDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
self.assertIsInstance(results['data_samples'].ignored_instances,
InstanceData)
self.assertEqual(len(results['data_samples'].gt_instances), 2)
self.assertEqual(len(results['data_samples'].ignored_instances), 1)
self.assertIsInstance(results['data_samples'].gt_sem_seg, PixelData)
def test_transform_without_ignore(self):
transform = PackDetInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results2))
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], DetDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
self.assertIsInstance(results['data_samples'].ignored_instances,
InstanceData)
self.assertEqual(len(results['data_samples'].gt_instances), 3)
self.assertEqual(len(results['data_samples'].ignored_instances), 0)
self.assertIsInstance(results['data_samples'].gt_sem_seg, PixelData)
def test_transform_with_panoptic_seg(self):
transform = PackDetInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results3))
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], DetDataSample)
self.assertIsInstance(results['data_samples'].gt_instances,
InstanceData)
self.assertIsInstance(results['data_samples'].ignored_instances,
InstanceData)
self.assertEqual(len(results['data_samples'].gt_instances), 3)
self.assertEqual(len(results['data_samples'].ignored_instances), 0)
self.assertIsInstance(results['data_samples'].gt_sem_seg, PixelData)
self.assertIsInstance(results['data_samples'].gt_panoptic_seg,
PixelData)
def test_repr(self):
transform = PackDetInputs(meta_keys=self.meta_keys)
self.assertEqual(
repr(transform), f'PackDetInputs(meta_keys={self.meta_keys})')