eri2 / voxelnext_3d_box /models /voxelnext_head.py
mart9992's picture
m
aede1d5
raw
history blame
No virus
7.28 kB
import numpy as np
import torch
import torch.nn as nn
from voxelnext_3d_box.utils import centernet_utils
import spconv.pytorch as spconv
import copy
from spconv.core import ConvAlgo
class SeparateHead(nn.Module):
def __init__(self, input_channels, sep_head_dict, kernel_size, use_bias=False):
super().__init__()
self.sep_head_dict = sep_head_dict
for cur_name in self.sep_head_dict:
output_channels = self.sep_head_dict[cur_name]['out_channels']
num_conv = self.sep_head_dict[cur_name]['num_conv']
fc_list = []
for k in range(num_conv - 1):
fc_list.append(spconv.SparseSequential(
spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name, algo=ConvAlgo.Native),
nn.BatchNorm1d(input_channels),
nn.ReLU()
))
fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out', algo=ConvAlgo.Native))
fc = nn.Sequential(*fc_list)
self.__setattr__(cur_name, fc)
def forward(self, x):
ret_dict = {}
for cur_name in self.sep_head_dict:
ret_dict[cur_name] = self.__getattr__(cur_name)(x).features
return ret_dict
class VoxelNeXtHead(nn.Module):
def __init__(self, class_names, point_cloud_range, voxel_size, kernel_size_head,
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING):
super().__init__()
self.point_cloud_range = torch.Tensor(point_cloud_range)
self.voxel_size = torch.Tensor(voxel_size)
self.feature_map_stride = 8
self.class_names = class_names
self.class_names_each_head = []
self.class_id_mapping_each_head = []
self.POST_PROCESSING = POST_PROCESSING
for cur_class_names in CLASS_NAMES_EACH_HEAD:
self.class_names_each_head.append([x for x in cur_class_names if x in class_names])
cur_class_id_mapping = torch.from_numpy(np.array(
[self.class_names.index(x) for x in cur_class_names if x in class_names]
))
self.class_id_mapping_each_head.append(cur_class_id_mapping)
total_classes = sum([len(x) for x in self.class_names_each_head])
assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}'
self.heads_list = nn.ModuleList()
self.separate_head_cfg = SEPARATE_HEAD_CFG
for idx, cur_class_names in enumerate(self.class_names_each_head):
cur_head_dict = copy.deepcopy(self.separate_head_cfg.HEAD_DICT)
cur_head_dict['hm'] = dict(out_channels=len(cur_class_names), num_conv=2)
self.heads_list.append(
SeparateHead(
input_channels=128,
sep_head_dict=cur_head_dict,
kernel_size=kernel_size_head,
use_bias=True,
)
)
self.forward_ret_dict = {}
def generate_predicted_boxes(self, batch_size, pred_dicts, voxel_indices, spatial_shape):
device = pred_dicts[0]['hm'].device
post_process_cfg = self.POST_PROCESSING
post_center_limit_range = torch.tensor(post_process_cfg.POST_CENTER_LIMIT_RANGE).float().to(device)
ret_dict = [{
'pred_boxes': [],
'pred_scores': [],
'pred_labels': [],
'pred_ious': [],
'voxel_ids': []
} for k in range(batch_size)]
for idx, pred_dict in enumerate(pred_dicts):
batch_hm = pred_dict['hm'].sigmoid()
batch_center = pred_dict['center']
batch_center_z = pred_dict['center_z']
batch_dim = pred_dict['dim'].exp()
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1)
batch_iou = None
batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None
voxel_indices_ = voxel_indices
final_pred_dicts = centernet_utils.decode_bbox_from_voxels_nuscenes(
batch_size=batch_size, indices=voxel_indices_,
obj=batch_hm,
rot_cos=batch_rot_cos,
rot_sin=batch_rot_sin,
center=batch_center, center_z=batch_center_z,
dim=batch_dim, vel=batch_vel, iou=batch_iou,
point_cloud_range=self.point_cloud_range.to(device), voxel_size=self.voxel_size.to(device),
feature_map_stride=self.feature_map_stride,
K=post_process_cfg.MAX_OBJ_PER_SAMPLE,
score_thresh=post_process_cfg.SCORE_THRESH,
post_center_limit_range=post_center_limit_range,
add_features=torch.arange(voxel_indices_.shape[0], device=voxel_indices_.device).unsqueeze(-1)
)
for k, final_dict in enumerate(final_pred_dicts):
class_id_mapping_each_head = self.class_id_mapping_each_head[idx].to(device)
final_dict['pred_labels'] = class_id_mapping_each_head[final_dict['pred_labels'].long()]
ret_dict[k]['pred_boxes'].append(final_dict['pred_boxes'])
ret_dict[k]['pred_scores'].append(final_dict['pred_scores'])
ret_dict[k]['pred_labels'].append(final_dict['pred_labels'])
ret_dict[k]['pred_ious'].append(final_dict['pred_ious'])
ret_dict[k]['voxel_ids'].append(final_dict['add_features'])
for k in range(batch_size):
pred_boxes = torch.cat(ret_dict[k]['pred_boxes'], dim=0)
pred_scores = torch.cat(ret_dict[k]['pred_scores'], dim=0)
pred_labels = torch.cat(ret_dict[k]['pred_labels'], dim=0)
voxel_ids = torch.cat(ret_dict[k]['voxel_ids'], dim=0)
ret_dict[k]['pred_boxes'] = pred_boxes
ret_dict[k]['pred_scores'] = pred_scores
ret_dict[k]['pred_labels'] = pred_labels + 1
ret_dict[k]['voxel_ids'] = voxel_ids
return ret_dict
def _get_voxel_infos(self, x):
spatial_shape = x.spatial_shape
voxel_indices = x.indices
spatial_indices = []
num_voxels = []
batch_size = x.batch_size
batch_index = voxel_indices[:, 0]
for bs_idx in range(batch_size):
batch_inds = batch_index==bs_idx
spatial_indices.append(voxel_indices[batch_inds][:, [2, 1]])
num_voxels.append(batch_inds.sum())
return spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels
def forward(self, data_dict):
x = data_dict['encoded_spconv_tensor']
spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels = self._get_voxel_infos(x)
pred_dicts = []
for idx, head in enumerate(self.heads_list):
pred_dict = head(x)
pred_dicts.append(pred_dict)
pred_dicts = self.generate_predicted_boxes(
data_dict['batch_size'],
pred_dicts, voxel_indices, spatial_shape
)
return pred_dicts