# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch from torch import nn, Tensor from sam2.modeling.sam.transformer import RoPEAttention from sam2.modeling.sam2_utils import get_activation_fn, get_clones import pdb class MemoryAttentionLayer(nn.Module): def __init__( self, activation: str, cross_attention: nn.Module, d_model: int, dim_feedforward: int, dropout: float, pos_enc_at_attn: bool, pos_enc_at_cross_attn_keys: bool, pos_enc_at_cross_attn_queries: bool, self_attention: nn.Module, ): super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward self.dropout_value = dropout self.self_attn = self_attention self.cross_attn_image = cross_attention # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation_str = activation self.activation = get_activation_fn(activation) # Where to add pos enc self.pos_enc_at_attn = pos_enc_at_attn self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn(q, k, v=tgt2) tgt = tgt + self.dropout1(tgt2) return tgt def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0, object_frame_scores=None, object_ptr_scores=None): kwds = {} if num_k_exclude_rope > 0: assert isinstance(self.cross_attn_image, RoPEAttention) kwds = {"num_k_exclude_rope": num_k_exclude_rope} # Cross-Attention tgt2 = self.norm2(tgt) if object_frame_scores is None: key = memory + pos if self.pos_enc_at_cross_attn_keys else memory else: # relative key_original = memory + pos if self.pos_enc_at_cross_attn_keys else memory num_frame, num_ptr = len(object_frame_scores), len(object_ptr_scores) num_frame_ = int(num_frame*4096) num_object = key_original.shape[0] key_frame = key_original[:, :num_frame_].reshape(num_object, num_frame, 4096, -1) key_ptr = key_original[:, num_frame_:].reshape(num_object, num_ptr, 4, -1) scaling_low = 0.95 scaling_high = 1.05 if num_frame == 1: key = key_original else: weight_frame = torch.stack(object_frame_scores, dim=1) # num_object, num_frame weight_ptr = torch.stack(object_ptr_scores, dim=1) # num_object, num_ptr standard_weight_frame = torch.linspace(scaling_low, scaling_high, num_frame).to(weight_frame) # num_frame standard_weight_ptr = torch.linspace(scaling_low, scaling_high, num_ptr).to(weight_ptr) # num_ptr new_weight_frame = torch.zeros_like(weight_frame) new_weight_ptr = torch.zeros_like(weight_ptr) new_weight_frame.scatter_(1, torch.argsort(weight_frame, dim=1), standard_weight_frame.unsqueeze(0).repeat([num_object, 1])) new_weight_ptr.scatter_(1, torch.argsort(weight_ptr, dim=1), standard_weight_ptr.unsqueeze(0).repeat([num_object, 1])) key_frame_scale = (new_weight_frame[:, :, None, None].to(key_frame.device) * key_frame) key_ptr_scale = (new_weight_ptr[:, :, None, None].to(key_ptr.device) * key_ptr) key = torch.cat([key_frame_scale.reshape(num_object, num_frame_, -1), key_ptr_scale.reshape(num_object, int(num_ptr*4), -1)], dim=1) # key = memory + pos if self.pos_enc_at_cross_attn_keys else memory tgt2 = self.cross_attn_image( q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, k=key, v=memory, **kwds, ) tgt = tgt + self.dropout2(tgt2) return tgt def forward( self, tgt, memory, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, num_k_exclude_rope: int = 0, object_frame_scores = None, object_ptr_scores = None, ) -> torch.Tensor: # Self-Attn, Cross-Attn tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope, object_frame_scores, object_ptr_scores) # MLP tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt class MemoryAttention(nn.Module): def __init__( self, d_model: int, pos_enc_at_input: bool, layer: nn.Module, num_layers: int, batch_first: bool = True, # Do layers expect batch first input? ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.num_layers = num_layers self.norm = nn.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.batch_first = batch_first def forward( self, curr: torch.Tensor, # self-attention inputs memory: torch.Tensor, # cross-attention inputs curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* object_frame_scores=None, object_ptr_scores=None, ): if isinstance(curr, list): assert isinstance(curr_pos, list) assert len(curr) == len(curr_pos) == 1 curr, curr_pos = ( curr[0], curr_pos[0], ) assert ( curr.shape[1] == memory.shape[1] ), "Batch size must be the same for curr and memory" output = curr if self.pos_enc_at_input and curr_pos is not None: output = output + 0.1 * curr_pos if self.batch_first: # Convert to batch first output = output.transpose(0, 1) curr_pos = curr_pos.transpose(0, 1) memory = memory.transpose(0, 1) memory_pos = memory_pos.transpose(0, 1) for layer in self.layers: kwds = {} if isinstance(layer.cross_attn_image, RoPEAttention): kwds = {"num_k_exclude_rope": num_obj_ptr_tokens, "object_frame_scores": object_frame_scores, "object_ptr_scores":object_ptr_scores} output = layer( tgt=output, memory=memory, pos=memory_pos, query_pos=curr_pos, **kwds, ) normed_output = self.norm(output) if self.batch_first: # Convert back to seq first normed_output = normed_output.transpose(0, 1) curr_pos = curr_pos.transpose(0, 1) return normed_output