|
|
|
|
|
""" |
|
|
Transformer decoder. |
|
|
Inspired from Pytorch's version, adds the pre-norm variant |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
|
|
|
from sam3.sam.transformer import RoPEAttention |
|
|
|
|
|
from torch import nn, Tensor |
|
|
from torchvision.ops.roi_align import RoIAlign |
|
|
|
|
|
from .act_ckpt_utils import activation_ckpt_wrapper |
|
|
|
|
|
from .box_ops import box_cxcywh_to_xyxy |
|
|
|
|
|
from .model_misc import ( |
|
|
gen_sineembed_for_position, |
|
|
get_activation_fn, |
|
|
get_clones, |
|
|
inverse_sigmoid, |
|
|
MLP, |
|
|
) |
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
activation: str, |
|
|
d_model: int, |
|
|
dim_feedforward: int, |
|
|
dropout: float, |
|
|
cross_attention: nn.Module, |
|
|
n_heads: int, |
|
|
use_text_cross_attention: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.cross_attn = cross_attention |
|
|
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.use_text_cross_attention = use_text_cross_attention |
|
|
if use_text_cross_attention: |
|
|
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) |
|
|
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
self.catext_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.activation = get_activation_fn(activation) |
|
|
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
|
|
@staticmethod |
|
|
def with_pos_embed(tensor, pos): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_ffn(self, tgt): |
|
|
with torch.amp.autocast(device_type="cuda", enabled=False): |
|
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout4(tgt2) |
|
|
tgt = self.norm3(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
|
|
|
tgt: Optional[Tensor], |
|
|
tgt_query_pos: Optional[Tensor] = None, |
|
|
tgt_query_sine_embed: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
tgt_reference_points: Optional[Tensor] = None, |
|
|
memory_text: Optional[Tensor] = None, |
|
|
text_attention_mask: Optional[Tensor] = None, |
|
|
|
|
|
memory: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_level_start_index: Optional[Tensor] = None, |
|
|
memory_spatial_shapes: Optional[Tensor] = None, |
|
|
memory_pos: Optional[Tensor] = None, |
|
|
|
|
|
self_attn_mask: Optional[Tensor] = None, |
|
|
cross_attn_mask: Optional[Tensor] = None, |
|
|
|
|
|
dac=False, |
|
|
dac_use_selfatt_ln=True, |
|
|
presence_token=None, |
|
|
|
|
|
identity=0.0, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Input: |
|
|
- tgt/tgt_query_pos: nq, bs, d_model |
|
|
- |
|
|
""" |
|
|
|
|
|
if self.self_attn is not None: |
|
|
if dac: |
|
|
|
|
|
assert tgt.shape[0] % 2 == 0 |
|
|
num_o2o_queries = tgt.shape[0] // 2 |
|
|
tgt_o2o = tgt[:num_o2o_queries] |
|
|
tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries] |
|
|
tgt_o2m = tgt[num_o2o_queries:] |
|
|
else: |
|
|
tgt_o2o = tgt |
|
|
tgt_query_pos_o2o = tgt_query_pos |
|
|
|
|
|
if presence_token is not None: |
|
|
tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0) |
|
|
tgt_query_pos_o2o = torch.cat( |
|
|
[torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0 |
|
|
) |
|
|
tgt_query_pos = torch.cat( |
|
|
[torch.zeros_like(presence_token), tgt_query_pos], dim=0 |
|
|
) |
|
|
|
|
|
q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) |
|
|
tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0] |
|
|
tgt_o2o = tgt_o2o + self.dropout2(tgt2) |
|
|
if dac: |
|
|
if not dac_use_selfatt_ln: |
|
|
tgt_o2o = self.norm2(tgt_o2o) |
|
|
tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) |
|
|
if dac_use_selfatt_ln: |
|
|
tgt = self.norm2(tgt) |
|
|
else: |
|
|
tgt = tgt_o2o |
|
|
tgt = self.norm2(tgt) |
|
|
|
|
|
if self.use_text_cross_attention: |
|
|
tgt2 = self.ca_text( |
|
|
self.with_pos_embed(tgt, tgt_query_pos), |
|
|
memory_text, |
|
|
memory_text, |
|
|
key_padding_mask=text_attention_mask, |
|
|
)[0] |
|
|
tgt = tgt + self.catext_dropout(tgt2) |
|
|
tgt = self.catext_norm(tgt) |
|
|
|
|
|
if presence_token is not None: |
|
|
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) |
|
|
cross_attn_mask = torch.cat( |
|
|
[presence_token_mask, cross_attn_mask], dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
tgt2 = self.cross_attn( |
|
|
query=self.with_pos_embed(tgt, tgt_query_pos), |
|
|
key=self.with_pos_embed(memory, memory_pos), |
|
|
value=memory, |
|
|
attn_mask=cross_attn_mask, |
|
|
key_padding_mask=( |
|
|
memory_key_padding_mask.transpose(0, 1) |
|
|
if memory_key_padding_mask is not None |
|
|
else None |
|
|
), |
|
|
)[0] |
|
|
|
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
tgt = self.norm1(tgt) |
|
|
|
|
|
|
|
|
tgt = self.forward_ffn(tgt) |
|
|
|
|
|
presence_token_out = None |
|
|
if presence_token is not None: |
|
|
presence_token_out = tgt[:1] |
|
|
tgt = tgt[1:] |
|
|
|
|
|
return tgt, presence_token_out |
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
frozen: bool, |
|
|
interaction_layer, |
|
|
layer, |
|
|
num_layers: int, |
|
|
num_queries: int, |
|
|
return_intermediate: bool, |
|
|
box_refine: bool = False, |
|
|
num_o2m_queries: int = 0, |
|
|
dac: bool = False, |
|
|
boxRPB: str = "none", |
|
|
|
|
|
instance_query: bool = False, |
|
|
|
|
|
|
|
|
num_instances: int = 1, |
|
|
dac_use_selfatt_ln: bool = True, |
|
|
use_act_checkpoint: bool = False, |
|
|
compile_mode=None, |
|
|
presence_token: bool = False, |
|
|
clamp_presence_logits: bool = True, |
|
|
clamp_presence_logit_max_val: float = 10.0, |
|
|
use_normed_output_consistently: bool = True, |
|
|
separate_box_head_instance: bool = False, |
|
|
separate_norm_instance: bool = False, |
|
|
resolution: Optional[int] = None, |
|
|
stride: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.layers = get_clones(layer, num_layers) |
|
|
self.fine_layers = ( |
|
|
get_clones(interaction_layer, num_layers) |
|
|
if interaction_layer is not None |
|
|
else [None] * num_layers |
|
|
) |
|
|
self.num_layers = num_layers |
|
|
self.num_queries = num_queries |
|
|
self.dac = dac |
|
|
if dac: |
|
|
self.num_o2m_queries = num_queries |
|
|
tot_num_queries = num_queries |
|
|
else: |
|
|
self.num_o2m_queries = num_o2m_queries |
|
|
tot_num_queries = num_queries + num_o2m_queries |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.return_intermediate = return_intermediate |
|
|
self.bbox_embed = MLP(d_model, d_model, 4, 3) |
|
|
self.query_embed = nn.Embedding(tot_num_queries, d_model) |
|
|
self.instance_query_embed = None |
|
|
self.instance_query_reference_points = None |
|
|
self.use_instance_query = instance_query |
|
|
self.num_instances = num_instances |
|
|
self.use_normed_output_consistently = use_normed_output_consistently |
|
|
|
|
|
self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None |
|
|
self.instance_bbox_embed = None |
|
|
if separate_box_head_instance: |
|
|
self.instance_bbox_embed = MLP(d_model, d_model, 4, 3) |
|
|
if instance_query: |
|
|
self.instance_query_embed = nn.Embedding(num_instances, d_model) |
|
|
self.box_refine = box_refine |
|
|
if box_refine: |
|
|
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) |
|
|
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) |
|
|
|
|
|
self.reference_points = nn.Embedding(num_queries, 4) |
|
|
if instance_query: |
|
|
self.instance_reference_points = nn.Embedding(num_instances, 4) |
|
|
|
|
|
assert boxRPB in ["none", "log", "linear", "both"] |
|
|
self.boxRPB = boxRPB |
|
|
if boxRPB != "none": |
|
|
try: |
|
|
nheads = self.layers[0].cross_attn_image.num_heads |
|
|
except AttributeError: |
|
|
nheads = self.layers[0].cross_attn.num_heads |
|
|
|
|
|
n_input = 4 if boxRPB == "both" else 2 |
|
|
self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2) |
|
|
self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2) |
|
|
self.compilable_cord_cache = None |
|
|
self.compilable_stored_size = None |
|
|
self.coord_cache = {} |
|
|
|
|
|
if resolution is not None and stride is not None: |
|
|
feat_size = resolution // stride |
|
|
coords_h, coords_w = self._get_coords( |
|
|
feat_size, feat_size, device="cuda" |
|
|
) |
|
|
self.compilable_cord_cache = (coords_h, coords_w) |
|
|
self.compilable_stored_size = (feat_size, feat_size) |
|
|
|
|
|
self.roi_pooler = ( |
|
|
RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True) |
|
|
if interaction_layer is not None |
|
|
else None |
|
|
) |
|
|
if frozen: |
|
|
for p in self.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
self.presence_token = None |
|
|
self.clamp_presence_logits = clamp_presence_logits |
|
|
self.clamp_presence_logit_max_val = clamp_presence_logit_max_val |
|
|
if presence_token: |
|
|
self.presence_token = nn.Embedding(1, d_model) |
|
|
self.presence_token_head = MLP(d_model, d_model, 1, 3) |
|
|
self.presence_token_out_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2) |
|
|
self.dac_use_selfatt_ln = dac_use_selfatt_ln |
|
|
self.use_act_checkpoint = use_act_checkpoint |
|
|
|
|
|
nn.init.normal_(self.query_embed.weight.data) |
|
|
if self.instance_query_embed is not None: |
|
|
nn.init.normal_(self.instance_query_embed.weight.data) |
|
|
|
|
|
assert self.roi_pooler is None |
|
|
assert self.return_intermediate, "support return_intermediate only" |
|
|
assert self.box_refine, "support box refine only" |
|
|
|
|
|
self.compile_mode = compile_mode |
|
|
self.compiled = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
|
layer.layer_idx = layer_idx |
|
|
|
|
|
@staticmethod |
|
|
def _get_coords(H, W, device): |
|
|
coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H |
|
|
coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W |
|
|
return coords_h, coords_w |
|
|
|
|
|
def _get_rpb_matrix(self, reference_boxes, feat_size): |
|
|
H, W = feat_size |
|
|
boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1) |
|
|
bs, num_queries, _ = boxes_xyxy.shape |
|
|
if self.compilable_cord_cache is None: |
|
|
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) |
|
|
self.compilable_stored_size = (H, W) |
|
|
|
|
|
if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( |
|
|
H, |
|
|
W, |
|
|
): |
|
|
|
|
|
coords_h, coords_w = self.compilable_cord_cache |
|
|
else: |
|
|
|
|
|
|
|
|
if feat_size not in self.coord_cache: |
|
|
self.coord_cache[feat_size] = self._get_coords( |
|
|
H, W, reference_boxes.device |
|
|
) |
|
|
coords_h, coords_w = self.coord_cache[feat_size] |
|
|
|
|
|
assert coords_h.shape == (H,) |
|
|
assert coords_w.shape == (W,) |
|
|
|
|
|
deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] |
|
|
deltas_y = deltas_y.view(bs, num_queries, -1, 2) |
|
|
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] |
|
|
deltas_x = deltas_x.view(bs, num_queries, -1, 2) |
|
|
|
|
|
if self.boxRPB in ["log", "both"]: |
|
|
deltas_x_log = deltas_x * 8 |
|
|
deltas_x_log = ( |
|
|
torch.sign(deltas_x_log) |
|
|
* torch.log2(torch.abs(deltas_x_log) + 1.0) |
|
|
/ np.log2(8) |
|
|
) |
|
|
|
|
|
deltas_y_log = deltas_y * 8 |
|
|
deltas_y_log = ( |
|
|
torch.sign(deltas_y_log) |
|
|
* torch.log2(torch.abs(deltas_y_log) + 1.0) |
|
|
/ np.log2(8) |
|
|
) |
|
|
if self.boxRPB == "log": |
|
|
deltas_x = deltas_x_log |
|
|
deltas_y = deltas_y_log |
|
|
else: |
|
|
deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1) |
|
|
deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1) |
|
|
|
|
|
if self.training: |
|
|
assert self.use_act_checkpoint, "activation ckpt not enabled in decoder" |
|
|
deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)( |
|
|
x=deltas_x, |
|
|
act_ckpt_enable=self.training and self.use_act_checkpoint, |
|
|
) |
|
|
deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)( |
|
|
x=deltas_y, |
|
|
act_ckpt_enable=self.training and self.use_act_checkpoint, |
|
|
) |
|
|
|
|
|
if not torch.compiler.is_dynamo_compiling(): |
|
|
assert deltas_x.shape[:3] == (bs, num_queries, W) |
|
|
assert deltas_y.shape[:3] == (bs, num_queries, H) |
|
|
|
|
|
B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( |
|
|
2 |
|
|
) |
|
|
if not torch.compiler.is_dynamo_compiling(): |
|
|
assert B.shape[:4] == (bs, num_queries, H, W) |
|
|
B = B.flatten(2, 3) |
|
|
B = B.permute(0, 3, 1, 2) |
|
|
B = B.contiguous() |
|
|
if not torch.compiler.is_dynamo_compiling(): |
|
|
assert B.shape[2:] == (num_queries, H * W) |
|
|
return B |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tgt, |
|
|
memory, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
reference_boxes: Optional[Tensor] = None, |
|
|
|
|
|
level_start_index: Optional[Tensor] = None, |
|
|
spatial_shapes: Optional[Tensor] = None, |
|
|
valid_ratios: Optional[Tensor] = None, |
|
|
|
|
|
memory_text: Optional[Tensor] = None, |
|
|
text_attention_mask: Optional[Tensor] = None, |
|
|
|
|
|
apply_dac: Optional[bool] = None, |
|
|
is_instance_prompt=False, |
|
|
decoder_extra_kwargs: Optional[Dict] = None, |
|
|
|
|
|
obj_roi_memory_feat=None, |
|
|
obj_roi_memory_mask=None, |
|
|
box_head_trk=None, |
|
|
): |
|
|
""" |
|
|
Input: |
|
|
- tgt: nq, bs, d_model |
|
|
- memory: \\sum{hw}, bs, d_model |
|
|
- pos: \\sum{hw}, bs, d_model |
|
|
- reference_boxes: nq, bs, 4 (after sigmoid) |
|
|
- valid_ratios/spatial_shapes: bs, nlevel, 2 |
|
|
""" |
|
|
if memory_mask is not None: |
|
|
assert ( |
|
|
self.boxRPB == "none" |
|
|
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" |
|
|
|
|
|
apply_dac = apply_dac if apply_dac is not None else self.dac |
|
|
if apply_dac: |
|
|
assert (tgt.shape[0] == self.num_queries) or ( |
|
|
self.use_instance_query |
|
|
and (tgt.shape[0] == self.instance_query_embed.num_embeddings) |
|
|
) |
|
|
|
|
|
tgt = tgt.repeat(2, 1, 1) |
|
|
|
|
|
|
|
|
if reference_boxes is not None: |
|
|
assert (reference_boxes.shape[0] == self.num_queries) or ( |
|
|
self.use_instance_query |
|
|
and ( |
|
|
reference_boxes.shape[0] |
|
|
== self.instance_query_embed.num_embeddings |
|
|
) |
|
|
) |
|
|
reference_boxes = reference_boxes.repeat(2, 1, 1) |
|
|
|
|
|
bs = tgt.shape[1] |
|
|
intermediate = [] |
|
|
intermediate_presence_logits = [] |
|
|
presence_feats = None |
|
|
|
|
|
if self.box_refine: |
|
|
if reference_boxes is None: |
|
|
|
|
|
reference_boxes = self.reference_points.weight.unsqueeze(1) |
|
|
reference_boxes = ( |
|
|
reference_boxes.repeat(2, bs, 1) |
|
|
if apply_dac |
|
|
else reference_boxes.repeat(1, bs, 1) |
|
|
) |
|
|
reference_boxes = reference_boxes.sigmoid() |
|
|
intermediate_ref_boxes = [reference_boxes] |
|
|
else: |
|
|
reference_boxes = None |
|
|
intermediate_ref_boxes = None |
|
|
|
|
|
output = tgt |
|
|
presence_out = None |
|
|
if self.presence_token is not None and is_instance_prompt is False: |
|
|
|
|
|
presence_out = self.presence_token.weight[None].expand(1, bs, -1) |
|
|
|
|
|
box_head = self.bbox_embed |
|
|
if is_instance_prompt and self.instance_bbox_embed is not None: |
|
|
box_head = self.instance_bbox_embed |
|
|
|
|
|
out_norm = self.norm |
|
|
if is_instance_prompt and self.instance_norm is not None: |
|
|
out_norm = self.instance_norm |
|
|
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
|
reference_points_input = ( |
|
|
reference_boxes[:, :, None] |
|
|
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] |
|
|
) |
|
|
|
|
|
query_sine_embed = gen_sineembed_for_position( |
|
|
reference_points_input[:, :, 0, :], self.d_model |
|
|
) |
|
|
|
|
|
|
|
|
query_pos = self.ref_point_head(query_sine_embed) |
|
|
|
|
|
if self.boxRPB != "none" and reference_boxes is not None: |
|
|
assert ( |
|
|
spatial_shapes.shape[0] == 1 |
|
|
), "only single scale support implemented" |
|
|
memory_mask = self._get_rpb_matrix( |
|
|
reference_boxes, |
|
|
(spatial_shapes[0, 0], spatial_shapes[0, 1]), |
|
|
) |
|
|
memory_mask = memory_mask.flatten(0, 1) |
|
|
if self.training: |
|
|
assert ( |
|
|
self.use_act_checkpoint |
|
|
), "Activation checkpointing not enabled in the decoder" |
|
|
output, presence_out = activation_ckpt_wrapper(layer)( |
|
|
tgt=output, |
|
|
tgt_query_pos=query_pos, |
|
|
tgt_query_sine_embed=query_sine_embed, |
|
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
|
tgt_reference_points=reference_points_input, |
|
|
memory_text=memory_text, |
|
|
text_attention_mask=text_attention_mask, |
|
|
memory=memory, |
|
|
memory_key_padding_mask=memory_key_padding_mask, |
|
|
memory_level_start_index=level_start_index, |
|
|
memory_spatial_shapes=spatial_shapes, |
|
|
memory_pos=pos, |
|
|
self_attn_mask=tgt_mask, |
|
|
cross_attn_mask=memory_mask, |
|
|
dac=apply_dac, |
|
|
dac_use_selfatt_ln=self.dac_use_selfatt_ln, |
|
|
presence_token=presence_out, |
|
|
**(decoder_extra_kwargs or {}), |
|
|
act_ckpt_enable=self.training and self.use_act_checkpoint, |
|
|
|
|
|
obj_roi_memory_feat=obj_roi_memory_feat, |
|
|
obj_roi_memory_mask=obj_roi_memory_mask, |
|
|
) |
|
|
|
|
|
|
|
|
if self.box_refine: |
|
|
reference_before_sigmoid = inverse_sigmoid(reference_boxes) |
|
|
if box_head_trk is None: |
|
|
|
|
|
if not self.use_normed_output_consistently: |
|
|
delta_unsig = box_head(output) |
|
|
else: |
|
|
delta_unsig = box_head(out_norm(output)) |
|
|
else: |
|
|
|
|
|
Q_det = decoder_extra_kwargs["Q_det"] |
|
|
assert output.size(0) >= Q_det |
|
|
delta_unsig_det = self.bbox_embed(output[:Q_det]) |
|
|
delta_unsig_trk = box_head_trk(output[Q_det:]) |
|
|
delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0) |
|
|
outputs_unsig = delta_unsig + reference_before_sigmoid |
|
|
new_reference_points = outputs_unsig.sigmoid() |
|
|
|
|
|
reference_boxes = new_reference_points.detach() |
|
|
if layer_idx != self.num_layers - 1: |
|
|
intermediate_ref_boxes.append(new_reference_points) |
|
|
else: |
|
|
raise NotImplementedError("not implemented yet") |
|
|
|
|
|
intermediate.append(out_norm(output)) |
|
|
if self.presence_token is not None and is_instance_prompt is False: |
|
|
|
|
|
intermediate_layer_presence_logits = self.presence_token_head( |
|
|
self.presence_token_out_norm(presence_out) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
if self.clamp_presence_logits: |
|
|
intermediate_layer_presence_logits.clamp( |
|
|
min=-self.clamp_presence_logit_max_val, |
|
|
max=self.clamp_presence_logit_max_val, |
|
|
) |
|
|
|
|
|
intermediate_presence_logits.append(intermediate_layer_presence_logits) |
|
|
presence_feats = presence_out.clone() |
|
|
|
|
|
if not self.compiled and self.compile_mode is not None: |
|
|
self.forward = torch.compile( |
|
|
self.forward, mode=self.compile_mode, fullgraph=True |
|
|
) |
|
|
self.compiled = True |
|
|
|
|
|
return ( |
|
|
torch.stack(intermediate), |
|
|
torch.stack(intermediate_ref_boxes), |
|
|
( |
|
|
torch.stack(intermediate_presence_logits) |
|
|
if self.presence_token is not None and is_instance_prompt is False |
|
|
else None |
|
|
), |
|
|
presence_feats, |
|
|
) |
|
|
|
|
|
|
|
|
class TransformerEncoderCrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
frozen: bool, |
|
|
pos_enc_at_input: bool, |
|
|
layer, |
|
|
num_layers: int, |
|
|
use_act_checkpoint: bool = False, |
|
|
batch_first: bool = False, |
|
|
|
|
|
|
|
|
remove_cross_attention_layers: Optional[list] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.layers = get_clones(layer, num_layers) |
|
|
self.num_layers = num_layers |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.pos_enc_at_input = pos_enc_at_input |
|
|
self.use_act_checkpoint = use_act_checkpoint |
|
|
|
|
|
if frozen: |
|
|
for p in self.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
self.batch_first = batch_first |
|
|
|
|
|
|
|
|
self.remove_cross_attention_layers = [False] * self.num_layers |
|
|
if remove_cross_attention_layers is not None: |
|
|
for i in remove_cross_attention_layers: |
|
|
self.remove_cross_attention_layers[i] = True |
|
|
assert len(self.remove_cross_attention_layers) == len(self.layers) |
|
|
|
|
|
for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers): |
|
|
if remove_cross_attention: |
|
|
self.layers[i].cross_attn_image = None |
|
|
self.layers[i].norm2 = None |
|
|
self.layers[i].dropout2 = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src, |
|
|
prompt, |
|
|
src_mask: Optional[Tensor] = None, |
|
|
prompt_mask: Optional[Tensor] = None, |
|
|
src_key_padding_mask: Optional[Tensor] = None, |
|
|
prompt_key_padding_mask: Optional[Tensor] = None, |
|
|
src_pos: Optional[Tensor] = None, |
|
|
prompt_pos: Optional[Tensor] = None, |
|
|
feat_sizes: Optional[list] = None, |
|
|
num_obj_ptr_tokens: int = 0, |
|
|
): |
|
|
if isinstance(src, list): |
|
|
assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list) |
|
|
assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1 |
|
|
src, src_key_padding_mask, src_pos = ( |
|
|
src[0], |
|
|
src_key_padding_mask[0], |
|
|
src_pos[0], |
|
|
) |
|
|
|
|
|
assert ( |
|
|
src.shape[1] == prompt.shape[1] |
|
|
), "Batch size must be the same for src and prompt" |
|
|
|
|
|
output = src |
|
|
|
|
|
if self.pos_enc_at_input and src_pos is not None: |
|
|
output = output + 0.1 * src_pos |
|
|
|
|
|
if self.batch_first: |
|
|
|
|
|
output = output.transpose(0, 1) |
|
|
src_pos = src_pos.transpose(0, 1) |
|
|
prompt = prompt.transpose(0, 1) |
|
|
prompt_pos = prompt_pos.transpose(0, 1) |
|
|
|
|
|
for layer in self.layers: |
|
|
kwds = {} |
|
|
if isinstance(layer.cross_attn_image, RoPEAttention): |
|
|
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} |
|
|
|
|
|
output = activation_ckpt_wrapper(layer)( |
|
|
tgt=output, |
|
|
memory=prompt, |
|
|
tgt_mask=src_mask, |
|
|
memory_mask=prompt_mask, |
|
|
tgt_key_padding_mask=src_key_padding_mask, |
|
|
memory_key_padding_mask=prompt_key_padding_mask, |
|
|
pos=prompt_pos, |
|
|
query_pos=src_pos, |
|
|
dac=False, |
|
|
attn_bias=None, |
|
|
act_ckpt_enable=self.training and self.use_act_checkpoint, |
|
|
**kwds, |
|
|
) |
|
|
normed_output = self.norm(output) |
|
|
|
|
|
if self.batch_first: |
|
|
|
|
|
normed_output = normed_output.transpose(0, 1) |
|
|
src_pos = src_pos.transpose(0, 1) |
|
|
|
|
|
return { |
|
|
"memory": normed_output, |
|
|
"pos_embed": src_pos, |
|
|
"padding_mask": src_key_padding_mask, |
|
|
} |
|
|
|
|
|
|
|
|
class TransformerDecoderLayerv1(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
activation: str, |
|
|
cross_attention: nn.Module, |
|
|
d_model: int, |
|
|
dim_feedforward: int, |
|
|
dropout: float, |
|
|
pos_enc_at_attn: bool, |
|
|
pos_enc_at_cross_attn_keys: bool, |
|
|
pos_enc_at_cross_attn_queries: bool, |
|
|
pre_norm: bool, |
|
|
self_attention: nn.Module, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.dim_feedforward = dim_feedforward |
|
|
self.dropout_value = dropout |
|
|
self.self_attn = self_attention |
|
|
self.cross_attn_image = cross_attention |
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
|
|
self.activation_str = activation |
|
|
self.activation = get_activation_fn(activation) |
|
|
self.pre_norm = pre_norm |
|
|
|
|
|
self.pos_enc_at_attn = pos_enc_at_attn |
|
|
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries |
|
|
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys |
|
|
|
|
|
def forward_post( |
|
|
self, |
|
|
tgt, |
|
|
memory, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
q = k = tgt + query_pos if self.pos_enc_at_attn else tgt |
|
|
|
|
|
|
|
|
tgt2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=tgt, |
|
|
attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask, |
|
|
)[0] |
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
tgt = self.norm1(tgt) |
|
|
|
|
|
|
|
|
tgt2 = self.cross_attn_image( |
|
|
query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, |
|
|
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, |
|
|
value=memory, |
|
|
attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask, |
|
|
)[0] |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
tgt = self.norm2(tgt) |
|
|
|
|
|
|
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout3(tgt2) |
|
|
tgt = self.norm3(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward_pre( |
|
|
self, |
|
|
tgt, |
|
|
memory, |
|
|
dac: bool = False, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
if dac: |
|
|
|
|
|
assert tgt.shape[0] % 2 == 0 |
|
|
other_tgt = tgt[tgt.shape[0] // 2 :] |
|
|
tgt = tgt[: tgt.shape[0] // 2] |
|
|
tgt2 = self.norm1(tgt) |
|
|
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 |
|
|
tgt2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=tgt2, |
|
|
attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask, |
|
|
)[0] |
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
if dac: |
|
|
|
|
|
tgt = torch.cat((tgt, other_tgt), dim=0) |
|
|
tgt2 = self.norm2(tgt) |
|
|
tgt2 = self.cross_attn_image( |
|
|
query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, |
|
|
key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, |
|
|
value=memory, |
|
|
attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask, |
|
|
attn_bias=attn_bias, |
|
|
)[0] |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
tgt2 = self.norm3(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout3(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tgt, |
|
|
memory, |
|
|
dac: bool = False, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
**kwds: Any, |
|
|
) -> torch.Tensor: |
|
|
fwd_fn = self.forward_pre if self.pre_norm else self.forward_post |
|
|
return fwd_fn( |
|
|
tgt, |
|
|
memory, |
|
|
dac=dac, |
|
|
tgt_mask=tgt_mask, |
|
|
memory_mask=memory_mask, |
|
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
|
memory_key_padding_mask=memory_key_padding_mask, |
|
|
pos=pos, |
|
|
query_pos=query_pos, |
|
|
attn_bias=attn_bias, |
|
|
**kwds, |
|
|
) |
|
|
|
|
|
|
|
|
class TransformerDecoderLayerv2(TransformerDecoderLayerv1): |
|
|
def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any): |
|
|
super().__init__(*args, **kwds) |
|
|
self.cross_attention_first = cross_attention_first |
|
|
|
|
|
def _forward_sa(self, tgt, query_pos): |
|
|
|
|
|
tgt2 = self.norm1(tgt) |
|
|
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 |
|
|
tgt2 = self.self_attn(q, k, v=tgt2) |
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
return tgt |
|
|
|
|
|
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): |
|
|
if self.cross_attn_image is None: |
|
|
return tgt |
|
|
|
|
|
kwds = {} |
|
|
if num_k_exclude_rope > 0: |
|
|
assert isinstance(self.cross_attn_image, RoPEAttention) |
|
|
kwds = {"num_k_exclude_rope": num_k_exclude_rope} |
|
|
|
|
|
|
|
|
tgt2 = self.norm2(tgt) |
|
|
tgt2 = self.cross_attn_image( |
|
|
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, |
|
|
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, |
|
|
v=memory, |
|
|
**kwds, |
|
|
) |
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward_pre( |
|
|
self, |
|
|
tgt, |
|
|
memory, |
|
|
dac: bool, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
num_k_exclude_rope: int = 0, |
|
|
): |
|
|
assert dac is False |
|
|
assert tgt_mask is None |
|
|
assert memory_mask is None |
|
|
assert tgt_key_padding_mask is None |
|
|
assert memory_key_padding_mask is None |
|
|
assert attn_bias is None |
|
|
|
|
|
if self.cross_attention_first: |
|
|
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) |
|
|
tgt = self._forward_sa(tgt, query_pos) |
|
|
else: |
|
|
tgt = self._forward_sa(tgt, query_pos) |
|
|
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) |
|
|
|
|
|
|
|
|
tgt2 = self.norm3(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout3(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: |
|
|
if self.pre_norm: |
|
|
return self.forward_pre(*args, **kwds) |
|
|
raise NotImplementedError |
|
|
|