3dtest / mmdet3d /testing /model_utils.py
giantmonkeyTC
mm2
c2ca15f
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from os.path import dirname, exists, join
import numpy as np
import torch
from mmengine.structures import InstanceData
from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes,
PointData)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def _get_config_directory():
"""Find the predefined detector config directory."""
try:
# Assume we are running in the source mmdetection3d repo
repo_dpath = dirname(dirname(dirname(__file__)))
except NameError:
# For IPython development when this __file__ is not defined
import mmdet3d
repo_dpath = dirname(dirname(mmdet3d.__file__))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def _get_config_module(fname):
"""Load a configuration as a python module."""
from mmengine import Config
config_dpath = _get_config_directory()
config_fpath = join(config_dpath, fname)
config_mod = Config.fromfile(config_fpath)
return config_mod
def get_model_cfg(fname):
"""Grab configs necessary to create a model.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
return model
def get_detector_cfg(fname):
"""Grab configs necessary to create a detector.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
import mmengine
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
train_cfg = mmengine.Config(copy.deepcopy(config.model.train_cfg))
test_cfg = mmengine.Config(copy.deepcopy(config.model.test_cfg))
model.update(train_cfg=train_cfg)
model.update(test_cfg=test_cfg)
return model
def create_detector_inputs(seed=0,
with_points=True,
with_img=False,
img_size=10,
num_gt_instance=20,
num_points=10,
points_feat_dim=4,
num_classes=3,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
with_eval_ann_info=False,
bboxes_3d_type='lidar'):
setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
bbox_3d_class = {
'lidar': LiDARInstance3DBoxes,
'depth': DepthInstance3DBoxes,
'cam': CameraInstance3DBoxes
}
meta_info = dict()
meta_info['depth2img'] = np.array(
[[5.23289349e+02, 3.68831943e+02, 6.10469439e+01],
[1.09560138e+02, 1.97404735e+02, -5.47377738e+02],
[1.25930002e-02, 9.92229998e-01, -1.23769999e-01]])
meta_info['lidar2img'] = np.array(
[[5.23289349e+02, 3.68831943e+02, 6.10469439e+01],
[1.09560138e+02, 1.97404735e+02, -5.47377738e+02],
[1.25930002e-02, 9.92229998e-01, -1.23769999e-01]])
inputs_dict = dict()
if with_points:
points = torch.rand([num_points, points_feat_dim])
inputs_dict['points'] = [points]
if with_img:
if isinstance(img_size, tuple):
img = torch.rand(3, img_size[0], img_size[1])
meta_info['img_shape'] = img_size
meta_info['ori_shape'] = img_size
else:
img = torch.rand(3, img_size, img_size)
meta_info['img_shape'] = (img_size, img_size)
meta_info['ori_shape'] = (img_size, img_size)
meta_info['scale_factor'] = np.array([1., 1.])
inputs_dict['img'] = [img]
gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
data_sample.set_metainfo(meta_info)
data_sample.gt_instances_3d = gt_instance_3d
gt_instance = InstanceData()
gt_instance.labels = torch.randint(0, num_classes, [num_gt_instance])
gt_instance.bboxes = torch.rand(num_gt_instance, 4)
gt_instance.bboxes[:,
2:] = gt_instance.bboxes[:, :2] + gt_instance.bboxes[:,
2:]
data_sample.gt_instances = gt_instance
data_sample.gt_pts_seg = PointData()
if with_pts_instance_mask:
pts_instance_mask = torch.randint(0, num_gt_instance, [num_points])
data_sample.gt_pts_seg['pts_instance_mask'] = pts_instance_mask
if with_pts_semantic_mask:
pts_semantic_mask = torch.randint(0, num_classes, [num_points])
data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask
if with_eval_ann_info:
data_sample.eval_ann_info = dict()
else:
data_sample.eval_ann_info = None
return dict(inputs=inputs_dict, data_samples=[data_sample])