SAM2Long-Demo / sam2 /modeling /memory_attention.py
Mar2Ding's picture
Upload 61 files
b578f14 verified
raw
history blame
7.83 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from torch import nn, Tensor
from sam2.modeling.sam.transformer import RoPEAttention
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
import pdb
class MemoryAttentionLayer(nn.Module):
def __init__(
self,
activation: str,
cross_attention: nn.Module,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
self_attention: nn.Module,
):
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = self_attention
self.cross_attn_image = cross_attention
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation_str = activation
self.activation = get_activation_fn(activation)
# Where to add pos enc
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def _forward_sa(self, tgt, query_pos):
# Self-Attention
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
tgt = tgt + self.dropout1(tgt2)
return tgt
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0, object_frame_scores=None, object_ptr_scores=None):
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
if object_frame_scores is None:
key = memory + pos if self.pos_enc_at_cross_attn_keys else memory
else: # relative
key_original = memory + pos if self.pos_enc_at_cross_attn_keys else memory
num_frame, num_ptr = len(object_frame_scores), len(object_ptr_scores)
num_frame_ = int(num_frame*4096)
num_object = key_original.shape[0]
key_frame = key_original[:, :num_frame_].reshape(num_object, num_frame, 4096, -1)
key_ptr = key_original[:, num_frame_:].reshape(num_object, num_ptr, 4, -1)
scaling_low = 0.95
scaling_high = 1.05
if num_frame == 1:
key = key_original
else:
weight_frame = torch.stack(object_frame_scores, dim=1) # num_object, num_frame
weight_ptr = torch.stack(object_ptr_scores, dim=1) # num_object, num_ptr
standard_weight_frame = torch.linspace(scaling_low, scaling_high, num_frame).to(weight_frame) # num_frame
standard_weight_ptr = torch.linspace(scaling_low, scaling_high, num_ptr).to(weight_ptr) # num_ptr
new_weight_frame = torch.zeros_like(weight_frame)
new_weight_ptr = torch.zeros_like(weight_ptr)
new_weight_frame.scatter_(1, torch.argsort(weight_frame, dim=1), standard_weight_frame.unsqueeze(0).repeat([num_object, 1]))
new_weight_ptr.scatter_(1, torch.argsort(weight_ptr, dim=1), standard_weight_ptr.unsqueeze(0).repeat([num_object, 1]))
key_frame_scale = (new_weight_frame[:, :, None, None].to(key_frame.device) * key_frame)
key_ptr_scale = (new_weight_ptr[:, :, None, None].to(key_ptr.device) * key_ptr)
key = torch.cat([key_frame_scale.reshape(num_object, num_frame_, -1), key_ptr_scale.reshape(num_object, int(num_ptr*4), -1)], dim=1)
# key = memory + pos if self.pos_enc_at_cross_attn_keys else memory
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=key,
v=memory,
**kwds,
)
tgt = tgt + self.dropout2(tgt2)
return tgt
def forward(
self,
tgt,
memory,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
object_frame_scores = None,
object_ptr_scores = None,
) -> torch.Tensor:
# Self-Attn, Cross-Attn
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope, object_frame_scores, object_ptr_scores)
# MLP
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class MemoryAttention(nn.Module):
def __init__(
self,
d_model: int,
pos_enc_at_input: bool,
layer: nn.Module,
num_layers: int,
batch_first: bool = True, # Do layers expect batch first input?
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.batch_first = batch_first
def forward(
self,
curr: torch.Tensor, # self-attention inputs
memory: torch.Tensor, # cross-attention inputs
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
object_frame_scores=None,
object_ptr_scores=None,
):
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1
curr, curr_pos = (
curr[0],
curr_pos[0],
)
assert (
curr.shape[1] == memory.shape[1]
), "Batch size must be the same for curr and memory"
output = curr
if self.pos_enc_at_input and curr_pos is not None:
output = output + 0.1 * curr_pos
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
memory = memory.transpose(0, 1)
memory_pos = memory_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens,
"object_frame_scores": object_frame_scores,
"object_ptr_scores":object_ptr_scores}
output = layer(
tgt=output,
memory=memory,
pos=memory_pos,
query_pos=curr_pos,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
return normed_output