yunyangx's picture
efficient track anything built on sam2
bd9da36 verified
raw
history blame
11 kB
"""ViTDet backbone adapted from Detectron2"""
from functools import partial
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.modeling.backbones.utils import (
PatchEmbed,
window_partition,
window_unpartition,
get_abs_pos,
)
from sam2.modeling.sam2_utils import DropPath, MLP, LayerScale
from functools import partial
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
x = F.scaled_dot_product_attention(q, k, v)
x = (
x.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
input_size=None,
dropout=0.0,
init_values=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
dropout (float): Dropout rate.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
dim,
int(dim * mlp_ratio),
dim,
num_layers=2,
activation=act_layer,
)
# self.mlp = Mlp2(
# in_features=dim,
# hidden_features=int(dim * mlp_ratio),
# act_layer=act_layer,
# drop=(dropout, 0.0),
# )
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.dropout = nn.Dropout(dropout)
self.window_size = window_size
def forward(self, x):
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.ls1(self.attn(x))
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.dropout(self.drop_path(x))
x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
return x
class ViT(nn.Module):
"""
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
"Exploring Plain Vision Transformer Backbones for Object Detection",
https://arxiv.org/abs/2203.16527
"""
def __init__(
self,
img_size=1024,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
use_abs_pos=True,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=14,
window_block_indexes=(0, 1, 3, 4, 6, 7, 9, 10),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
dropout=0.0,
weights_path=None,
return_interm_layers=False,
init_values=None,
):
"""
Args:
img_size (int): Input image size. Only relevant for rel pos.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
window_block_indexes (list): Indexes for blocks using window attention.
residual_block_indexes (list): Indexes for blocks using conv propagation.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
path (str or None): Path to the pretrained weights.
return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
freezing (BackboneFreezingType): Type of freezing.
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
padding=(0, 0),
in_chans=in_chans,
embed_dim=embed_dim,
)
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (
pretrain_img_size // patch_size
)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
self.full_attn_ids = []
cur_stage = 1
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
dropout=dropout,
init_values=init_values,
)
if i not in window_block_indexes:
self.full_attn_ids.append(i)
cur_stage += 1
self.blocks.append(block)
self.return_interm_layers = return_interm_layers
self.channel_list = (
[embed_dim] * len(self.full_attn_ids)
if return_interm_layers
else [embed_dim]
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
)
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.full_attn_ids[-1]) or (
self.return_interm_layers and i in self.full_attn_ids
):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs