import torch import torch.nn as nn from llava.model.multimodal_encoder.three_detr_model.models.transformer import ( TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer, ) from torch.nn.utils.rnn import pad_sequence from llava.model.multimodal_encoder.mask3d_model.position_embedding import ( PositionEmbeddingCoordsSine, ) from torch.nn.init import xavier_uniform_ class SimpleBBoxHead(nn.Module): def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 1024, ): super().__init__() self.activation = nn.ReLU() # # round up to the nearest multiple of 4 # new_vision_feat_dim_in = (vision_feat_dim_in + 3) // 4 * 4 # self.vision_projection_mlp = nn.Sequential( # nn.Linear(vision_feat_dim_in, new_vision_feat_dim_in), # self.activation, # nn.Linear(new_vision_feat_dim_in, new_vision_feat_dim_in), # self.activation, # nn.Linear(new_vision_feat_dim_in, new_vision_feat_dim_in), # ) # encoder_layer = TransformerEncoderLayer( # d_model=new_vision_feat_dim_in, # nhead=4, # dim_feedforward=dim_feedforward, # dropout=0.0, # activation="relu", # normalize_before=False, # ) # self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=1) self.box_mlp = nn.Sequential( nn.Linear(vision_feat_dim_in * num_vision_feat + lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, 6), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents] Returns: _type_: _description_ """ # pre_encoder_vision_feat = self.vision_projection_mlp( # vision_features_before_mm_projection # ) # (B, num_latents, new_vision_feat_dim_in) # # get padding mask by checking where zero vectors are # src_key_padding_mask = vision_features_before_mm_projection.eq(0).all( # dim=-1 # ) # (B, num_latents) # # nn.MultiHeadAttention in encoder expects npoints x batch x channel features # # note that vision_features_before_mm_projection already contains positional embeddings # _, encoder_output, _ = self.encoder( # src=pre_encoder_vision_feat.permute(1, 0, 2), # src_key_padding_mask=src_key_padding_mask, # ) # [num_latents, B, d_latents] # encoder_output = encoder_output.permute(1, 0, 2) # [B, num_latents, d_latents] bbox_preds = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): # vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,) vision_feat = vision_features_before_mm_projection[batch_idx].flatten() # (1024 * 96,) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (D,) concat_feat = torch.cat((vision_feat, langauge_feat), dim=-1) bbox_pred = self.box_mlp(concat_feat) bbox_preds.append(bbox_pred) bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6) return bbox_preds class BBoxHead(nn.Module): """A simple MLP head for bounding box regression""" def __init__(self, lm_feat_dim_in: int, vision_feat_dim_in: int, dim_feedforward: int = 128): super().__init__() encoder_layer = TransformerEncoderLayer( d_model=vision_feat_dim_in, nhead=4, dim_feedforward=dim_feedforward, dropout=0.0, activation="relu", # normalize_before=False, ) self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=4) decoder_layer = TransformerDecoderLayer( d_model=vision_feat_dim_in, nhead=4, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=False, ) self.decoder = TransformerDecoder( decoder_layer=decoder_layer, num_layers=4, return_intermediate=False ) self.language_projection = nn.Sequential( nn.Linear(lm_feat_dim_in, vision_feat_dim_in), # nn.ReLU(), # nn.Linear(256, 256), # nn.ReLU(), # nn.Linear(256, vision_feat_dim_in), ) self.activation = nn.GELU() self.box_mlp = nn.Sequential( nn.Linear(vision_feat_dim_in, 256), self.activation, nn.Linear(256, 256), self.activation, nn.Linear(256, 6), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents] Returns: _type_: _description_ """ # nn.MultiHeadAttention in encoder expects npoints x batch x channel features # note that vision_features_before_mm_projection already contains positional embeddings _, encoder_output, _ = self.encoder( src=vision_features_before_mm_projection.permute(1, 0, 2) ) # [num_latents, B, d_latents] # we need to mask out the attention between different ground tokens # because each ground token is independent of each other # Pad the list of hidden states to the longest sample grd_token_hidden_states_padded = pad_sequence( grd_token_hidden_states_list, batch_first=True, padding_value=0 ) # (B, N', D), where N' is the number of ground tokens in the sample with the most ground tokens in the batch # Create a mask for the padding tokens, True means there will be no attention tgt_key_padding_mask = grd_token_hidden_states_padded.eq(0).all(dim=-1) # (B, N') tgt_mask = self.create_diag_mask(grd_token_hidden_states_padded.shape[1]).to( grd_token_hidden_states_padded.device ) # (N', N') # decoder expects: npoints x batch x channel language_projected = self.language_projection( grd_token_hidden_states_padded ) # (B, N', d_latents) decoder_output, decoder_attns = self.decoder( tgt=language_projected.permute(1, 0, 2), # [N', B, d_latents] memory=encoder_output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, ) # output, attns, output shape: [N', B, d_latents] # predict the bounding boxes bbox_preds = self.box_mlp(decoder_output) # (N', B, 6) # flatten the first two dimensions, remove padded locations bbox_preds = bbox_preds.permute(1, 0, 2) # (B, N', 6) # discard the padded locations bbox_preds = bbox_preds[~tgt_key_padding_mask] # (num_boxes_in_batch, 6) return bbox_preds @staticmethod def create_diag_mask(size): # for transformer, a binary ``True`` value indicates that the corresponding position is NOT # allowed to attend, while a ``False`` value indicates that the position is allowed to attend. mask = torch.ones(size, size, dtype=torch.bool) mask.fill_diagonal_(0) return mask class BBoxHeadForGroundTruthBboxRegressionV2(nn.Module): """A simple MLP head for bounding box regression""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 1024, ): super().__init__() # round up to the nearest multiple of 4 new_vision_feat_dim_in = (vision_feat_dim_in + 3) // 4 * 4 self.vision_projection_mlp = nn.Sequential( nn.Linear(vision_feat_dim_in, new_vision_feat_dim_in), ) self.activation = nn.ReLU() self.language_projection_mlp = nn.Sequential( nn.Linear(lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, new_vision_feat_dim_in), ) encoder_layer = TransformerEncoderLayer( d_model=new_vision_feat_dim_in, nhead=4, dim_feedforward=dim_feedforward, dropout=0.0, activation="relu", normalize_before=True, ) self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=4) self.activation = nn.ReLU() self.box_mlp = nn.Sequential( nn.Linear(new_vision_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, 6), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents] Returns: _type_: _description_ """ bbox_preds = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): # vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,) vision_feat = vision_features_before_mm_projection[batch_idx].unsqueeze( 0 ) # (1, num_vision_feat, vision_feat_dim_in) vision_feat = self.vision_projection_mlp( vision_feat ) # (1, num_vision_feat, new_vision_feat_dim_in) for i in range(len(grd_token_hidden_states)): language_feat = grd_token_hidden_states[i] # (D,) language_feat = self.language_projection_mlp( language_feat ) # (new_vision_feat_dim_in,) language_feat = language_feat[None, None, :] # (1, 1, new_vision_feat_dim_in) language_concat_vision_feat = torch.cat( (language_feat, vision_feat), dim=1 ) # (1, 1 + new_vision_feat_dim_in, new_vision_feat_dim_in) # # nn.MultiHeadAttention in encoder expects seqlen x batch x channel features _, encoder_output, _ = self.encoder( src=language_concat_vision_feat.permute(1, 0, 2) ) # [1 + new_vision_feat_dim_in, 1, new_vision_feat_dim_in] fused_feat = encoder_output[0][0] # (new_vision_feat_dim_in,) bbox_pred = self.box_mlp(fused_feat) # (6,) bbox_preds.append(bbox_pred) bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6) return bbox_preds class BBoxHeadForGroundTruthBboxRegressionV1(nn.Module): """A simple MLP head for bounding box regression""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 1024, ): super().__init__() self.bbox_pos_embedding = PositionEmbeddingCoordsSine( d_pos=10, pos_type="fourier", ) self.obj_class_embedding = nn.Embedding( 265, 64 ) # 265 classes in ScanNet, learnable embedding size 64 self.activation = nn.ReLU() encoder_layer = TransformerEncoderLayer( d_model=10 * 2 + 64, nhead=4, dim_feedforward=dim_feedforward, dropout=0.0, activation="relu", normalize_before=False, ) self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2) self.box_mlp = nn.Sequential( nn.Linear((10 * 2 + 64) * num_vision_feat + lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, 6), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1] Returns: _type_: _description_ """ # get bbox position embeddings # xyz is batch x npoints x 3 min_xyz_pos_embeddings = self.bbox_pos_embedding( xyz=vision_features_before_mm_projection[:, :, 0:3] ) # (B, 96, num_latents) min_xyz_pos_embeddings = min_xyz_pos_embeddings.permute(0, 2, 1) # (B, num_latents, 96) max_xyz_pos_embeddings = self.bbox_pos_embedding( xyz=vision_features_before_mm_projection[:, :, 3:6] ) # (B, 96, num_latents) max_xyz_pos_embeddings = max_xyz_pos_embeddings.permute(0, 2, 1) # (B, num_latents, 96) # get the object class embeddings obj_classes = vision_features_before_mm_projection[:, :, -1].long() obj_class_embeddings = self.obj_class_embedding(obj_classes) # (B, num_latents, 64) vision_feat = torch.concat( (min_xyz_pos_embeddings, max_xyz_pos_embeddings, obj_class_embeddings), dim=-1 ) # (B, num_vision_feat, 96*2+64) # get padding mask by checking where zero vectors are src_key_padding_mask = vision_features_before_mm_projection.eq(0).all( dim=-1 ) # (B, num_latents) # nn.MultiHeadAttention in encoder expects npoints x batch x channel features # note that vision_features_before_mm_projection already contains positional embeddings _, encoder_output, _ = self.encoder( src=vision_feat.permute(1, 0, 2), src_key_padding_mask=src_key_padding_mask, ) # [num_latents, B, d_latents] encoder_output = encoder_output.permute(1, 0, 2) # [B, num_latents, d_latents] bbox_preds = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,) # vision_feat = vision_features_before_mm_projection[batch_idx].flatten() # (1024 * 96,) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (D,) concat_feat = torch.cat((vision_feat, langauge_feat), dim=-1) bbox_pred = self.box_mlp(concat_feat) bbox_preds.append(bbox_pred) bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6) return bbox_preds class BBoxHeadForGroundTruthBboxSelectionTransformerLateFusion(nn.Module): """A simple MLP head for bounding box selection, for training on CE loss""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 2048, ): super().__init__() class_emb_dim = 256 pos_emb_dim = 16 self.bbox_pos_embedding = PositionEmbeddingCoordsSine( d_pos=pos_emb_dim, pos_type="fourier", ) self.obj_class_embedding = nn.Embedding( 265, class_emb_dim ) # 265 classes in ScanNet, learnable embedding size 64 self.activation = nn.GELU() self.language_vision_fusion_mlp = nn.Sequential( nn.Linear(class_emb_dim + pos_emb_dim + lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), ) # encoder_layer = TransformerEncoderLayer( # d_model=dim_feedforward, # nhead=8, # dim_feedforward=dim_feedforward, # dropout=0.0, # activation="relu", # normalize_before=True, # ) # self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2) encoder_layer = nn.TransformerEncoderLayer( d_model=class_emb_dim + pos_emb_dim, nhead=8, dim_feedforward=dim_feedforward, norm_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2) self.scoring_mlp = nn.Sequential( nn.Linear(dim_feedforward, 1), ) self._reset_parameters() def _reset_parameters(self): r"""Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: xavier_uniform_(p) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1] Returns: _type_: _description_ """ # get bbox position embeddings # xyz is batch x npoints x 3 # get the center of the bbox bbox_center = ( vision_features_before_mm_projection[:, :, 0:3] + vision_features_before_mm_projection[:, :, 3:6] ) / 2.0 bbox_pos_embeddings = self.bbox_pos_embedding( xyz=bbox_center ) # (B, pos_emb_dim, num_latents) bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim) # get the object class embeddings obj_classes = vision_features_before_mm_projection[:, :, -1].long() obj_class_embeddings = self.obj_class_embedding( obj_classes ) # (B, num_latents, class_emb_dim) vision_feat = torch.concat( (obj_class_embeddings, bbox_pos_embeddings), dim=-1 ) # (B, class_emb_dim + pos_emb_dim, class_emb_dim) # get padding mask by checking where zero vectors are src_key_padding_mask = vision_features_before_mm_projection.eq(0).all( dim=-1 ) # (B, num_latents) bbox_scores = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): # vision_feat = vision_features_before_mm_projection[ # batch_idx # ] # (num_latents, d_latents) cur_vision_feat = vision_feat[batch_idx] # (num_latents, class_emb_dim) cur_vision_feat = cur_vision_feat.unsqueeze(0) # (1, num_latents, class_emb_dim) cur_vision_feat = cur_vision_feat.permute(1, 0, 2) # (num_latents, 1, class_emb_dim) # nn.MultiHeadAttention in encoder expects seqlen x batch x channel features cur_encoder_output = self.encoder( cur_vision_feat, src_key_padding_mask=src_key_padding_mask[batch_idx].unsqueeze(0), ) # [num_latents, 1, class_emb_dim] cur_encoder_output = cur_encoder_output.squeeze(1) # (num_latents, class_emb_dim) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in,) # concat the language feat with each vision feat langauge_feat_repeat = langauge_feat.repeat( cur_encoder_output.shape[0], 1 ) # (num_latents, lm_feat_dim_in) concat_feat = torch.cat( (cur_encoder_output, langauge_feat_repeat), dim=-1 ) # (num_latents, class_emb_dim + lm_feat_dim_in) fused_feat = self.language_vision_fusion_mlp( concat_feat ) # (num_latents, dim_feedforward) bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,) bbox_scores.append(bbox_score) # (num_latents) bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents) return bbox_scores class BBoxHeadForGroundTruthBboxSelectionTransformerEarlyFusion(nn.Module): """A simple MLP head for bounding box selection, for training on CE loss""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 2048, ): super().__init__() class_emb_dim = 256 pos_emb_dim = 16 self.bbox_pos_embedding = PositionEmbeddingCoordsSine( d_pos=pos_emb_dim, pos_type="fourier", ) self.obj_class_embedding = nn.Embedding( 265, class_emb_dim ) # 265 classes in ScanNet, learnable embedding size class_emb_dim self.activation = nn.GELU() self.language_projection_mlp = nn.Sequential( nn.Linear(lm_feat_dim_in, class_emb_dim), ) # encoder_layer = TransformerEncoderLayer( # d_model=dim_feedforward, # nhead=8, # dim_feedforward=dim_feedforward, # dropout=0.0, # activation="relu", # normalize_before=True, # ) # self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2) encoder_layer = nn.TransformerEncoderLayer( d_model=class_emb_dim + pos_emb_dim + class_emb_dim, nhead=8, dim_feedforward=dim_feedforward, norm_first=False, ) self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2) self.scoring_mlp = nn.Sequential( nn.Linear(class_emb_dim + pos_emb_dim + class_emb_dim, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, 1), ) self._reset_parameters() def _reset_parameters(self): r"""Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: xavier_uniform_(p) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1] Returns: _type_: _description_ """ # get bbox position embeddings # xyz is batch x npoints x 3 # get the center of the bbox bbox_center = ( vision_features_before_mm_projection[:, :, 0:3] + vision_features_before_mm_projection[:, :, 3:6] ) / 2.0 bbox_pos_embeddings = self.bbox_pos_embedding( xyz=bbox_center ) # (B, pos_emb_dim, num_latents) bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim) # get the object class embeddings obj_classes = vision_features_before_mm_projection[:, :, -1].long() obj_class_embeddings = self.obj_class_embedding( obj_classes ) # (B, num_latents, class_emb_dim) vision_feat = torch.concat( (obj_class_embeddings, bbox_pos_embeddings), dim=-1 ) # (B, num_latents, class_emb_dim) # get padding mask by checking where zero vectors are src_key_padding_mask = vision_features_before_mm_projection.eq(0).all( dim=-1 ) # (B, num_latents) bbox_scores = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): # vision_feat = vision_features_before_mm_projection[ # batch_idx # ] # (num_latents, d_latents) cur_vision_feat = vision_feat[batch_idx] # (num_latents, class_emb_dim + pos_emb_dim) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in,) langauge_feat = self.language_projection_mlp(langauge_feat) # (lm_feat_dim_in,) langauge_feat_repeat = langauge_feat.repeat( cur_vision_feat.shape[0], 1 ) # (num_latents, lm_feat_dim_in) concat_feat = torch.cat( (cur_vision_feat, langauge_feat_repeat), dim=-1 ) # (num_latents, class_emb_dim + pos_emb_dim + lm_feat_dim_in) concat_feat = concat_feat.unsqueeze( 0 ) # (1, num_latents, class_emb_dim + pos_emb_dim + lm_feat_dim_in) concat_feat = concat_feat.permute( 1, 0, 2 ) # (num_latents, 1, class_emb_dim + pos_emb_dim + lm_feat_dim_in) # nn.MultiHeadAttention in encoder expects seqlen x batch x channel features cur_encoder_output = self.encoder( concat_feat, src_key_padding_mask=src_key_padding_mask[batch_idx].unsqueeze(0), ) # [num_latents, 1, class_emb_dim + pos_emb_dim + lm_feat_dim_in] cur_encoder_output = cur_encoder_output.squeeze(1) # (num_latents, class_emb_dim) bbox_score = self.scoring_mlp(cur_encoder_output).squeeze(-1) # (num_latents,) bbox_scores.append(bbox_score) # (num_latents) bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents) return bbox_scores class BBoxHeadForGroundTruthBboxSelectionMLPPosEmbAndFusionOneHot(nn.Module): """A simple MLP head for bounding box selection, for training on CE loss""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 4096, ): super().__init__() self.class_emb_dim = class_emb_dim = 265 # 265 classes in ScanRefer pos_emb_dim = 16 self.bbox_pos_embedding = PositionEmbeddingCoordsSine( d_pos=pos_emb_dim, pos_type="fourier", ) self.activation = nn.ReLU() self.language_vision_fusion_mlp = nn.Sequential( nn.Linear(class_emb_dim + pos_emb_dim + lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), ) self.scoring_mlp = nn.Sequential( nn.Linear(dim_feedforward, 1), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1] Returns: _type_: _description_ """ # get bbox position embeddings # xyz is batch x npoints x 3 # get the center of the bbox bbox_center = ( vision_features_before_mm_projection[:, :, 0:3] + vision_features_before_mm_projection[:, :, 3:6] ) / 2.0 bbox_pos_embeddings = self.bbox_pos_embedding( xyz=bbox_center ) # (B, pos_emb_dim, num_latents) bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim) # get the object class embeddings, one-hot encoding of self.class_emb_dim classes obj_classes = vision_features_before_mm_projection[:, :, -1].long() obj_class_embeddings = torch.eye( self.class_emb_dim, device=vision_features_before_mm_projection.device, dtype=vision_features_before_mm_projection.dtype, )[ obj_classes ] # (B, num_latents, class_emb_dim) vision_feat = torch.concat( (obj_class_embeddings, bbox_pos_embeddings), dim=-1 ) # (B, num_latents, class_emb_dim + pos_emb_dim) # get padding mask by checking where zero vectors are src_key_padding_mask = vision_features_before_mm_projection.eq(0).all( dim=-1 ) # (B, num_latents) # for the padded locations, we set the vision_feat to be zero vision_feat[src_key_padding_mask] = 0 bbox_scores = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): cur_vision_feat = vision_feat[batch_idx] # (num_latents, d_latents) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in),) # concat the language feat with each vision feat langauge_feat_repeat = langauge_feat.repeat( cur_vision_feat.shape[0], 1 ) # (num_latents, lm_feat_dim_in) concat_feat = torch.cat( (cur_vision_feat, langauge_feat_repeat), dim=-1 ) # (num_latents, d_latents + lm_feat_dim_in) fused_feat = self.language_vision_fusion_mlp(concat_feat) bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,) bbox_scores.append(bbox_score) # (num_latents) bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents) return bbox_scores class BBoxHeadForGroundTruthBboxSelectionMLPFusionBoxCoordsAndClassID(nn.Module): """A simple MLP head for bounding box selection, for training on CE loss""" def __init__( self, lm_feat_dim_in: int, vision_feat_dim_in: int, num_vision_feat: int, dim_feedforward: int = 1024, ): super().__init__() self.activation = nn.ReLU() self.language_vision_fusion_mlp = nn.Sequential( nn.Linear(vision_feat_dim_in + lm_feat_dim_in, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), self.activation, nn.Linear(dim_feedforward, dim_feedforward), ) self.scoring_mlp = nn.Sequential( nn.Linear(dim_feedforward, 1), ) def forward( self, grd_token_hidden_states_list: list[torch.Tensor], vision_features_before_mm_projection: torch.Tensor, ): """_summary_ Args: grd_token_hidden_states_list (list[torch.Tensor]): each element in this list contains the hidden states of the ground tokens in one sample, list[[varying N, D]] vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1] Returns: _type_: _description_ """ bbox_scores = [] for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list): vision_feat = vision_features_before_mm_projection[ batch_idx ] # (num_latents, d_latents) for i in range(len(grd_token_hidden_states)): langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in),) # concat the language feat with each vision feat langauge_feat_repeat = langauge_feat.repeat( vision_feat.shape[0], 1 ) # (num_latents, lm_feat_dim_in) concat_feat = torch.cat( (vision_feat, langauge_feat_repeat), dim=-1 ) # (num_latents, d_latents + lm_feat_dim_in) fused_feat = self.language_vision_fusion_mlp(concat_feat) bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,) bbox_scores.append(bbox_score) # (num_latents) bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents) return bbox_scores