|
|
|
|
|
import random |
|
|
from unittest import TestCase |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from mmdet3d.structures import PointData |
|
|
|
|
|
|
|
|
class TestPointData(TestCase): |
|
|
|
|
|
def setup_data(self): |
|
|
metainfo = dict(sample_idx=random.randint(0, 100)) |
|
|
points = torch.rand((5, 3)) |
|
|
point_data = PointData(metainfo=metainfo, points=points) |
|
|
return point_data |
|
|
|
|
|
def test_set_data(self): |
|
|
point_data = self.setup_data() |
|
|
|
|
|
|
|
|
with self.assertRaises(AttributeError): |
|
|
point_data._metainfo_fields = 1 |
|
|
with self.assertRaises(AttributeError): |
|
|
point_data._data_fields = 1 |
|
|
|
|
|
point_data.keypoints = torch.rand((5, 2)) |
|
|
assert 'keypoints' in point_data |
|
|
|
|
|
def test_getitem(self): |
|
|
point_data = PointData() |
|
|
|
|
|
with self.assertRaises(IndexError): |
|
|
point_data[1] |
|
|
|
|
|
point_data = self.setup_data() |
|
|
assert len(point_data) == 5 |
|
|
slice_point_data = point_data[:2] |
|
|
assert len(slice_point_data) == 2 |
|
|
slice_point_data = point_data[1] |
|
|
assert len(slice_point_data) == 1 |
|
|
|
|
|
with pytest.raises(IndexError): |
|
|
point_data[5] |
|
|
|
|
|
|
|
|
item = torch.Tensor([1, 2, 3, 4]) |
|
|
with pytest.raises(AssertionError): |
|
|
point_data[item] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError): |
|
|
point_data[item.bool()] |
|
|
|
|
|
|
|
|
long_tensor = torch.randint(5, (2, )) |
|
|
long_index_point_data = point_data[long_tensor] |
|
|
assert len(long_index_point_data) == len(long_tensor) |
|
|
|
|
|
|
|
|
bool_tensor = torch.rand(5) > 0.5 |
|
|
bool_index_point_data = point_data[bool_tensor] |
|
|
assert len(bool_index_point_data) == bool_tensor.sum() |
|
|
bool_tensor = torch.rand(5) > 1 |
|
|
empty_point_data = point_data[bool_tensor] |
|
|
assert len(empty_point_data) == bool_tensor.sum() |
|
|
|
|
|
|
|
|
list_index = [1, 2] |
|
|
list_index_point_data = point_data[list_index] |
|
|
assert len(list_index_point_data) == len(list_index) |
|
|
|
|
|
|
|
|
list_bool = [True, False, True, False, False] |
|
|
list_bool_point_data = point_data[list_bool] |
|
|
assert len(list_bool_point_data) == 2 |
|
|
|
|
|
|
|
|
long_numpy = np.random.randint(5, size=2) |
|
|
long_numpy_point_data = point_data[long_numpy] |
|
|
assert len(long_numpy_point_data) == len(long_numpy) |
|
|
|
|
|
bool_numpy = np.random.rand(5) > 0.5 |
|
|
bool_numpy_point_data = point_data[bool_numpy] |
|
|
assert len(bool_numpy_point_data) == bool_numpy.sum() |
|
|
|
|
|
def test_len(self): |
|
|
point_data = self.setup_data() |
|
|
assert len(point_data) == 5 |
|
|
point_data = PointData() |
|
|
assert len(point_data) == 0 |
|
|
|