|
"""
|
|
Implements the TransFuser vision backbone.
|
|
"""
|
|
|
|
import timm
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from navsim.agents.backbones.internimage import InternImage
|
|
from navsim.agents.backbones.swin import SwinTransformerBEVFT
|
|
from navsim.agents.backbones.vov import VoVNet
|
|
from navsim.agents.transfuser.transfuser_backbone import GPT
|
|
from navsim.agents.utils.vit import DAViT
|
|
|
|
|
|
class TransfuserBackboneConv(nn.Module):
|
|
"""
|
|
Multi-scale Fusion Transformer for image + LiDAR feature fusion
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.backbone_type = config.backbone_type
|
|
if config.backbone_type == 'intern':
|
|
self.image_encoder = InternImage(init_cfg=dict(type='Pretrained',
|
|
checkpoint=config.intern_ckpt
|
|
),
|
|
frozen_stages=2)
|
|
|
|
vit_channels = 2560
|
|
self.image_encoder.init_weights()
|
|
elif config.backbone_type == 'vov':
|
|
self.image_encoder = VoVNet(
|
|
spec_name='V-99-eSE',
|
|
out_features=['stage4', 'stage5'],
|
|
norm_eval=True,
|
|
with_cp=True,
|
|
init_cfg=dict(
|
|
type='Pretrained',
|
|
checkpoint=config.vov_ckpt,
|
|
prefix='img_backbone.'
|
|
)
|
|
)
|
|
|
|
vit_channels = 1024
|
|
self.image_encoder.init_weights()
|
|
elif config.backbone_type == 'swin':
|
|
self.image_encoder = SwinTransformerBEVFT(
|
|
with_cp=True,
|
|
convert_weights=False,
|
|
depths=[2,2,18,2],
|
|
drop_path_rate=0.35,
|
|
embed_dims=192,
|
|
init_cfg=dict(
|
|
checkpoint=config.swin_ckpt,
|
|
type='Pretrained'
|
|
),
|
|
num_heads=[6,12,24,48],
|
|
out_indices=[3],
|
|
patch_norm=True,
|
|
window_size=[16,16,16,16],
|
|
use_abs_pos_embed=True,
|
|
return_stereo_feat=False,
|
|
output_missing_index_as_none=False
|
|
)
|
|
vit_channels = 1536
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.use_ground_plane:
|
|
in_channels = 2 * config.lidar_seq_len
|
|
else:
|
|
in_channels = config.lidar_seq_len
|
|
|
|
self.avgpool_img = nn.AdaptiveAvgPool2d(
|
|
(self.config.img_vert_anchors, self.config.img_horz_anchors)
|
|
)
|
|
|
|
self.lidar_encoder = timm.create_model(
|
|
config.lidar_architecture,
|
|
pretrained=False,
|
|
in_chans=in_channels,
|
|
features_only=True,
|
|
)
|
|
self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1)
|
|
self.avgpool_lidar = nn.AdaptiveAvgPool2d(
|
|
(self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)
|
|
)
|
|
lidar_time_frames = [1, 1, 1, 1]
|
|
|
|
self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1)
|
|
start_index = 0
|
|
|
|
if len(self.lidar_encoder.return_layers) > 4:
|
|
start_index += 1
|
|
|
|
self.transformers = nn.ModuleList(
|
|
[
|
|
GPT(
|
|
n_embd=vit_channels,
|
|
config=config,
|
|
|
|
lidar_time_frames=lidar_time_frames[i],
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
self.lidar_channel_to_img = nn.ModuleList(
|
|
[
|
|
nn.Conv2d(
|
|
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
|
|
vit_channels,
|
|
kernel_size=1,
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
self.img_channel_to_lidar = nn.ModuleList(
|
|
[
|
|
nn.Conv2d(
|
|
vit_channels,
|
|
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
|
|
kernel_size=1,
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
|
|
self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"]
|
|
|
|
channel = self.config.bev_features_channels
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
self.upsample = nn.Upsample(
|
|
scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False
|
|
)
|
|
self.upsample2 = nn.Upsample(
|
|
size=(
|
|
self.config.lidar_resolution_height // self.config.bev_down_sample_factor,
|
|
self.config.lidar_resolution_width // self.config.bev_down_sample_factor,
|
|
),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1)
|
|
self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1)
|
|
|
|
|
|
self.c5_conv = nn.Conv2d(
|
|
self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1)
|
|
)
|
|
|
|
def top_down(self, x):
|
|
|
|
p5 = self.relu(self.c5_conv(x))
|
|
p4 = self.relu(self.up_conv5(self.upsample(p5)))
|
|
p3 = self.relu(self.up_conv4(self.upsample2(p4)))
|
|
|
|
return p3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, image, lidar):
|
|
"""
|
|
Image + LiDAR feature fusion using transformers
|
|
Args:
|
|
image_list (list): list of input images
|
|
lidar_list (list): list of input LiDAR BEV
|
|
"""
|
|
image_features, lidar_features = image, lidar
|
|
|
|
|
|
lidar_layers = iter(self.lidar_encoder.items())
|
|
|
|
|
|
|
|
if len(self.lidar_encoder.return_layers) > 4:
|
|
lidar_features = self.forward_layer_block(
|
|
lidar_layers, self.lidar_encoder.return_layers, lidar_features
|
|
)
|
|
|
|
|
|
|
|
|
|
image_features = self.image_encoder(image_features)[-1]
|
|
|
|
|
|
for i in range(4):
|
|
lidar_features = self.forward_layer_block(
|
|
lidar_layers, self.lidar_encoder.return_layers, lidar_features
|
|
)
|
|
|
|
image_features, lidar_features = self.fuse_features(image_features, lidar_features, i)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
x4 = lidar_features
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.transformer_decoder_join:
|
|
fused_features = lidar_features
|
|
else:
|
|
image_features = self.global_pool_img(image_features)
|
|
image_features = torch.flatten(image_features, 1)
|
|
lidar_features = self.global_pool_lidar(lidar_features)
|
|
lidar_features = torch.flatten(lidar_features, 1)
|
|
|
|
if self.config.add_features:
|
|
lidar_features = self.lidar_to_img_features_end(lidar_features)
|
|
fused_features = image_features + lidar_features
|
|
else:
|
|
fused_features = torch.cat((image_features, lidar_features), dim=1)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
features = self.top_down(x4)
|
|
else:
|
|
features = None
|
|
|
|
return features, fused_features, image_features
|
|
|
|
def forward_layer_block(self, layers, return_layers, features, if_ckpt=False):
|
|
"""
|
|
Run one forward pass to a block of layers from a TIMM neural network and returns the result.
|
|
Advances the whole network by just one block
|
|
:param layers: Iterator starting at the current layer block
|
|
:param return_layers: TIMM dictionary describing at which intermediate layers features are returned.
|
|
:param features: Input features
|
|
:return: Processed features
|
|
"""
|
|
for name, module in layers:
|
|
if if_ckpt:
|
|
features = checkpoint(module, features)
|
|
else:
|
|
features = module(features)
|
|
if name in return_layers:
|
|
break
|
|
return features
|
|
|
|
def fuse_features(self, image_features, lidar_features, layer_idx):
|
|
"""
|
|
Perform a TransFuser feature fusion block using a Transformer module.
|
|
:param image_features: Features from the image branch
|
|
:param lidar_features: Features from the LiDAR branch
|
|
:param layer_idx: Transformer layer index.
|
|
:return: image_features and lidar_features with added features from the other branch.
|
|
"""
|
|
image_embd_layer = self.avgpool_img(image_features)
|
|
lidar_embd_layer = self.avgpool_lidar(lidar_features)
|
|
|
|
lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer)
|
|
|
|
image_features_layer, lidar_features_layer = self.transformers[layer_idx](
|
|
image_embd_layer, lidar_embd_layer
|
|
)
|
|
lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer)
|
|
|
|
image_features_layer = F.interpolate(
|
|
image_features_layer,
|
|
size=(image_features.shape[2], image_features.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
lidar_features_layer = F.interpolate(
|
|
lidar_features_layer,
|
|
size=(lidar_features.shape[2], lidar_features.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
image_features = image_features + image_features_layer
|
|
lidar_features = lidar_features + lidar_features_layer
|
|
|
|
return image_features, lidar_features
|
|
|