""" Implements the TransFuser vision backbone. """ import math import torch from torch import nn import torch.nn.functional as F import timm import copy from torch.utils.checkpoint import checkpoint from navsim.agents.backbones.eva import EVAViT from navsim.agents.transfuser.transfuser_backbone import GPT from timm.models.vision_transformer import VisionTransformer from navsim.agents.utils.vit import DAViT class TransfuserBackboneViT(nn.Module): """ Multi-scale Fusion Transformer for image + LiDAR feature fusion """ def __init__(self, config): super().__init__() self.config = config # debug # vit-l if config.backbone_type == 'vit': self.image_encoder = DAViT(ckpt=config.vit_ckpt) elif config.backbone_type == 'eva': self.image_encoder = EVAViT( img_size=512, # img_size for short side patch_size=16, window_size=16, global_window_size=32, # If use square image (e.g., set global_window_size=0, else global_window_size=img_size // 16) in_chans=3, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4 * 2 / 3, window_block_indexes=( list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list( range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)) ), qkv_bias=True, drop_path_rate=0.3, with_cp=True, flash_attn=False, xformers_attn=True ) self.image_encoder.init_weights(config.eva_ckpt) else: raise ValueError('unsupported vit backbones') if config.use_ground_plane: in_channels = 2 * config.lidar_seq_len else: in_channels = config.lidar_seq_len self.avgpool_img = nn.AdaptiveAvgPool2d( (self.config.img_vert_anchors, self.config.img_horz_anchors) ) self.lidar_encoder = timm.create_model( config.lidar_architecture, pretrained=False, in_chans=in_channels, features_only=True, ) self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1) self.avgpool_lidar = nn.AdaptiveAvgPool2d( (self.config.lidar_vert_anchors, self.config.lidar_horz_anchors) ) lidar_time_frames = [1, 1, 1, 1] self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1) start_index = 0 # Some networks have a stem layer vit_channels = 1024 if len(self.lidar_encoder.return_layers) > 4: start_index += 1 self.transformers = nn.ModuleList( [ GPT( n_embd=vit_channels, config=config, # lidar_video=self.lidar_video, lidar_time_frames=lidar_time_frames[i], ) for i in range(4) ] ) self.lidar_channel_to_img = nn.ModuleList( [ nn.Conv2d( self.lidar_encoder.feature_info.info[start_index + i]["num_chs"], vit_channels, kernel_size=1, ) for i in range(4) ] ) self.img_channel_to_lidar = nn.ModuleList( [ nn.Conv2d( vit_channels, self.lidar_encoder.feature_info.info[start_index + i]["num_chs"], kernel_size=1, ) for i in range(4) ] ) self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"] # FPN fusion channel = self.config.bev_features_channels self.relu = nn.ReLU(inplace=True) # top down if self.config.detect_boxes or self.config.use_bev_semantic: self.upsample = nn.Upsample( scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False ) self.upsample2 = nn.Upsample( size=( self.config.lidar_resolution_height // self.config.bev_down_sample_factor, self.config.lidar_resolution_width // self.config.bev_down_sample_factor, ), mode="bilinear", align_corners=False, ) self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1) self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1) # lateral self.c5_conv = nn.Conv2d( self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1) ) def top_down(self, x): p5 = self.relu(self.c5_conv(x)) p4 = self.relu(self.up_conv5(self.upsample(p5))) p3 = self.relu(self.up_conv4(self.upsample2(p4))) return p3 def forward(self, image, lidar): """ Image + LiDAR feature fusion using transformers Args: image_list (list): list of input images lidar_list (list): list of input LiDAR BEV """ image_features, lidar_features = image, lidar # Generate an iterator for all the layers in the network that one can loop through. lidar_layers = iter(self.lidar_encoder.items()) # Stem layer. # In some architectures the stem is not a return layer, so we need to skip it. if len(self.lidar_encoder.return_layers) > 4: lidar_features = self.forward_layer_block( lidar_layers, self.lidar_encoder.return_layers, lidar_features ) # Loop through the 4 blocks of the network. image_features = self.image_encoder(image_features)[0] for i in range(4): lidar_features = self.forward_layer_block( lidar_layers, self.lidar_encoder.return_layers, lidar_features ) image_features, lidar_features = self.fuse_features(image_features, lidar_features, i) if self.config.detect_boxes or self.config.use_bev_semantic: x4 = lidar_features # image_feature_grid = None # if self.config.use_semantic or self.config.use_depth: # image_feature_grid = image_features if self.config.transformer_decoder_join: fused_features = lidar_features else: image_features = self.global_pool_img(image_features) image_features = torch.flatten(image_features, 1) lidar_features = self.global_pool_lidar(lidar_features) lidar_features = torch.flatten(lidar_features, 1) if self.config.add_features: lidar_features = self.lidar_to_img_features_end(lidar_features) fused_features = image_features + lidar_features else: fused_features = torch.cat((image_features, lidar_features), dim=1) if self.config.detect_boxes or self.config.use_bev_semantic: features = self.top_down(x4) else: features = None return features, fused_features, image_features def forward_layer_block(self, layers, return_layers, features, if_ckpt=False): """ Run one forward pass to a block of layers from a TIMM neural network and returns the result. Advances the whole network by just one block :param layers: Iterator starting at the current layer block :param return_layers: TIMM dictionary describing at which intermediate layers features are returned. :param features: Input features :return: Processed features """ for name, module in layers: if if_ckpt: features = checkpoint(module, features) else: features = module(features) if name in return_layers: break return features def fuse_features(self, image_features, lidar_features, layer_idx): """ Perform a TransFuser feature fusion block using a Transformer module. :param image_features: Features from the image branch :param lidar_features: Features from the LiDAR branch :param layer_idx: Transformer layer index. :return: image_features and lidar_features with added features from the other branch. """ image_embd_layer = self.avgpool_img(image_features) lidar_embd_layer = self.avgpool_lidar(lidar_features) lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer) image_features_layer, lidar_features_layer = self.transformers[layer_idx]( image_embd_layer, lidar_embd_layer ) lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer) image_features_layer = F.interpolate( image_features_layer, size=(image_features.shape[2], image_features.shape[3]), mode="bilinear", align_corners=False, ) lidar_features_layer = F.interpolate( lidar_features_layer, size=(lidar_features.shape[2], lidar_features.shape[3]), mode="bilinear", align_corners=False, ) image_features = image_features + image_features_layer lidar_features = lidar_features + lidar_features_layer return image_features, lidar_features