3dtest / mmdet3d /models /middle_encoders /voxel_set_abstraction.py
giantmonkeyTC
mm2
c2ca15f
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import mmengine
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops.furthest_point_sample import furthest_point_sample
from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
def bilinear_interpolate_torch(inputs: Tensor, x: Tensor, y: Tensor) -> Tensor:
"""Bilinear interpolate for inputs."""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, inputs.shape[1] - 1)
x1 = torch.clamp(x1, 0, inputs.shape[1] - 1)
y0 = torch.clamp(y0, 0, inputs.shape[0] - 1)
y1 = torch.clamp(y1, 0, inputs.shape[0] - 1)
Ia = inputs[y0, x0]
Ib = inputs[y1, x0]
Ic = inputs[y0, x1]
Id = inputs[y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(
torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
return ans
@MODELS.register_module()
class VoxelSetAbstraction(BaseModule):
"""Voxel set abstraction module for PVRCNN and PVRCNN++.
Args:
num_keypoints (int): The number of key points sampled from
raw points cloud.
fused_out_channel (int): Key points feature output channels
num after fused. Default to 128.
voxel_size (list[float]): Size of voxels. Defaults to
[0.05, 0.05, 0.1].
point_cloud_range (list[float]): Point cloud range. Defaults to
[0, -40, -3, 70.4, 40, 1].
voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA
module cfg. Used to gather key points features from multi-wise
voxel features. Default to None.
rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg.
Used to gather key points features from raw points. Default to
None.
bev_feat_channel (int): Bev features channels num.
Default to 256.
bev_scale_factor (int): Bev features scale factor. Default to 8.
voxel_center_as_source (bool): Whether used voxel centers as points
cloud key points. Defaults to False.
norm_cfg (dict[str]): Config of normalization layer. Default
used dict(type='BN1d', eps=1e-5, momentum=0.1).
bias (bool | str, optional): If specified as `auto`, it will be
decided by `norm_cfg`. `bias` will be set as True if
`norm_cfg` is None, otherwise False. Default: 'auto'.
"""
def __init__(self,
num_keypoints: int,
fused_out_channel: int = 128,
voxel_size: list = [0.05, 0.05, 0.1],
point_cloud_range: list = [0, -40, -3, 70.4, 40, 1],
voxel_sa_cfgs_list: Optional[list] = None,
rawpoints_sa_cfgs: Optional[dict] = None,
bev_feat_channel: int = 256,
bev_scale_factor: int = 8,
voxel_center_as_source: bool = False,
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1),
bias: str = 'auto') -> None:
super().__init__()
self.num_keypoints = num_keypoints
self.fused_out_channel = fused_out_channel
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.voxel_center_as_source = voxel_center_as_source
gathered_channel = 0
if rawpoints_sa_cfgs is not None:
self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs)
gathered_channel += sum(
[x[-1] for x in rawpoints_sa_cfgs.mlp_channels])
else:
self.rawpoints_sa_layer = None
if voxel_sa_cfgs_list is not None:
self.voxel_sa_configs_list = voxel_sa_cfgs_list
self.voxel_sa_layers = nn.ModuleList()
for voxel_sa_config in voxel_sa_cfgs_list:
cur_layer = MODELS.build(voxel_sa_config)
self.voxel_sa_layers.append(cur_layer)
gathered_channel += sum(
[x[-1] for x in voxel_sa_config.mlp_channels])
else:
self.voxel_sa_layers = None
if bev_feat_channel is not None and bev_scale_factor is not None:
self.bev_cfg = mmengine.Config(
dict(
bev_feat_channels=bev_feat_channel,
bev_scale_factor=bev_scale_factor))
gathered_channel += bev_feat_channel
else:
self.bev_cfg = None
self.point_feature_fusion_layer = nn.Sequential(
ConvModule(
gathered_channel,
fused_out_channel,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=bias))
def interpolate_from_bev_features(self, keypoints: torch.Tensor,
bev_features: torch.Tensor,
batch_size: int,
bev_scale_factor: int) -> torch.Tensor:
"""Gather key points features from bev feature map by interpolate.
Args:
keypoints (torch.Tensor): Sampled key points with shape
(N1 + N2 + ..., NDim).
bev_features (torch.Tensor): Bev feature map from the first
stage with shape (B, C, H, W).
batch_size (int): Input batch size.
bev_scale_factor (int): Bev feature map scale factor.
Returns:
torch.Tensor: Key points features gather from bev feature
map with shape (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[..., 0] -
self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[..., 1] -
self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_scale_factor
y_idxs = y_idxs / bev_scale_factor
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k, ...]
cur_y_idxs = y_idxs[k, ...]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(
cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(
point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features.view(batch_size, keypoints.shape[1], -1)
def get_voxel_centers(self, coors: torch.Tensor,
scale_factor: float) -> torch.Tensor:
"""Get voxel centers coordinate.
Args:
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
scale_factor (float): Scale factor.
Returns:
torch.Tensor: Voxel centers coordinate with shape (N, 3).
"""
assert coors.shape[1] == 4
voxel_centers = coors[:, [3, 2, 1]].float() # (xyz)
voxel_size = torch.tensor(
self.voxel_size,
device=voxel_centers.device).float() * scale_factor
pc_range = torch.tensor(
self.point_cloud_range[0:3], device=voxel_centers.device).float()
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
return voxel_centers
def sample_key_points(self, points: List[torch.Tensor],
coors: torch.Tensor) -> torch.Tensor:
"""Sample key points from raw points cloud.
Args:
points (List[torch.Tensor]): Point cloud of each sample.
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
Returns:
torch.Tensor: (B, M, 3) Key points of each sample.
M is num_keypoints.
"""
assert points is not None or coors is not None
if self.voxel_center_as_source:
_src_points = self.get_voxel_centers(coors=coors, scale_factor=1)
batch_size = coors[-1, 0].item() + 1
src_points = [
_src_points[coors[:, 0] == b] for b in range(batch_size)
]
else:
src_points = [p[..., :3] for p in points]
keypoints_list = []
for points_to_sample in src_points:
num_points = points_to_sample.shape[0]
cur_pt_idxs = furthest_point_sample(
points_to_sample.unsqueeze(dim=0).contiguous(),
self.num_keypoints).long()[0]
if num_points < self.num_keypoints:
times = int(self.num_keypoints / num_points) + 1
non_empty = cur_pt_idxs[:num_points]
cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints]
keypoints = points_to_sample[cur_pt_idxs]
keypoints_list.append(keypoints)
keypoints = torch.stack(keypoints_list, dim=0) # (B, M, 3)
return keypoints
def forward(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from multi-input.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first
stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gathered key points
features from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
points = batch_inputs_dict['points']
voxel_encode_features = feats_dict['multi_scale_3d_feats']
bev_encode_features = feats_dict['spatial_feats']
if self.voxel_center_as_source:
voxels_coors = batch_inputs_dict['voxels']['coors']
else:
voxels_coors = None
keypoints = self.sample_key_points(points, voxels_coors)
point_features_list = []
batch_size = len(points)
if self.bev_cfg is not None:
point_bev_features = self.interpolate_from_bev_features(
keypoints, bev_encode_features, batch_size,
self.bev_cfg.bev_scale_factor)
point_features_list.append(point_bev_features.contiguous())
batch_size, num_keypoints, _ = keypoints.shape
key_xyz = keypoints.view(-1, 3)
key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_(
num_keypoints)
if self.rawpoints_sa_layer is not None:
batch_points = torch.cat(points, dim=0)
batch_cnt = [len(p) for p in points]
xyz = batch_points[:, :3].contiguous()
features = None
if batch_points.size(1) > 0:
features = batch_points[:, 3:].contiguous()
xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32)
pooled_points, pooled_features = self.rawpoints_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
if self.voxel_sa_layers is not None:
for k, voxel_sa_layer in enumerate(self.voxel_sa_layers):
cur_coords = voxel_encode_features[k].indices
xyz = self.get_voxel_centers(
coors=cur_coords,
scale_factor=self.voxel_sa_configs_list[k].scale_factor
).contiguous()
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = voxel_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=voxel_encode_features[k].features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
point_features = torch.cat(
point_features_list, dim=-1).view(batch_size * num_keypoints, -1,
1)
fusion_point_features = self.point_feature_fusion_layer(
point_features.unsqueeze(dim=-1)).squeeze(dim=-1)
batch_idxs = torch.arange(
batch_size * num_keypoints, device=keypoints.device
) // num_keypoints # batch indexes of each key points
batch_keypoints_xyz = torch.cat(
(batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1)
return dict(
keypoint_features=point_features.squeeze(dim=-1),
fusion_keypoint_features=fusion_point_features.squeeze(dim=-1),
keypoints=batch_keypoints_xyz)