3D-GRAND / llava /model /bbox_head.py
jedyang97's picture
initial demo
947767a
raw
history blame
34.7 kB
import torch
import torch.nn as nn
from llava.model.multimodal_encoder.three_detr_model.models.transformer import (
TransformerEncoder,
TransformerEncoderLayer,
TransformerDecoder,
TransformerDecoderLayer,
)
from torch.nn.utils.rnn import pad_sequence
from llava.model.multimodal_encoder.mask3d_model.position_embedding import (
PositionEmbeddingCoordsSine,
)
from torch.nn.init import xavier_uniform_
class SimpleBBoxHead(nn.Module):
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 1024,
):
super().__init__()
self.activation = nn.ReLU()
# # round up to the nearest multiple of 4
# new_vision_feat_dim_in = (vision_feat_dim_in + 3) // 4 * 4
# self.vision_projection_mlp = nn.Sequential(
# nn.Linear(vision_feat_dim_in, new_vision_feat_dim_in),
# self.activation,
# nn.Linear(new_vision_feat_dim_in, new_vision_feat_dim_in),
# self.activation,
# nn.Linear(new_vision_feat_dim_in, new_vision_feat_dim_in),
# )
# encoder_layer = TransformerEncoderLayer(
# d_model=new_vision_feat_dim_in,
# nhead=4,
# dim_feedforward=dim_feedforward,
# dropout=0.0,
# activation="relu",
# normalize_before=False,
# )
# self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=1)
self.box_mlp = nn.Sequential(
nn.Linear(vision_feat_dim_in * num_vision_feat + lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, 6),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents]
Returns:
_type_: _description_
"""
# pre_encoder_vision_feat = self.vision_projection_mlp(
# vision_features_before_mm_projection
# ) # (B, num_latents, new_vision_feat_dim_in)
# # get padding mask by checking where zero vectors are
# src_key_padding_mask = vision_features_before_mm_projection.eq(0).all(
# dim=-1
# ) # (B, num_latents)
# # nn.MultiHeadAttention in encoder expects npoints x batch x channel features
# # note that vision_features_before_mm_projection already contains positional embeddings
# _, encoder_output, _ = self.encoder(
# src=pre_encoder_vision_feat.permute(1, 0, 2),
# src_key_padding_mask=src_key_padding_mask,
# ) # [num_latents, B, d_latents]
# encoder_output = encoder_output.permute(1, 0, 2) # [B, num_latents, d_latents]
bbox_preds = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
# vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,)
vision_feat = vision_features_before_mm_projection[batch_idx].flatten() # (1024 * 96,)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (D,)
concat_feat = torch.cat((vision_feat, langauge_feat), dim=-1)
bbox_pred = self.box_mlp(concat_feat)
bbox_preds.append(bbox_pred)
bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6)
return bbox_preds
class BBoxHead(nn.Module):
"""A simple MLP head for bounding box regression"""
def __init__(self, lm_feat_dim_in: int, vision_feat_dim_in: int, dim_feedforward: int = 128):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model=vision_feat_dim_in,
nhead=4,
dim_feedforward=dim_feedforward,
dropout=0.0,
activation="relu",
# normalize_before=False,
)
self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=4)
decoder_layer = TransformerDecoderLayer(
d_model=vision_feat_dim_in,
nhead=4,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=False,
)
self.decoder = TransformerDecoder(
decoder_layer=decoder_layer, num_layers=4, return_intermediate=False
)
self.language_projection = nn.Sequential(
nn.Linear(lm_feat_dim_in, vision_feat_dim_in),
# nn.ReLU(),
# nn.Linear(256, 256),
# nn.ReLU(),
# nn.Linear(256, vision_feat_dim_in),
)
self.activation = nn.GELU()
self.box_mlp = nn.Sequential(
nn.Linear(vision_feat_dim_in, 256),
self.activation,
nn.Linear(256, 256),
self.activation,
nn.Linear(256, 6),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents]
Returns:
_type_: _description_
"""
# nn.MultiHeadAttention in encoder expects npoints x batch x channel features
# note that vision_features_before_mm_projection already contains positional embeddings
_, encoder_output, _ = self.encoder(
src=vision_features_before_mm_projection.permute(1, 0, 2)
) # [num_latents, B, d_latents]
# we need to mask out the attention between different ground tokens
# because each ground token is independent of each other
# Pad the list of hidden states to the longest sample
grd_token_hidden_states_padded = pad_sequence(
grd_token_hidden_states_list, batch_first=True, padding_value=0
) # (B, N', D), where N' is the number of ground tokens in the sample with the most ground tokens in the batch
# Create a mask for the padding tokens, True means there will be no attention
tgt_key_padding_mask = grd_token_hidden_states_padded.eq(0).all(dim=-1) # (B, N')
tgt_mask = self.create_diag_mask(grd_token_hidden_states_padded.shape[1]).to(
grd_token_hidden_states_padded.device
) # (N', N')
# decoder expects: npoints x batch x channel
language_projected = self.language_projection(
grd_token_hidden_states_padded
) # (B, N', d_latents)
decoder_output, decoder_attns = self.decoder(
tgt=language_projected.permute(1, 0, 2), # [N', B, d_latents]
memory=encoder_output,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
) # output, attns, output shape: [N', B, d_latents]
# predict the bounding boxes
bbox_preds = self.box_mlp(decoder_output) # (N', B, 6)
# flatten the first two dimensions, remove padded locations
bbox_preds = bbox_preds.permute(1, 0, 2) # (B, N', 6)
# discard the padded locations
bbox_preds = bbox_preds[~tgt_key_padding_mask] # (num_boxes_in_batch, 6)
return bbox_preds
@staticmethod
def create_diag_mask(size):
# for transformer, a binary ``True`` value indicates that the corresponding position is NOT
# allowed to attend, while a ``False`` value indicates that the position is allowed to attend.
mask = torch.ones(size, size, dtype=torch.bool)
mask.fill_diagonal_(0)
return mask
class BBoxHeadForGroundTruthBboxRegressionV2(nn.Module):
"""A simple MLP head for bounding box regression"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 1024,
):
super().__init__()
# round up to the nearest multiple of 4
new_vision_feat_dim_in = (vision_feat_dim_in + 3) // 4 * 4
self.vision_projection_mlp = nn.Sequential(
nn.Linear(vision_feat_dim_in, new_vision_feat_dim_in),
)
self.activation = nn.ReLU()
self.language_projection_mlp = nn.Sequential(
nn.Linear(lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, new_vision_feat_dim_in),
)
encoder_layer = TransformerEncoderLayer(
d_model=new_vision_feat_dim_in,
nhead=4,
dim_feedforward=dim_feedforward,
dropout=0.0,
activation="relu",
normalize_before=True,
)
self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=4)
self.activation = nn.ReLU()
self.box_mlp = nn.Sequential(
nn.Linear(new_vision_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, 6),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, d_latents]
Returns:
_type_: _description_
"""
bbox_preds = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
# vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,)
vision_feat = vision_features_before_mm_projection[batch_idx].unsqueeze(
0
) # (1, num_vision_feat, vision_feat_dim_in)
vision_feat = self.vision_projection_mlp(
vision_feat
) # (1, num_vision_feat, new_vision_feat_dim_in)
for i in range(len(grd_token_hidden_states)):
language_feat = grd_token_hidden_states[i] # (D,)
language_feat = self.language_projection_mlp(
language_feat
) # (new_vision_feat_dim_in,)
language_feat = language_feat[None, None, :] # (1, 1, new_vision_feat_dim_in)
language_concat_vision_feat = torch.cat(
(language_feat, vision_feat), dim=1
) # (1, 1 + new_vision_feat_dim_in, new_vision_feat_dim_in)
# # nn.MultiHeadAttention in encoder expects seqlen x batch x channel features
_, encoder_output, _ = self.encoder(
src=language_concat_vision_feat.permute(1, 0, 2)
) # [1 + new_vision_feat_dim_in, 1, new_vision_feat_dim_in]
fused_feat = encoder_output[0][0] # (new_vision_feat_dim_in,)
bbox_pred = self.box_mlp(fused_feat) # (6,)
bbox_preds.append(bbox_pred)
bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6)
return bbox_preds
class BBoxHeadForGroundTruthBboxRegressionV1(nn.Module):
"""A simple MLP head for bounding box regression"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 1024,
):
super().__init__()
self.bbox_pos_embedding = PositionEmbeddingCoordsSine(
d_pos=10,
pos_type="fourier",
)
self.obj_class_embedding = nn.Embedding(
265, 64
) # 265 classes in ScanNet, learnable embedding size 64
self.activation = nn.ReLU()
encoder_layer = TransformerEncoderLayer(
d_model=10 * 2 + 64,
nhead=4,
dim_feedforward=dim_feedforward,
dropout=0.0,
activation="relu",
normalize_before=False,
)
self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
self.box_mlp = nn.Sequential(
nn.Linear((10 * 2 + 64) * num_vision_feat + lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, 6),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1]
Returns:
_type_: _description_
"""
# get bbox position embeddings
# xyz is batch x npoints x 3
min_xyz_pos_embeddings = self.bbox_pos_embedding(
xyz=vision_features_before_mm_projection[:, :, 0:3]
) # (B, 96, num_latents)
min_xyz_pos_embeddings = min_xyz_pos_embeddings.permute(0, 2, 1) # (B, num_latents, 96)
max_xyz_pos_embeddings = self.bbox_pos_embedding(
xyz=vision_features_before_mm_projection[:, :, 3:6]
) # (B, 96, num_latents)
max_xyz_pos_embeddings = max_xyz_pos_embeddings.permute(0, 2, 1) # (B, num_latents, 96)
# get the object class embeddings
obj_classes = vision_features_before_mm_projection[:, :, -1].long()
obj_class_embeddings = self.obj_class_embedding(obj_classes) # (B, num_latents, 64)
vision_feat = torch.concat(
(min_xyz_pos_embeddings, max_xyz_pos_embeddings, obj_class_embeddings), dim=-1
) # (B, num_vision_feat, 96*2+64)
# get padding mask by checking where zero vectors are
src_key_padding_mask = vision_features_before_mm_projection.eq(0).all(
dim=-1
) # (B, num_latents)
# nn.MultiHeadAttention in encoder expects npoints x batch x channel features
# note that vision_features_before_mm_projection already contains positional embeddings
_, encoder_output, _ = self.encoder(
src=vision_feat.permute(1, 0, 2),
src_key_padding_mask=src_key_padding_mask,
) # [num_latents, B, d_latents]
encoder_output = encoder_output.permute(1, 0, 2) # [B, num_latents, d_latents]
bbox_preds = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
vision_feat = encoder_output[batch_idx].flatten() # (1024 * 96,)
# vision_feat = vision_features_before_mm_projection[batch_idx].flatten() # (1024 * 96,)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (D,)
concat_feat = torch.cat((vision_feat, langauge_feat), dim=-1)
bbox_pred = self.box_mlp(concat_feat)
bbox_preds.append(bbox_pred)
bbox_preds = torch.stack(bbox_preds, dim=0) # (N, 6)
return bbox_preds
class BBoxHeadForGroundTruthBboxSelectionTransformerLateFusion(nn.Module):
"""A simple MLP head for bounding box selection, for training on CE loss"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 2048,
):
super().__init__()
class_emb_dim = 256
pos_emb_dim = 16
self.bbox_pos_embedding = PositionEmbeddingCoordsSine(
d_pos=pos_emb_dim,
pos_type="fourier",
)
self.obj_class_embedding = nn.Embedding(
265, class_emb_dim
) # 265 classes in ScanNet, learnable embedding size 64
self.activation = nn.GELU()
self.language_vision_fusion_mlp = nn.Sequential(
nn.Linear(class_emb_dim + pos_emb_dim + lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
)
# encoder_layer = TransformerEncoderLayer(
# d_model=dim_feedforward,
# nhead=8,
# dim_feedforward=dim_feedforward,
# dropout=0.0,
# activation="relu",
# normalize_before=True,
# )
# self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
encoder_layer = nn.TransformerEncoderLayer(
d_model=class_emb_dim + pos_emb_dim,
nhead=8,
dim_feedforward=dim_feedforward,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
self.scoring_mlp = nn.Sequential(
nn.Linear(dim_feedforward, 1),
)
self._reset_parameters()
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1]
Returns:
_type_: _description_
"""
# get bbox position embeddings
# xyz is batch x npoints x 3
# get the center of the bbox
bbox_center = (
vision_features_before_mm_projection[:, :, 0:3]
+ vision_features_before_mm_projection[:, :, 3:6]
) / 2.0
bbox_pos_embeddings = self.bbox_pos_embedding(
xyz=bbox_center
) # (B, pos_emb_dim, num_latents)
bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim)
# get the object class embeddings
obj_classes = vision_features_before_mm_projection[:, :, -1].long()
obj_class_embeddings = self.obj_class_embedding(
obj_classes
) # (B, num_latents, class_emb_dim)
vision_feat = torch.concat(
(obj_class_embeddings, bbox_pos_embeddings), dim=-1
) # (B, class_emb_dim + pos_emb_dim, class_emb_dim)
# get padding mask by checking where zero vectors are
src_key_padding_mask = vision_features_before_mm_projection.eq(0).all(
dim=-1
) # (B, num_latents)
bbox_scores = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
# vision_feat = vision_features_before_mm_projection[
# batch_idx
# ] # (num_latents, d_latents)
cur_vision_feat = vision_feat[batch_idx] # (num_latents, class_emb_dim)
cur_vision_feat = cur_vision_feat.unsqueeze(0) # (1, num_latents, class_emb_dim)
cur_vision_feat = cur_vision_feat.permute(1, 0, 2) # (num_latents, 1, class_emb_dim)
# nn.MultiHeadAttention in encoder expects seqlen x batch x channel features
cur_encoder_output = self.encoder(
cur_vision_feat,
src_key_padding_mask=src_key_padding_mask[batch_idx].unsqueeze(0),
) # [num_latents, 1, class_emb_dim]
cur_encoder_output = cur_encoder_output.squeeze(1) # (num_latents, class_emb_dim)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in,)
# concat the language feat with each vision feat
langauge_feat_repeat = langauge_feat.repeat(
cur_encoder_output.shape[0], 1
) # (num_latents, lm_feat_dim_in)
concat_feat = torch.cat(
(cur_encoder_output, langauge_feat_repeat), dim=-1
) # (num_latents, class_emb_dim + lm_feat_dim_in)
fused_feat = self.language_vision_fusion_mlp(
concat_feat
) # (num_latents, dim_feedforward)
bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,)
bbox_scores.append(bbox_score) # (num_latents)
bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents)
return bbox_scores
class BBoxHeadForGroundTruthBboxSelectionTransformerEarlyFusion(nn.Module):
"""A simple MLP head for bounding box selection, for training on CE loss"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 2048,
):
super().__init__()
class_emb_dim = 256
pos_emb_dim = 16
self.bbox_pos_embedding = PositionEmbeddingCoordsSine(
d_pos=pos_emb_dim,
pos_type="fourier",
)
self.obj_class_embedding = nn.Embedding(
265, class_emb_dim
) # 265 classes in ScanNet, learnable embedding size class_emb_dim
self.activation = nn.GELU()
self.language_projection_mlp = nn.Sequential(
nn.Linear(lm_feat_dim_in, class_emb_dim),
)
# encoder_layer = TransformerEncoderLayer(
# d_model=dim_feedforward,
# nhead=8,
# dim_feedforward=dim_feedforward,
# dropout=0.0,
# activation="relu",
# normalize_before=True,
# )
# self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
encoder_layer = nn.TransformerEncoderLayer(
d_model=class_emb_dim + pos_emb_dim + class_emb_dim,
nhead=8,
dim_feedforward=dim_feedforward,
norm_first=False,
)
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
self.scoring_mlp = nn.Sequential(
nn.Linear(class_emb_dim + pos_emb_dim + class_emb_dim, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, 1),
)
self._reset_parameters()
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1]
Returns:
_type_: _description_
"""
# get bbox position embeddings
# xyz is batch x npoints x 3
# get the center of the bbox
bbox_center = (
vision_features_before_mm_projection[:, :, 0:3]
+ vision_features_before_mm_projection[:, :, 3:6]
) / 2.0
bbox_pos_embeddings = self.bbox_pos_embedding(
xyz=bbox_center
) # (B, pos_emb_dim, num_latents)
bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim)
# get the object class embeddings
obj_classes = vision_features_before_mm_projection[:, :, -1].long()
obj_class_embeddings = self.obj_class_embedding(
obj_classes
) # (B, num_latents, class_emb_dim)
vision_feat = torch.concat(
(obj_class_embeddings, bbox_pos_embeddings), dim=-1
) # (B, num_latents, class_emb_dim)
# get padding mask by checking where zero vectors are
src_key_padding_mask = vision_features_before_mm_projection.eq(0).all(
dim=-1
) # (B, num_latents)
bbox_scores = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
# vision_feat = vision_features_before_mm_projection[
# batch_idx
# ] # (num_latents, d_latents)
cur_vision_feat = vision_feat[batch_idx] # (num_latents, class_emb_dim + pos_emb_dim)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in,)
langauge_feat = self.language_projection_mlp(langauge_feat) # (lm_feat_dim_in,)
langauge_feat_repeat = langauge_feat.repeat(
cur_vision_feat.shape[0], 1
) # (num_latents, lm_feat_dim_in)
concat_feat = torch.cat(
(cur_vision_feat, langauge_feat_repeat), dim=-1
) # (num_latents, class_emb_dim + pos_emb_dim + lm_feat_dim_in)
concat_feat = concat_feat.unsqueeze(
0
) # (1, num_latents, class_emb_dim + pos_emb_dim + lm_feat_dim_in)
concat_feat = concat_feat.permute(
1, 0, 2
) # (num_latents, 1, class_emb_dim + pos_emb_dim + lm_feat_dim_in)
# nn.MultiHeadAttention in encoder expects seqlen x batch x channel features
cur_encoder_output = self.encoder(
concat_feat,
src_key_padding_mask=src_key_padding_mask[batch_idx].unsqueeze(0),
) # [num_latents, 1, class_emb_dim + pos_emb_dim + lm_feat_dim_in]
cur_encoder_output = cur_encoder_output.squeeze(1) # (num_latents, class_emb_dim)
bbox_score = self.scoring_mlp(cur_encoder_output).squeeze(-1) # (num_latents,)
bbox_scores.append(bbox_score) # (num_latents)
bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents)
return bbox_scores
class BBoxHeadForGroundTruthBboxSelectionMLPPosEmbAndFusionOneHot(nn.Module):
"""A simple MLP head for bounding box selection, for training on CE loss"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 4096,
):
super().__init__()
self.class_emb_dim = class_emb_dim = 265 # 265 classes in ScanRefer
pos_emb_dim = 16
self.bbox_pos_embedding = PositionEmbeddingCoordsSine(
d_pos=pos_emb_dim,
pos_type="fourier",
)
self.activation = nn.ReLU()
self.language_vision_fusion_mlp = nn.Sequential(
nn.Linear(class_emb_dim + pos_emb_dim + lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
)
self.scoring_mlp = nn.Sequential(
nn.Linear(dim_feedforward, 1),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1]
Returns:
_type_: _description_
"""
# get bbox position embeddings
# xyz is batch x npoints x 3
# get the center of the bbox
bbox_center = (
vision_features_before_mm_projection[:, :, 0:3]
+ vision_features_before_mm_projection[:, :, 3:6]
) / 2.0
bbox_pos_embeddings = self.bbox_pos_embedding(
xyz=bbox_center
) # (B, pos_emb_dim, num_latents)
bbox_pos_embeddings = bbox_pos_embeddings.permute(0, 2, 1) # (B, num_latents, pos_emb_dim)
# get the object class embeddings, one-hot encoding of self.class_emb_dim classes
obj_classes = vision_features_before_mm_projection[:, :, -1].long()
obj_class_embeddings = torch.eye(
self.class_emb_dim,
device=vision_features_before_mm_projection.device,
dtype=vision_features_before_mm_projection.dtype,
)[
obj_classes
] # (B, num_latents, class_emb_dim)
vision_feat = torch.concat(
(obj_class_embeddings, bbox_pos_embeddings), dim=-1
) # (B, num_latents, class_emb_dim + pos_emb_dim)
# get padding mask by checking where zero vectors are
src_key_padding_mask = vision_features_before_mm_projection.eq(0).all(
dim=-1
) # (B, num_latents)
# for the padded locations, we set the vision_feat to be zero
vision_feat[src_key_padding_mask] = 0
bbox_scores = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
cur_vision_feat = vision_feat[batch_idx] # (num_latents, d_latents)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in),)
# concat the language feat with each vision feat
langauge_feat_repeat = langauge_feat.repeat(
cur_vision_feat.shape[0], 1
) # (num_latents, lm_feat_dim_in)
concat_feat = torch.cat(
(cur_vision_feat, langauge_feat_repeat), dim=-1
) # (num_latents, d_latents + lm_feat_dim_in)
fused_feat = self.language_vision_fusion_mlp(concat_feat)
bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,)
bbox_scores.append(bbox_score) # (num_latents)
bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents)
return bbox_scores
class BBoxHeadForGroundTruthBboxSelectionMLPFusionBoxCoordsAndClassID(nn.Module):
"""A simple MLP head for bounding box selection, for training on CE loss"""
def __init__(
self,
lm_feat_dim_in: int,
vision_feat_dim_in: int,
num_vision_feat: int,
dim_feedforward: int = 1024,
):
super().__init__()
self.activation = nn.ReLU()
self.language_vision_fusion_mlp = nn.Sequential(
nn.Linear(vision_feat_dim_in + lm_feat_dim_in, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
self.activation,
nn.Linear(dim_feedforward, dim_feedforward),
)
self.scoring_mlp = nn.Sequential(
nn.Linear(dim_feedforward, 1),
)
def forward(
self,
grd_token_hidden_states_list: list[torch.Tensor],
vision_features_before_mm_projection: torch.Tensor,
):
"""_summary_
Args:
grd_token_hidden_states_list (list[torch.Tensor]): each element in this list
contains the hidden states of the ground tokens in one sample, list[[varying N, D]]
vision_features_before_mm_projection (torch.Tensor): [B, num_latents, 6 + 1]
Returns:
_type_: _description_
"""
bbox_scores = []
for batch_idx, grd_token_hidden_states in enumerate(grd_token_hidden_states_list):
vision_feat = vision_features_before_mm_projection[
batch_idx
] # (num_latents, d_latents)
for i in range(len(grd_token_hidden_states)):
langauge_feat = grd_token_hidden_states[i] # (lm_feat_dim_in),)
# concat the language feat with each vision feat
langauge_feat_repeat = langauge_feat.repeat(
vision_feat.shape[0], 1
) # (num_latents, lm_feat_dim_in)
concat_feat = torch.cat(
(vision_feat, langauge_feat_repeat), dim=-1
) # (num_latents, d_latents + lm_feat_dim_in)
fused_feat = self.language_vision_fusion_mlp(concat_feat)
bbox_score = self.scoring_mlp(fused_feat).squeeze(-1) # (num_latents,)
bbox_scores.append(bbox_score) # (num_latents)
bbox_scores = torch.stack(bbox_scores, dim=0) # (N, num_latents)
return bbox_scores