import numpy as np import torch import torch.nn as nn from voxelnext_3d_box.utils import centernet_utils import spconv.pytorch as spconv import copy from spconv.core import ConvAlgo class SeparateHead(nn.Module): def __init__(self, input_channels, sep_head_dict, kernel_size, use_bias=False): super().__init__() self.sep_head_dict = sep_head_dict for cur_name in self.sep_head_dict: output_channels = self.sep_head_dict[cur_name]['out_channels'] num_conv = self.sep_head_dict[cur_name]['num_conv'] fc_list = [] for k in range(num_conv - 1): fc_list.append(spconv.SparseSequential( spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name, algo=ConvAlgo.Native), nn.BatchNorm1d(input_channels), nn.ReLU() )) fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out', algo=ConvAlgo.Native)) fc = nn.Sequential(*fc_list) self.__setattr__(cur_name, fc) def forward(self, x): ret_dict = {} for cur_name in self.sep_head_dict: ret_dict[cur_name] = self.__getattr__(cur_name)(x).features return ret_dict class VoxelNeXtHead(nn.Module): def __init__(self, class_names, point_cloud_range, voxel_size, kernel_size_head, CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING): super().__init__() self.point_cloud_range = torch.Tensor(point_cloud_range) self.voxel_size = torch.Tensor(voxel_size) self.feature_map_stride = 8 self.class_names = class_names self.class_names_each_head = [] self.class_id_mapping_each_head = [] self.POST_PROCESSING = POST_PROCESSING for cur_class_names in CLASS_NAMES_EACH_HEAD: self.class_names_each_head.append([x for x in cur_class_names if x in class_names]) cur_class_id_mapping = torch.from_numpy(np.array( [self.class_names.index(x) for x in cur_class_names if x in class_names] )) self.class_id_mapping_each_head.append(cur_class_id_mapping) total_classes = sum([len(x) for x in self.class_names_each_head]) assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}' self.heads_list = nn.ModuleList() self.separate_head_cfg = SEPARATE_HEAD_CFG for idx, cur_class_names in enumerate(self.class_names_each_head): cur_head_dict = copy.deepcopy(self.separate_head_cfg.HEAD_DICT) cur_head_dict['hm'] = dict(out_channels=len(cur_class_names), num_conv=2) self.heads_list.append( SeparateHead( input_channels=128, sep_head_dict=cur_head_dict, kernel_size=kernel_size_head, use_bias=True, ) ) self.forward_ret_dict = {} def generate_predicted_boxes(self, batch_size, pred_dicts, voxel_indices, spatial_shape): device = pred_dicts[0]['hm'].device post_process_cfg = self.POST_PROCESSING post_center_limit_range = torch.tensor(post_process_cfg.POST_CENTER_LIMIT_RANGE).float().to(device) ret_dict = [{ 'pred_boxes': [], 'pred_scores': [], 'pred_labels': [], 'pred_ious': [], 'voxel_ids': [] } for k in range(batch_size)] for idx, pred_dict in enumerate(pred_dicts): batch_hm = pred_dict['hm'].sigmoid() batch_center = pred_dict['center'] batch_center_z = pred_dict['center_z'] batch_dim = pred_dict['dim'].exp() batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1) batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1) batch_iou = None batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None voxel_indices_ = voxel_indices final_pred_dicts = centernet_utils.decode_bbox_from_voxels_nuscenes( batch_size=batch_size, indices=voxel_indices_, obj=batch_hm, rot_cos=batch_rot_cos, rot_sin=batch_rot_sin, center=batch_center, center_z=batch_center_z, dim=batch_dim, vel=batch_vel, iou=batch_iou, point_cloud_range=self.point_cloud_range.to(device), voxel_size=self.voxel_size.to(device), feature_map_stride=self.feature_map_stride, K=post_process_cfg.MAX_OBJ_PER_SAMPLE, score_thresh=post_process_cfg.SCORE_THRESH, post_center_limit_range=post_center_limit_range, add_features=torch.arange(voxel_indices_.shape[0], device=voxel_indices_.device).unsqueeze(-1) ) for k, final_dict in enumerate(final_pred_dicts): class_id_mapping_each_head = self.class_id_mapping_each_head[idx].to(device) final_dict['pred_labels'] = class_id_mapping_each_head[final_dict['pred_labels'].long()] ret_dict[k]['pred_boxes'].append(final_dict['pred_boxes']) ret_dict[k]['pred_scores'].append(final_dict['pred_scores']) ret_dict[k]['pred_labels'].append(final_dict['pred_labels']) ret_dict[k]['pred_ious'].append(final_dict['pred_ious']) ret_dict[k]['voxel_ids'].append(final_dict['add_features']) for k in range(batch_size): pred_boxes = torch.cat(ret_dict[k]['pred_boxes'], dim=0) pred_scores = torch.cat(ret_dict[k]['pred_scores'], dim=0) pred_labels = torch.cat(ret_dict[k]['pred_labels'], dim=0) voxel_ids = torch.cat(ret_dict[k]['voxel_ids'], dim=0) ret_dict[k]['pred_boxes'] = pred_boxes ret_dict[k]['pred_scores'] = pred_scores ret_dict[k]['pred_labels'] = pred_labels + 1 ret_dict[k]['voxel_ids'] = voxel_ids return ret_dict def _get_voxel_infos(self, x): spatial_shape = x.spatial_shape voxel_indices = x.indices spatial_indices = [] num_voxels = [] batch_size = x.batch_size batch_index = voxel_indices[:, 0] for bs_idx in range(batch_size): batch_inds = batch_index==bs_idx spatial_indices.append(voxel_indices[batch_inds][:, [2, 1]]) num_voxels.append(batch_inds.sum()) return spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels def forward(self, data_dict): x = data_dict['encoded_spconv_tensor'] spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels = self._get_voxel_infos(x) pred_dicts = [] for idx, head in enumerate(self.heads_list): pred_dict = head(x) pred_dicts.append(pred_dict) pred_dicts = self.generate_predicted_boxes( data_dict['batch_size'], pred_dicts, voxel_indices, spatial_shape ) return pred_dicts