merve's picture
merve HF staff
Upload 4 files
e5d3156
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
#
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
#
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
#
# Paper: https://arxiv.org/abs/2306.00989/
#
# References:
# slowfast: https://github.com/facebookresearch/SlowFast
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
import math
from functools import partial
from typing import List, Tuple, Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, Mlp
from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll
class MaskUnitAttention(nn.Module):
"""
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
Note: this assumes the tokens have already been flattened and unrolled into mask units.
See `Unroll` for more details.
"""
def __init__(
self,
dim: int,
dim_out: int,
heads: int,
q_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
):
"""
Args:
- dim, dim_out: The input and output feature dimensions.
- heads: The number of attention heads.
- q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
- window_size: The current (flattened) size of a mask unit *after* pooling (if any).
- use_mask_unit_attn: Use Mask Unit or Global Attention.
"""
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.heads = heads
self.q_stride = q_stride
self.head_dim = dim_out // heads
self.scale = (self.head_dim) ** -0.5
self.qkv = nn.Linear(dim, 3 * dim_out)
self.proj = nn.Linear(dim_out, dim_out)
self.window_size = window_size
self.use_mask_unit_attn = use_mask_unit_attn
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
num_windows = (
(N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
)
qkv = (
self.qkv(x)
.reshape(B, -1, num_windows, 3, self.heads, self.head_dim)
.permute(3, 0, 4, 2, 1, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.q_stride > 1:
# Refer to Unroll to see how this performs a maxpool-Nd
q = (
q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim)
.max(dim=3)
.values
)
if hasattr(F, "scaled_dot_product_attention"):
# Note: the original paper did *not* use SDPA, it's a free boost!
x = F.scaled_dot_product_attention(q, k, v)
else:
attn = (q * self.scale) @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
x = (attn @ v)
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x
class HieraBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
q_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.attn = MaskUnitAttention(
dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn
)
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Attention + Q Pooling
x_norm = self.norm1(x)
if self.dim != self.dim_out:
x = do_pool(self.proj(x_norm), stride=self.attn.q_stride)
x = x + self.drop_path(self.attn(x_norm))
# MLP
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Head(nn.Module):
def __init__(
self,
dim: int,
num_classes: int,
dropout_rate: float = 0.0,
act_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.softmax(dim=-1),
):
super().__init__()
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
self.projection = nn.Linear(dim, num_classes)
# act_fun for eval and testing only
self.act_func = act_func
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.dropout(x)
x = self.projection(x)
if not self.training:
x = self.act_func(x)
return x
class PatchEmbed(nn.Module):
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
def __init__(
self,
dim_in: int,
dim_out: int,
kernel: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
):
super().__init__()
# Support any number of spatial dimensions
self.spatial_dims = len(kernel)
self.proj = conv_nd(self.spatial_dims)(
dim_in,
dim_out,
kernel_size=kernel,
stride=stride,
padding=padding,
)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = do_masked_conv(x, self.proj, mask)
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
return x
class Hiera(nn.Module):
def __init__(
self,
input_size: Tuple[int, ...] = (224, 224),
in_chans: int = 3,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
num_classes: int = 1000,
stages: Tuple[int, ...] = (2, 3, 16, 3),
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, ...] = (2, 2),
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
# mask_unit_attn: which stages use mask unit attention?
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
dim_mul: float = 2.0,
head_mul: float = 2.0,
patch_kernel: Tuple[int, ...] = (7, 7),
patch_stride: Tuple[int, ...] = (4, 4),
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
head_dropout: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
):
super().__init__()
depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
num_tokens = math.prod(self.tokens_spatial_shape)
flat_mu_size = math.prod(mask_unit_size)
flat_q_stride = math.prod(q_stride)
assert q_pool < len(stages)
self.q_pool, self.q_stride = q_pool, q_stride
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
self.mask_spatial_shape = [
i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)
]
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
self.patch_embed = PatchEmbed(
in_chans, embed_dim, patch_kernel, patch_stride, patch_padding
)
self.sep_pos_embed = sep_pos_embed
if sep_pos_embed:
self.pos_embed_spatial = nn.Parameter(
torch.zeros(
1,
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
embed_dim,
)
)
self.pos_embed_temporal = nn.Parameter(
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
)
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
# Setup roll and reroll modules
self.unroll = Unroll(
input_size, patch_stride, [q_stride] * len(self.stage_ends[:-1])
)
self.reroll = Reroll(
input_size,
patch_stride,
[q_stride] * len(self.stage_ends[:-1]),
self.stage_ends,
q_pool,
)
# q_pool locations
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Transformer blocks
cur_stage = 0
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# Mask unit or global attention.
# Lag by 1 block, so that global attention,
# applied post pooling on lower resolution
use_mask_unit_attn = mask_unit_attn[cur_stage]
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
if i in q_pool_blocks:
flat_mu_size //= flat_q_stride
block = HieraBlock(
dim=embed_dim,
dim_out=dim_out,
heads=num_heads,
mlp_ratio=mlp_ratio,
drop_path=dpr[i],
norm_layer=norm_layer,
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
window_size=flat_mu_size,
use_mask_unit_attn=use_mask_unit_attn,
)
embed_dim = dim_out
self.blocks.append(block)
self.norm = norm_layer(embed_dim)
self.head = Head(embed_dim, num_classes, dropout_rate=head_dropout)
# Initialize everything
if sep_pos_embed:
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
else:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(partial(self._init_weights))
self.head.projection.weight.data.mul_(head_init_scale)
self.head.projection.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_bias=0.02):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, init_bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, init_bias)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
if self.sep_pos_embed:
return ["pos_embed_spatial", "pos_embed_temporal"]
else:
return ["pos_embed"]
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
"""
Generates a random mask, mask_ratio fraction are dropped.
1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
"""
B = x.shape[0]
# Tokens selected for masking at mask unit level
num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
len_keep = int(num_windows * (1 - mask_ratio))
noise = torch.rand(B, num_windows, device=x.device)
# Sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1
) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Generate the binary mask: 1 is *keep*, 0 is *remove*
# Note this is opposite to original MAE
mask = torch.zeros([B, num_windows], device=x.device)
mask[:, :len_keep] = 1
# Unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return mask.bool()
def get_pos_embed(self) -> torch.Tensor:
if self.sep_pos_embed:
return self.pos_embed_spatial.repeat(
1, self.tokens_spatial_shape[0], 1
) + torch.repeat_interleave(
self.pos_embed_temporal,
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
dim=1,
)
else:
return self.pos_embed
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
return_intermediates: bool = False,
) -> torch.Tensor:
"""
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
"""
# Slowfast training passes in a list
if isinstance(x, list):
x = x[0]
intermediates = []
x = self.patch_embed(
x,
mask=mask.view(
x.shape[0], 1, *self.mask_spatial_shape
) # B, C, *mask_spatial_shape
if mask is not None
else None,
)
x = x + self.get_pos_embed()
x = self.unroll(x)
# Discard masked tokens
if mask is not None:
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(
x.shape[0], -1, x.shape[-1]
)
for i, blk in enumerate(self.blocks):
x = blk(x)
if return_intermediates and i in self.stage_ends:
intermediates.append(self.reroll(x, i, mask=mask))
if mask is None:
x = x.mean(dim=1)
x = self.norm(x)
x = self.head(x)
# x may not always be in spatial order here.
# e.g. if q_pool = 2, mask_unit_size = (8, 8), and
# q_stride = (2, 2), not all unrolls were consumed,
# intermediates[-1] is x in spatial order
if return_intermediates:
return x, intermediates
return x
# Image models
@pretrained_model({
"mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-tiny-ft-224-in1k/resolve/main/hiera_tiny_224.pth",
"mae_in1k": "https://huggingface.co/merve/hiera-tiny-224-in1k/resolve/main/mae_hiera_tiny_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_tiny_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs)
@pretrained_model({
"mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-small-ft-224-in1k/resolve/main/hiera_small_224.pth",
"mae_in1k": "https://huggingface.co/merve/hiera-small-224-in1k/resolve/main/mae_hiera_small_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_small_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs)
@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_base_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs)
@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_base_plus_224(**kwdargs):
return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs)
@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_large_224(**kwdargs):
return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs)
@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_huge_224(**kwdargs):
return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs)
# Video models
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_base_16x224(num_classes: int = 400, **kwdargs):
return Hiera(
num_classes=num_classes, # K400 has 400 classes
input_size=(16, 224, 224),
q_stride=(1, 2, 2),
mask_unit_size=(1, 8, 8),
patch_kernel=(3, 7, 7),
patch_stride=(2, 4, 4),
patch_padding=(1, 3, 3),
sep_pos_embed=True,
**kwdargs
)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_base_plus_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_large_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_huge_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
)