File size: 6,949 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.structures import InstanceData
from mmdet3d.structures import Det3DDataSample, PointData
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestDet3DDataSample(TestCase):
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
det3d_data_sample = Det3DDataSample(metainfo=meta_info)
assert 'img_size' in det3d_data_sample
assert det3d_data_sample.img_size == [256, 256]
assert det3d_data_sample.get('img_size') == [256, 256]
def test_setter(self):
det3d_data_sample = Det3DDataSample()
# test gt_instances_3d
gt_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
gt_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
gt_instances_3d_data['labels_3d'])
# test pred_instances_3d
pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
pred_instances_3d_data['scores_3d'])
# test pts_pred_instances_3d
pts_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
pts_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
pts_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
pts_pred_instances_3d_data['scores_3d'])
# test img_pred_instances_3d
img_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
img_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
img_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d'])
# test gt_pts_seg
gt_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
gt_pts_seg = PointData(**gt_pts_seg_data)
det3d_data_sample.gt_pts_seg = gt_pts_seg
assert 'gt_pts_seg' in det3d_data_sample
assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
gt_pts_seg_data['pts_instance_mask'])
assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
gt_pts_seg_data['pts_semantic_mask'])
# test pred_pts_seg
pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data)
det3d_data_sample.pred_pts_seg = pred_pts_seg
assert 'pred_pts_seg' in det3d_data_sample
assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
pred_pts_seg_data['pts_instance_mask'])
assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
pred_pts_seg_data['pts_semantic_mask'])
# test type error
with pytest.raises(AssertionError):
det3d_data_sample.pred_instances_3d = torch.rand(2, 4)
with pytest.raises(AssertionError):
det3d_data_sample.pred_pts_seg = torch.rand(20)
def test_deleter(self):
tmp_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
det3d_data_sample = Det3DDataSample()
gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
del det3d_data_sample.gt_instances_3d
assert 'gt_instances_3d' not in det3d_data_sample
pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pred_instances_3d
assert 'pred_instances_3d' not in det3d_data_sample
pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pts_pred_instances_3d
assert 'pts_pred_instances_3d' not in det3d_data_sample
img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.img_pred_instances_3d
assert 'img_pred_instances_3d' not in det3d_data_sample
pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data)
det3d_data_sample.pred_pts_seg = pred_pts_seg
assert 'pred_pts_seg' in det3d_data_sample
del det3d_data_sample.pred_pts_seg
assert 'pred_pts_seg' not in det3d_data_sample
|