|
import copy
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from torch.nn.init import normal_
|
|
from det_map.det.dal.mmdet3d.models.builder import build_fuser
|
|
import torch.nn.functional as F
|
|
from mmdet.models.utils.builder import TRANSFORMER
|
|
from det_map.det.dal.mmdet3d.models.builder import FUSERS
|
|
from mmcv.cnn import Linear, bias_init_with_prob, xavier_init, constant_init
|
|
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
|
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence, build_positional_encoding
|
|
from torchvision.transforms.functional import rotate
|
|
from det_map.det.dal.mmdet3d.models.bevformer_modules.temporal_self_attention import TemporalSelfAttention
|
|
from det_map.det.dal.mmdet3d.models.bevformer_modules.spatial_cross_attention import MSDeformableAttention3D
|
|
from det_map.det.dal.mmdet3d.models.bevformer_modules.decoder import CustomMSDeformableAttention
|
|
from typing import List
|
|
|
|
@FUSERS.register_module()
|
|
class ConvFuser(nn.Sequential):
|
|
def __init__(self, in_channels: int, out_channels: int) -> None:
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
super().__init__(
|
|
nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
|
return super().forward(torch.cat(inputs, dim=1))
|
|
|
|
|
|
|
|
@TRANSFORMER.register_module()
|
|
class MapTRPerceptionTransformer(BaseModule):
|
|
"""Implements the Detr3D transformer.
|
|
Args:
|
|
as_two_stage (bool): Generate query from encoder features.
|
|
Default: False.
|
|
num_feature_levels (int): Number of feature maps from FPN:
|
|
Default: 4.
|
|
two_stage_num_proposals (int): Number of proposals when set
|
|
`as_two_stage` as True. Default: 300.
|
|
"""
|
|
|
|
def __init__(self,
|
|
bev_h, bev_w,
|
|
num_feature_levels=1,
|
|
num_cams=2,
|
|
z_cfg=dict(
|
|
pred_z_flag=False,
|
|
gt_z_flag=False,
|
|
),
|
|
two_stage_num_proposals=300,
|
|
fuser=None,
|
|
encoder=None,
|
|
decoder=None,
|
|
embed_dims=256,
|
|
rotate_prev_bev=True,
|
|
use_shift=True,
|
|
use_can_bus=True,
|
|
can_bus_norm=True,
|
|
use_cams_embeds=True,
|
|
rotate_center=[100, 100],
|
|
modality='vision',
|
|
feat_down_sample_indice=-1,
|
|
**kwargs):
|
|
super(MapTRPerceptionTransformer, self).__init__(**kwargs)
|
|
if modality == 'fusion':
|
|
self.fuser = build_fuser(fuser)
|
|
|
|
|
|
self.use_attn_bev = True
|
|
self.bev_h = bev_h
|
|
self.bev_w = bev_w
|
|
self.bev_embedding = nn.Embedding(self.bev_h * self.bev_w, embed_dims)
|
|
self.positional_encoding = build_positional_encoding(
|
|
dict(
|
|
type='CustomLearnedPositionalEncoding',
|
|
num_feats=embed_dims // 2,
|
|
row_num_embed=self.bev_h,
|
|
col_num_embed=self.bev_w,
|
|
)
|
|
)
|
|
self.encoder = build_transformer_layer_sequence(encoder)
|
|
self.decoder = build_transformer_layer_sequence(decoder)
|
|
self.embed_dims = embed_dims
|
|
self.num_feature_levels = num_feature_levels
|
|
self.num_cams = num_cams
|
|
self.fp16_enabled = False
|
|
|
|
self.rotate_prev_bev = rotate_prev_bev
|
|
self.use_shift = use_shift
|
|
self.use_can_bus = use_can_bus
|
|
self.can_bus_norm = can_bus_norm
|
|
self.use_cams_embeds = use_cams_embeds
|
|
|
|
self.two_stage_num_proposals = two_stage_num_proposals
|
|
self.z_cfg=z_cfg
|
|
self.init_layers()
|
|
self.rotate_center = rotate_center
|
|
self.feat_down_sample_indice = feat_down_sample_indice
|
|
|
|
def init_layers(self):
|
|
"""Initialize layers of the Detr3DTransformer."""
|
|
|
|
|
|
|
|
|
|
self.reference_points = nn.Linear(self.embed_dims, 2) if not self.z_cfg['gt_z_flag'] \
|
|
else nn.Linear(self.embed_dims, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self):
|
|
"""Initialize the transformer weights."""
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
for m in self.modules():
|
|
if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
|
|
or isinstance(m, CustomMSDeformableAttention):
|
|
try:
|
|
m.init_weight()
|
|
except AttributeError:
|
|
m.init_weights()
|
|
normal_(self.level_embeds)
|
|
normal_(self.cams_embeds)
|
|
xavier_init(self.reference_points, distribution='uniform', bias=0.)
|
|
|
|
|
|
|
|
|
|
def attn_bev_encode(
|
|
self,
|
|
mlvl_feats,
|
|
cam_params=None,
|
|
gt_bboxes_3d=None,
|
|
pred_img_depth=None,
|
|
prev_bev=None,
|
|
bev_mask=None,
|
|
**kwargs):
|
|
|
|
bs = mlvl_feats[0].size(0)
|
|
dtype = mlvl_feats[0].dtype
|
|
|
|
feat_flatten = []
|
|
spatial_shapes = []
|
|
for lvl, feat in enumerate(mlvl_feats):
|
|
bs, num_cam, c, h, w = feat.shape
|
|
spatial_shape = (h, w)
|
|
feat = feat.flatten(3).permute(1, 0, 3, 2)
|
|
|
|
spatial_shapes.append(spatial_shape)
|
|
feat_flatten.append(feat)
|
|
|
|
feat_flatten = torch.cat(feat_flatten, 2)
|
|
spatial_shapes = torch.as_tensor(
|
|
spatial_shapes, dtype=torch.long, device=mlvl_feats[0].device)
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros(
|
|
(1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
|
|
feat_flatten = feat_flatten.permute(0, 2, 1, 3)
|
|
|
|
bev_queries = self.bev_embedding.weight.to(dtype)
|
|
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
|
|
bev_pos = self.positional_encoding(bs, self.bev_h, self.bev_w, bev_queries.device).to(dtype)
|
|
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
|
|
|
|
bev_embed = self.encoder(
|
|
bev_queries,
|
|
feat_flatten,
|
|
feat_flatten,
|
|
bev_h=self.bev_h,
|
|
bev_w=self.bev_w,
|
|
bev_pos=bev_pos,
|
|
spatial_shapes=spatial_shapes,
|
|
level_start_index=level_start_index,
|
|
cam_params=cam_params,
|
|
gt_bboxes_3d=gt_bboxes_3d,
|
|
pred_img_depth=pred_img_depth,
|
|
prev_bev=prev_bev,
|
|
bev_mask=bev_mask,
|
|
**kwargs
|
|
)
|
|
|
|
return bev_embed
|
|
|
|
def lss_bev_encode(
|
|
self,
|
|
mlvl_feats,
|
|
prev_bev=None,
|
|
**kwargs):
|
|
|
|
|
|
|
|
images = mlvl_feats[self.feat_down_sample_indice]
|
|
img_metas = kwargs['img_metas']
|
|
encoder_outputdict = self.encoder(images,img_metas)
|
|
bev_embed = encoder_outputdict['bev']
|
|
depth = encoder_outputdict['depth']
|
|
bs, c, _,_ = bev_embed.shape
|
|
bev_embed = bev_embed.view(bs,c,-1).permute(0,2,1).contiguous()
|
|
ret_dict = dict(
|
|
bev=bev_embed,
|
|
depth=depth
|
|
)
|
|
return ret_dict
|
|
|
|
def get_bev_features(
|
|
self,
|
|
mlvl_feats,
|
|
lidar_feat,
|
|
bev_queries,
|
|
bev_h,
|
|
bev_w,
|
|
grid_length=[0.512, 0.512],
|
|
bev_pos=None,
|
|
prev_bev=None,
|
|
**kwargs):
|
|
"""
|
|
obtain bev features.
|
|
"""
|
|
assert self.use_attn_bev
|
|
if self.use_attn_bev:
|
|
img_metas = kwargs['img_metas']
|
|
rot = img_metas['sensor2lidar_rotation']
|
|
B, T, N, _, _ = rot.shape
|
|
cam_params = (img_metas['sensor2lidar_rotation'][:, -1],
|
|
img_metas['sensor2lidar_translation'][:, -1],
|
|
img_metas['intrinsics'][:, -1],
|
|
img_metas['post_rot'][:, -1],
|
|
img_metas['post_tran'][:, -1],
|
|
torch.eye(3, device=rot.device, dtype=rot.dtype)[None].repeat(B, 1, 1)
|
|
)
|
|
bev_embed = self.attn_bev_encode(
|
|
mlvl_feats,
|
|
cam_params=cam_params,
|
|
**kwargs)
|
|
else:
|
|
ret_dict = self.lss_bev_encode(
|
|
mlvl_feats,
|
|
prev_bev=prev_bev,
|
|
**kwargs)
|
|
bev_embed = ret_dict['bev']
|
|
depth = ret_dict['depth']
|
|
if lidar_feat is not None:
|
|
bs = mlvl_feats[0].size(0)
|
|
bev_embed = bev_embed.view(bs, bev_h, bev_w, -1).permute(0,3,1,2).contiguous()
|
|
lidar_feat = lidar_feat.permute(0,1,3,2).contiguous()
|
|
|
|
fused_bev = self.fuser([bev_embed, lidar_feat])
|
|
fused_bev = fused_bev.flatten(2).permute(0,2,1).contiguous()
|
|
bev_embed = fused_bev
|
|
ret_dict = dict(
|
|
bev=bev_embed,
|
|
depth=None
|
|
)
|
|
return ret_dict
|
|
|
|
def format_feats(self, mlvl_feats):
|
|
bs = mlvl_feats[0].size(0)
|
|
feat_flatten = []
|
|
spatial_shapes = []
|
|
for lvl, feat in enumerate(mlvl_feats):
|
|
|
|
bs, num_cam, c, h, w = feat.shape
|
|
spatial_shape = (h, w)
|
|
feat = feat.flatten(3).permute(1, 0, 3, 2)
|
|
if self.use_cams_embeds:
|
|
feat = feat
|
|
feat = feat
|
|
spatial_shapes.append(spatial_shape)
|
|
feat_flatten.append(feat)
|
|
|
|
feat_flatten = torch.cat(feat_flatten, 2)
|
|
spatial_shapes = torch.as_tensor(
|
|
spatial_shapes, dtype=torch.long, device=feat.device)
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros(
|
|
(1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
|
|
feat_flatten = feat_flatten.permute(
|
|
0, 2, 1, 3)
|
|
return feat_flatten, spatial_shapes, level_start_index
|
|
|
|
|
|
def forward(self,
|
|
mlvl_feats,
|
|
lidar_feat,
|
|
bev_queries,
|
|
object_query_embed,
|
|
bev_h,
|
|
bev_w,
|
|
grid_length=[0.512, 0.512],
|
|
bev_pos=None,
|
|
reg_branches=None,
|
|
cls_branches=None,
|
|
prev_bev=None,
|
|
**kwargs):
|
|
"""Forward function for `Detr3DTransformer`.
|
|
Args:
|
|
mlvl_feats (list(Tensor)): Input queries from
|
|
different level. Each element has shape
|
|
[bs, num_cams, embed_dims, h, w].
|
|
bev_queries (Tensor): (bev_h*bev_w, c)
|
|
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
|
|
object_query_embed (Tensor): The query embedding for decoder,
|
|
with shape [num_query, c].
|
|
reg_branches (obj:`nn.ModuleList`): Regression heads for
|
|
feature maps from each decoder layer. Only would
|
|
be passed when `with_box_refine` is True. Default to None.
|
|
Returns:
|
|
tuple[Tensor]: results of decoder containing the following tensor.
|
|
- bev_embed: BEV features
|
|
- inter_states: Outputs from decoder. If
|
|
return_intermediate_dec is True output has shape \
|
|
(num_dec_layers, bs, num_query, embed_dims), else has \
|
|
shape (1, bs, num_query, embed_dims).
|
|
- init_reference_out: The initial value of reference \
|
|
points, has shape (bs, num_queries, 4).
|
|
- inter_references_out: The internal value of reference \
|
|
points in decoder, has shape \
|
|
(num_dec_layers, bs,num_query, embed_dims)
|
|
- enc_outputs_class: The classification score of \
|
|
proposals generated from \
|
|
encoder's feature maps, has shape \
|
|
(batch, h*w, num_classes). \
|
|
Only would be returned when `as_two_stage` is True, \
|
|
otherwise None.
|
|
- enc_outputs_coord_unact: The regression results \
|
|
generated from encoder's feature maps., has shape \
|
|
(batch, h*w, 4). Only would \
|
|
be returned when `as_two_stage` is True, \
|
|
otherwise None.
|
|
"""
|
|
|
|
ouput_dic = self.get_bev_features(
|
|
mlvl_feats,
|
|
lidar_feat,
|
|
bev_queries,
|
|
bev_h,
|
|
bev_w,
|
|
grid_length=grid_length,
|
|
bev_pos=bev_pos,
|
|
prev_bev=prev_bev,
|
|
**kwargs)
|
|
bev_embed = ouput_dic['bev']
|
|
depth = ouput_dic['depth']
|
|
bs = mlvl_feats[0].size(0)
|
|
query_pos, query = torch.split(
|
|
object_query_embed, self.embed_dims, dim=1)
|
|
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
|
|
query = query.unsqueeze(0).expand(bs, -1, -1)
|
|
reference_points = self.reference_points(query_pos)
|
|
reference_points = reference_points.sigmoid()
|
|
init_reference_out = reference_points
|
|
|
|
query = query.permute(1, 0, 2)
|
|
query_pos = query_pos.permute(1, 0, 2)
|
|
bev_embed = bev_embed.permute(1, 0, 2)
|
|
|
|
feat_flatten, feat_spatial_shapes, feat_level_start_index \
|
|
= self.format_feats(mlvl_feats)
|
|
|
|
inter_states, inter_references = self.decoder(
|
|
query=query,
|
|
key=None,
|
|
value=bev_embed,
|
|
query_pos=query_pos,
|
|
reference_points=reference_points,
|
|
reg_branches=reg_branches,
|
|
cls_branches=cls_branches,
|
|
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
|
|
level_start_index=torch.tensor([0], device=query.device),
|
|
mlvl_feats=mlvl_feats,
|
|
feat_flatten=None,
|
|
feat_spatial_shapes=feat_spatial_shapes,
|
|
feat_level_start_index=feat_level_start_index,
|
|
**kwargs)
|
|
|
|
inter_references_out = inter_references
|
|
|
|
return bev_embed, depth, inter_states, init_reference_out, inter_references_out
|
|
|