File size: 1,431 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch
from mmengine import ConfigDict, DefaultScope

from mmdet3d.models import Seg3DTTAModel
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.testing import get_detector_cfg


class TestSeg3DTTAModel(TestCase):

    def test_seg3d_tta_model(self):
        import mmdet3d.models

        assert hasattr(mmdet3d.models, 'Cylinder3D')
        DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d')
        segmentor3d_cfg = get_detector_cfg(
            'cylinder3d/cylinder3d_4xb4-3x_semantickitti.py')
        cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg)

        model: Seg3DTTAModel = MODELS.build(cfg)

        points = []
        data_samples = []
        pcd_horizontal_flip_list = [False, False, True, True]
        pcd_vertical_flip_list = [False, True, False, True]
        for i in range(4):
            points.append({'points': [torch.randn(200, 4)]})
            data_samples.append([
                Det3DDataSample(
                    metainfo=dict(
                        pcd_horizontal_flip=pcd_horizontal_flip_list[i],
                        pcd_vertical_flip=pcd_vertical_flip_list[i]))
            ])
        if torch.cuda.is_available():
            model.eval().cuda()
            model.test_step(dict(inputs=points, data_samples=data_samples))