navsim_ours / navsim /agents /transfuser /transfuser_backbone_conv.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
11.7 kB
"""
Implements the TransFuser vision backbone.
"""
import timm
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from navsim.agents.backbones.internimage import InternImage
from navsim.agents.backbones.swin import SwinTransformerBEVFT
from navsim.agents.backbones.vov import VoVNet
from navsim.agents.transfuser.transfuser_backbone import GPT
from navsim.agents.utils.vit import DAViT
class TransfuserBackboneConv(nn.Module):
"""
Multi-scale Fusion Transformer for image + LiDAR feature fusion
"""
def __init__(self, config):
super().__init__()
self.config = config
self.backbone_type = config.backbone_type
if config.backbone_type == 'intern':
self.image_encoder = InternImage(init_cfg=dict(type='Pretrained',
checkpoint=config.intern_ckpt
),
frozen_stages=2)
# scale_4_c = 2560
vit_channels = 2560
self.image_encoder.init_weights()
elif config.backbone_type == 'vov':
self.image_encoder = VoVNet(
spec_name='V-99-eSE',
out_features=['stage4', 'stage5'],
norm_eval=True,
with_cp=True,
init_cfg=dict(
type='Pretrained',
checkpoint=config.vov_ckpt,
prefix='img_backbone.'
)
)
# scale_4_c = 1024
vit_channels = 1024
self.image_encoder.init_weights()
elif config.backbone_type == 'swin':
self.image_encoder = SwinTransformerBEVFT(
with_cp=True,
convert_weights=False,
depths=[2,2,18,2],
drop_path_rate=0.35,
embed_dims=192,
init_cfg=dict(
checkpoint=config.swin_ckpt,
type='Pretrained'
),
num_heads=[6,12,24,48],
out_indices=[3],
patch_norm=True,
window_size=[16,16,16,16],
use_abs_pos_embed=True,
return_stereo_feat=False,
output_missing_index_as_none=False
)
vit_channels = 1536
else:
raise ValueError
# self.lateral_3 = nn.Sequential(*[
# nn.Conv2d(vit_channels,
# vit_channels,
# kernel_size=1),
# nn.ReLU(inplace=True)
# ])
# self.lateral_4 = nn.Sequential(*[
# nn.Conv2d(scale_4_c,
# vit_channels,
# kernel_size=1),
# nn.ReLU(inplace=True)
# ])
# self.fpn_out = nn.Sequential(*[
# nn.Conv2d(vit_channels,
# vit_channels,
# kernel_size=3, padding=1),
# nn.ReLU(inplace=True)
# ])
if config.use_ground_plane:
in_channels = 2 * config.lidar_seq_len
else:
in_channels = config.lidar_seq_len
self.avgpool_img = nn.AdaptiveAvgPool2d(
(self.config.img_vert_anchors, self.config.img_horz_anchors)
)
self.lidar_encoder = timm.create_model(
config.lidar_architecture,
pretrained=False,
in_chans=in_channels,
features_only=True,
)
self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1)
self.avgpool_lidar = nn.AdaptiveAvgPool2d(
(self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)
)
lidar_time_frames = [1, 1, 1, 1]
self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1)
start_index = 0
# Some networks have a stem layer
if len(self.lidar_encoder.return_layers) > 4:
start_index += 1
self.transformers = nn.ModuleList(
[
GPT(
n_embd=vit_channels,
config=config,
# lidar_video=self.lidar_video,
lidar_time_frames=lidar_time_frames[i],
)
for i in range(4)
]
)
self.lidar_channel_to_img = nn.ModuleList(
[
nn.Conv2d(
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
vit_channels,
kernel_size=1,
)
for i in range(4)
]
)
self.img_channel_to_lidar = nn.ModuleList(
[
nn.Conv2d(
vit_channels,
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
kernel_size=1,
)
for i in range(4)
]
)
self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"]
# FPN fusion
channel = self.config.bev_features_channels
self.relu = nn.ReLU(inplace=True)
# top down
if self.config.detect_boxes or self.config.use_bev_semantic:
self.upsample = nn.Upsample(
scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False
)
self.upsample2 = nn.Upsample(
size=(
self.config.lidar_resolution_height // self.config.bev_down_sample_factor,
self.config.lidar_resolution_width // self.config.bev_down_sample_factor,
),
mode="bilinear",
align_corners=False,
)
self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1)
self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1)
# lateral
self.c5_conv = nn.Conv2d(
self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1)
)
def top_down(self, x):
p5 = self.relu(self.c5_conv(x))
p4 = self.relu(self.up_conv5(self.upsample(p5)))
p3 = self.relu(self.up_conv4(self.upsample2(p4)))
return p3
# def fpn(self, xs):
# x_4 = xs[-1]
# x_3 = xs[-2]
# out = self.fpn_out(
# F.interpolate(self.lateral_4(x_4), scale_factor=self.config.bev_upsample_factor, mode='bilinear', align_corners=False)
# + self.lateral_3(x_3)
# )
#
# return out
def forward(self, image, lidar):
"""
Image + LiDAR feature fusion using transformers
Args:
image_list (list): list of input images
lidar_list (list): list of input LiDAR BEV
"""
image_features, lidar_features = image, lidar
# Generate an iterator for all the layers in the network that one can loop through.
lidar_layers = iter(self.lidar_encoder.items())
# Stem layer.
# In some architectures the stem is not a return layer, so we need to skip it.
if len(self.lidar_encoder.return_layers) > 4:
lidar_features = self.forward_layer_block(
lidar_layers, self.lidar_encoder.return_layers, lidar_features
)
# Loop through the 4 blocks of the network.
# FPN
# image_features = self.fpn(self.image_encoder(image_features))
image_features = self.image_encoder(image_features)[-1]
# print(image_features.shape)
for i in range(4):
lidar_features = self.forward_layer_block(
lidar_layers, self.lidar_encoder.return_layers, lidar_features
)
image_features, lidar_features = self.fuse_features(image_features, lidar_features, i)
if self.config.detect_boxes or self.config.use_bev_semantic:
x4 = lidar_features
# image_feature_grid = None
# if self.config.use_semantic or self.config.use_depth:
# image_feature_grid = image_features
if self.config.transformer_decoder_join:
fused_features = lidar_features
else:
image_features = self.global_pool_img(image_features)
image_features = torch.flatten(image_features, 1)
lidar_features = self.global_pool_lidar(lidar_features)
lidar_features = torch.flatten(lidar_features, 1)
if self.config.add_features:
lidar_features = self.lidar_to_img_features_end(lidar_features)
fused_features = image_features + lidar_features
else:
fused_features = torch.cat((image_features, lidar_features), dim=1)
if self.config.detect_boxes or self.config.use_bev_semantic:
features = self.top_down(x4)
else:
features = None
return features, fused_features, image_features
def forward_layer_block(self, layers, return_layers, features, if_ckpt=False):
"""
Run one forward pass to a block of layers from a TIMM neural network and returns the result.
Advances the whole network by just one block
:param layers: Iterator starting at the current layer block
:param return_layers: TIMM dictionary describing at which intermediate layers features are returned.
:param features: Input features
:return: Processed features
"""
for name, module in layers:
if if_ckpt:
features = checkpoint(module, features)
else:
features = module(features)
if name in return_layers:
break
return features
def fuse_features(self, image_features, lidar_features, layer_idx):
"""
Perform a TransFuser feature fusion block using a Transformer module.
:param image_features: Features from the image branch
:param lidar_features: Features from the LiDAR branch
:param layer_idx: Transformer layer index.
:return: image_features and lidar_features with added features from the other branch.
"""
image_embd_layer = self.avgpool_img(image_features)
lidar_embd_layer = self.avgpool_lidar(lidar_features)
lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer)
image_features_layer, lidar_features_layer = self.transformers[layer_idx](
image_embd_layer, lidar_embd_layer
)
lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer)
image_features_layer = F.interpolate(
image_features_layer,
size=(image_features.shape[2], image_features.shape[3]),
mode="bilinear",
align_corners=False,
)
lidar_features_layer = F.interpolate(
lidar_features_layer,
size=(lidar_features.shape[2], lidar_features.shape[3]),
mode="bilinear",
align_corners=False,
)
image_features = image_features + image_features_layer
lidar_features = lidar_features + lidar_features_layer
return image_features, lidar_features