|
|
|
from typing import List, Optional |
|
|
|
import mmengine |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from mmcv.ops.furthest_point_sample import furthest_point_sample |
|
from mmengine.model import BaseModule |
|
from torch import Tensor |
|
|
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.utils import InstanceList |
|
|
|
|
|
def bilinear_interpolate_torch(inputs: Tensor, x: Tensor, y: Tensor) -> Tensor: |
|
"""Bilinear interpolate for inputs.""" |
|
x0 = torch.floor(x).long() |
|
x1 = x0 + 1 |
|
|
|
y0 = torch.floor(y).long() |
|
y1 = y0 + 1 |
|
|
|
x0 = torch.clamp(x0, 0, inputs.shape[1] - 1) |
|
x1 = torch.clamp(x1, 0, inputs.shape[1] - 1) |
|
y0 = torch.clamp(y0, 0, inputs.shape[0] - 1) |
|
y1 = torch.clamp(y1, 0, inputs.shape[0] - 1) |
|
|
|
Ia = inputs[y0, x0] |
|
Ib = inputs[y1, x0] |
|
Ic = inputs[y0, x1] |
|
Id = inputs[y1, x1] |
|
|
|
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y) |
|
wb = (x1.type_as(x) - x) * (y - y0.type_as(y)) |
|
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y) |
|
wd = (x - x0.type_as(x)) * (y - y0.type_as(y)) |
|
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t( |
|
torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd) |
|
return ans |
|
|
|
|
|
@MODELS.register_module() |
|
class VoxelSetAbstraction(BaseModule): |
|
"""Voxel set abstraction module for PVRCNN and PVRCNN++. |
|
|
|
Args: |
|
num_keypoints (int): The number of key points sampled from |
|
raw points cloud. |
|
fused_out_channel (int): Key points feature output channels |
|
num after fused. Default to 128. |
|
voxel_size (list[float]): Size of voxels. Defaults to |
|
[0.05, 0.05, 0.1]. |
|
point_cloud_range (list[float]): Point cloud range. Defaults to |
|
[0, -40, -3, 70.4, 40, 1]. |
|
voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA |
|
module cfg. Used to gather key points features from multi-wise |
|
voxel features. Default to None. |
|
rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg. |
|
Used to gather key points features from raw points. Default to |
|
None. |
|
bev_feat_channel (int): Bev features channels num. |
|
Default to 256. |
|
bev_scale_factor (int): Bev features scale factor. Default to 8. |
|
voxel_center_as_source (bool): Whether used voxel centers as points |
|
cloud key points. Defaults to False. |
|
norm_cfg (dict[str]): Config of normalization layer. Default |
|
used dict(type='BN1d', eps=1e-5, momentum=0.1). |
|
bias (bool | str, optional): If specified as `auto`, it will be |
|
decided by `norm_cfg`. `bias` will be set as True if |
|
`norm_cfg` is None, otherwise False. Default: 'auto'. |
|
""" |
|
|
|
def __init__(self, |
|
num_keypoints: int, |
|
fused_out_channel: int = 128, |
|
voxel_size: list = [0.05, 0.05, 0.1], |
|
point_cloud_range: list = [0, -40, -3, 70.4, 40, 1], |
|
voxel_sa_cfgs_list: Optional[list] = None, |
|
rawpoints_sa_cfgs: Optional[dict] = None, |
|
bev_feat_channel: int = 256, |
|
bev_scale_factor: int = 8, |
|
voxel_center_as_source: bool = False, |
|
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1), |
|
bias: str = 'auto') -> None: |
|
super().__init__() |
|
self.num_keypoints = num_keypoints |
|
self.fused_out_channel = fused_out_channel |
|
self.voxel_size = voxel_size |
|
self.point_cloud_range = point_cloud_range |
|
self.voxel_center_as_source = voxel_center_as_source |
|
|
|
gathered_channel = 0 |
|
|
|
if rawpoints_sa_cfgs is not None: |
|
self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs) |
|
gathered_channel += sum( |
|
[x[-1] for x in rawpoints_sa_cfgs.mlp_channels]) |
|
else: |
|
self.rawpoints_sa_layer = None |
|
|
|
if voxel_sa_cfgs_list is not None: |
|
self.voxel_sa_configs_list = voxel_sa_cfgs_list |
|
self.voxel_sa_layers = nn.ModuleList() |
|
for voxel_sa_config in voxel_sa_cfgs_list: |
|
cur_layer = MODELS.build(voxel_sa_config) |
|
self.voxel_sa_layers.append(cur_layer) |
|
gathered_channel += sum( |
|
[x[-1] for x in voxel_sa_config.mlp_channels]) |
|
else: |
|
self.voxel_sa_layers = None |
|
|
|
if bev_feat_channel is not None and bev_scale_factor is not None: |
|
self.bev_cfg = mmengine.Config( |
|
dict( |
|
bev_feat_channels=bev_feat_channel, |
|
bev_scale_factor=bev_scale_factor)) |
|
gathered_channel += bev_feat_channel |
|
else: |
|
self.bev_cfg = None |
|
self.point_feature_fusion_layer = nn.Sequential( |
|
ConvModule( |
|
gathered_channel, |
|
fused_out_channel, |
|
kernel_size=(1, 1), |
|
stride=(1, 1), |
|
conv_cfg=dict(type='Conv2d'), |
|
norm_cfg=norm_cfg, |
|
bias=bias)) |
|
|
|
def interpolate_from_bev_features(self, keypoints: torch.Tensor, |
|
bev_features: torch.Tensor, |
|
batch_size: int, |
|
bev_scale_factor: int) -> torch.Tensor: |
|
"""Gather key points features from bev feature map by interpolate. |
|
|
|
Args: |
|
keypoints (torch.Tensor): Sampled key points with shape |
|
(N1 + N2 + ..., NDim). |
|
bev_features (torch.Tensor): Bev feature map from the first |
|
stage with shape (B, C, H, W). |
|
batch_size (int): Input batch size. |
|
bev_scale_factor (int): Bev feature map scale factor. |
|
|
|
Returns: |
|
torch.Tensor: Key points features gather from bev feature |
|
map with shape (N1 + N2 + ..., C) |
|
""" |
|
x_idxs = (keypoints[..., 0] - |
|
self.point_cloud_range[0]) / self.voxel_size[0] |
|
y_idxs = (keypoints[..., 1] - |
|
self.point_cloud_range[1]) / self.voxel_size[1] |
|
|
|
x_idxs = x_idxs / bev_scale_factor |
|
y_idxs = y_idxs / bev_scale_factor |
|
|
|
point_bev_features_list = [] |
|
for k in range(batch_size): |
|
cur_x_idxs = x_idxs[k, ...] |
|
cur_y_idxs = y_idxs[k, ...] |
|
cur_bev_features = bev_features[k].permute(1, 2, 0) |
|
point_bev_features = bilinear_interpolate_torch( |
|
cur_bev_features, cur_x_idxs, cur_y_idxs) |
|
point_bev_features_list.append(point_bev_features) |
|
|
|
point_bev_features = torch.cat( |
|
point_bev_features_list, dim=0) |
|
return point_bev_features.view(batch_size, keypoints.shape[1], -1) |
|
|
|
def get_voxel_centers(self, coors: torch.Tensor, |
|
scale_factor: float) -> torch.Tensor: |
|
"""Get voxel centers coordinate. |
|
|
|
Args: |
|
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim), |
|
where 1 represents the batch index. |
|
scale_factor (float): Scale factor. |
|
|
|
Returns: |
|
torch.Tensor: Voxel centers coordinate with shape (N, 3). |
|
""" |
|
assert coors.shape[1] == 4 |
|
voxel_centers = coors[:, [3, 2, 1]].float() |
|
voxel_size = torch.tensor( |
|
self.voxel_size, |
|
device=voxel_centers.device).float() * scale_factor |
|
pc_range = torch.tensor( |
|
self.point_cloud_range[0:3], device=voxel_centers.device).float() |
|
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range |
|
return voxel_centers |
|
|
|
def sample_key_points(self, points: List[torch.Tensor], |
|
coors: torch.Tensor) -> torch.Tensor: |
|
"""Sample key points from raw points cloud. |
|
|
|
Args: |
|
points (List[torch.Tensor]): Point cloud of each sample. |
|
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim), |
|
where 1 represents the batch index. |
|
|
|
Returns: |
|
torch.Tensor: (B, M, 3) Key points of each sample. |
|
M is num_keypoints. |
|
""" |
|
assert points is not None or coors is not None |
|
if self.voxel_center_as_source: |
|
_src_points = self.get_voxel_centers(coors=coors, scale_factor=1) |
|
batch_size = coors[-1, 0].item() + 1 |
|
src_points = [ |
|
_src_points[coors[:, 0] == b] for b in range(batch_size) |
|
] |
|
else: |
|
src_points = [p[..., :3] for p in points] |
|
|
|
keypoints_list = [] |
|
for points_to_sample in src_points: |
|
num_points = points_to_sample.shape[0] |
|
cur_pt_idxs = furthest_point_sample( |
|
points_to_sample.unsqueeze(dim=0).contiguous(), |
|
self.num_keypoints).long()[0] |
|
|
|
if num_points < self.num_keypoints: |
|
times = int(self.num_keypoints / num_points) + 1 |
|
non_empty = cur_pt_idxs[:num_points] |
|
cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints] |
|
|
|
keypoints = points_to_sample[cur_pt_idxs] |
|
|
|
keypoints_list.append(keypoints) |
|
keypoints = torch.stack(keypoints_list, dim=0) |
|
return keypoints |
|
|
|
def forward(self, batch_inputs_dict: dict, feats_dict: dict, |
|
rpn_results_list: InstanceList) -> dict: |
|
"""Extract point-wise features from multi-input. |
|
|
|
Args: |
|
batch_inputs_dict (dict): The model input dict which include |
|
'points', 'voxels' keys. |
|
|
|
- points (list[torch.Tensor]): Point cloud of each sample. |
|
- voxels (dict[torch.Tensor]): Voxels of the batch sample. |
|
feats_dict (dict): Contains features from the first |
|
stage. |
|
rpn_results_list (List[:obj:`InstanceData`]): Detection results |
|
of rpn head. |
|
|
|
Returns: |
|
dict: Contain Point-wise features, include: |
|
- keypoints (torch.Tensor): Sampled key points. |
|
- keypoint_features (torch.Tensor): Gathered key points |
|
features from multi input. |
|
- fusion_keypoint_features (torch.Tensor): Fusion |
|
keypoint_features by point_feature_fusion_layer. |
|
""" |
|
points = batch_inputs_dict['points'] |
|
voxel_encode_features = feats_dict['multi_scale_3d_feats'] |
|
bev_encode_features = feats_dict['spatial_feats'] |
|
if self.voxel_center_as_source: |
|
voxels_coors = batch_inputs_dict['voxels']['coors'] |
|
else: |
|
voxels_coors = None |
|
keypoints = self.sample_key_points(points, voxels_coors) |
|
|
|
point_features_list = [] |
|
batch_size = len(points) |
|
|
|
if self.bev_cfg is not None: |
|
point_bev_features = self.interpolate_from_bev_features( |
|
keypoints, bev_encode_features, batch_size, |
|
self.bev_cfg.bev_scale_factor) |
|
point_features_list.append(point_bev_features.contiguous()) |
|
|
|
batch_size, num_keypoints, _ = keypoints.shape |
|
key_xyz = keypoints.view(-1, 3) |
|
key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_( |
|
num_keypoints) |
|
|
|
if self.rawpoints_sa_layer is not None: |
|
batch_points = torch.cat(points, dim=0) |
|
batch_cnt = [len(p) for p in points] |
|
xyz = batch_points[:, :3].contiguous() |
|
features = None |
|
if batch_points.size(1) > 0: |
|
features = batch_points[:, 3:].contiguous() |
|
xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32) |
|
|
|
pooled_points, pooled_features = self.rawpoints_sa_layer( |
|
xyz=xyz.contiguous(), |
|
xyz_batch_cnt=xyz_batch_cnt, |
|
new_xyz=key_xyz.contiguous(), |
|
new_xyz_batch_cnt=key_xyz_batch_cnt, |
|
features=features.contiguous(), |
|
) |
|
|
|
point_features_list.append(pooled_features.contiguous().view( |
|
batch_size, num_keypoints, -1)) |
|
if self.voxel_sa_layers is not None: |
|
for k, voxel_sa_layer in enumerate(self.voxel_sa_layers): |
|
cur_coords = voxel_encode_features[k].indices |
|
xyz = self.get_voxel_centers( |
|
coors=cur_coords, |
|
scale_factor=self.voxel_sa_configs_list[k].scale_factor |
|
).contiguous() |
|
xyz_batch_cnt = xyz.new_zeros(batch_size).int() |
|
for bs_idx in range(batch_size): |
|
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum() |
|
|
|
pooled_points, pooled_features = voxel_sa_layer( |
|
xyz=xyz.contiguous(), |
|
xyz_batch_cnt=xyz_batch_cnt, |
|
new_xyz=key_xyz.contiguous(), |
|
new_xyz_batch_cnt=key_xyz_batch_cnt, |
|
features=voxel_encode_features[k].features.contiguous(), |
|
) |
|
point_features_list.append(pooled_features.contiguous().view( |
|
batch_size, num_keypoints, -1)) |
|
|
|
point_features = torch.cat( |
|
point_features_list, dim=-1).view(batch_size * num_keypoints, -1, |
|
1) |
|
|
|
fusion_point_features = self.point_feature_fusion_layer( |
|
point_features.unsqueeze(dim=-1)).squeeze(dim=-1) |
|
|
|
batch_idxs = torch.arange( |
|
batch_size * num_keypoints, device=keypoints.device |
|
) // num_keypoints |
|
batch_keypoints_xyz = torch.cat( |
|
(batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1) |
|
|
|
return dict( |
|
keypoint_features=point_features.squeeze(dim=-1), |
|
fusion_keypoint_features=fusion_point_features.squeeze(dim=-1), |
|
keypoints=batch_keypoints_xyz) |
|
|