merve's picture
merve HF staff
Upload hiera_mae.py
69b6e00
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# mae: https://github.com/facebookresearch/mae
# slowfast: https://github.com/facebookresearch/SlowFast
# --------------------------------------------------------
from functools import partial
from typing import Tuple, Optional
import math
import torch
import torch.nn as nn
from .hiera import Hiera, HieraBlock
from .hiera_utils import pretrained_model, undo_windowing, conv_nd
def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor:
if isinstance(head, nn.Identity):
return x
B, num_mask_units = x.shape[0:2]
# Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx])
permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2))
x = head(x.reshape(B * num_mask_units, *x.shape[2:]).permute(permute))
# Restore original layout, e.g. [B * #MUs, C', My', Mx'] -> [B, #MUs, My', Mx', C']
permute = [0] + list(range(2, len(x.shape))) + [1]
x = x.permute(permute).reshape(B, num_mask_units, *x.shape[2:], x.shape[1])
return x
class MaskedAutoencoderHiera(Hiera):
"""Masked Autoencoder with Hiera backbone"""
def __init__(
self,
in_chans: int = 3,
patch_stride: Tuple[int, ...] = (4, 4),
mlp_ratio: float = 4.0,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
**kwdargs,
):
super().__init__(
in_chans=in_chans,
patch_stride=patch_stride,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
**kwdargs,
)
del self.norm, self.head
encoder_dim_out = self.blocks[-1].dim_out
self.encoder_norm = norm_layer(encoder_dim_out)
self.mask_unit_spatial_shape_final = [
i // s ** (self.q_pool) for i, s in zip(self.mask_unit_size, self.q_stride)
]
self.tokens_spatial_shape_final = [
i // s ** (self.q_pool)
for i, s in zip(self.tokens_spatial_shape, self.q_stride)
]
# --------------------------------------------------------------------------
# Multi-scale fusion heads
curr_mu_size = self.mask_unit_size
self.multi_scale_fusion_heads = nn.ModuleList()
for i in self.stage_ends[: self.q_pool]: # resolution constant after q_pool
kernel = [
i // s for i, s in zip(curr_mu_size, self.mask_unit_spatial_shape_final)
]
curr_mu_size = [i // s for i, s in zip(curr_mu_size, self.q_stride)]
self.multi_scale_fusion_heads.append(
conv_nd(len(self.q_stride))(
self.blocks[i].dim_out,
encoder_dim_out,
kernel_size=kernel,
stride=kernel,
)
)
self.multi_scale_fusion_heads.append(nn.Identity()) # final stage, no transform
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(encoder_dim_out, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(
1, math.prod(self.tokens_spatial_shape_final), decoder_embed_dim
)
)
self.decoder_blocks = nn.ModuleList(
[
HieraBlock(
dim=decoder_embed_dim,
dim_out=decoder_embed_dim,
heads=decoder_num_heads,
norm_layer=norm_layer,
mlp_ratio=mlp_ratio,
)
for i in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.pred_stride = patch_stride[-1] * (
self.q_stride[-1] ** self.q_pool
) # patch stride of prediction
self.decoder_pred = nn.Linear(
decoder_embed_dim,
(self.pred_stride ** min(2, len(self.q_stride))) * in_chans,
) # predictor
# --------------------------------------------------------------------------
self.initialize_weights()
def initialize_weights(self):
nn.init.trunc_normal_(self.mask_token, std=0.02)
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
self.apply(self._mae_init_weights)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
def _mae_init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_pixel_label_2d(
self, input_img: torch.Tensor, mask: torch.Tensor, norm: bool = True
) -> torch.Tensor:
# mask (boolean tensor): True must correspond to *masked*
input_img = input_img.permute(0, 2, 3, 1)
size = self.pred_stride
label = input_img.unfold(1, size, size).unfold(2, size, size)
label = label.flatten(1, 2).flatten(2)
label = label[mask]
if norm:
mean = label.mean(dim=-1, keepdim=True)
var = label.var(dim=-1, keepdim=True)
label = (label - mean) / (var + 1.0e-6) ** 0.5
return label
def get_pixel_label_3d(
self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True
) -> torch.Tensor:
# mask (boolean tensor): True must correspond to *masked*
# We use time strided loss, only take the first frame from each token
input_vid = input_vid[:, :, ::self.patch_stride[0], :, :]
size = self.pred_stride
label = input_vid.unfold(3, size, size).unfold(4, size, size)
label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol
label = label.flatten(1, 3).flatten(2)
label = label[mask]
if norm:
mean = label.mean(dim=-1, keepdim=True)
var = label.var(dim=-1, keepdim=True)
label = (label - mean) / (var + 1.0e-6) ** 0.5
return label
def forward_encoder(
self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if mask is None:
mask = self.get_random_mask(x, mask_ratio) # [B, #MUs_all]
# Get multi-scale representations from encoder
_, intermediates = super().forward(x, mask, return_intermediates=True)
# Resolution unchanged after q_pool stages, so skip those features
intermediates = intermediates[: self.q_pool] + intermediates[-1:]
# Multi-scale fusion
x = 0.0
for head, interm_x in zip(self.multi_scale_fusion_heads, intermediates):
x += apply_fusion_head(head, interm_x)
x = self.encoder_norm(x)
return x, mask
def forward_decoder(
self, x: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Embed tokens
x = self.decoder_embed(x)
# Combine visible and mask tokens
# x: [B, #MUs, *mask_unit_spatial_shape_final, encoder_dim_out]
# mask: [B, #MUs_all]
x_dec = torch.zeros(*mask.shape, *x.shape[2:], device=x.device, dtype=x.dtype)
mask_tokens = self.mask_token.view(
(1,) * (len(mask.shape) + len(x.shape[2:-1])) + (-1,)
)
mask = mask.reshape(mask.shape + (1,) * len(x.shape[2:]))
mask = mask.expand((-1,) * 2 + x.shape[2:]).bool()
x_dec[mask] = x.flatten()
x_dec = ~mask * mask_tokens + mask * x_dec
# Get back spatial order
x = undo_windowing(
x_dec,
self.tokens_spatial_shape_final,
self.mask_unit_spatial_shape_final,
)
mask = undo_windowing(
mask[..., 0:1],
self.tokens_spatial_shape_final,
self.mask_unit_spatial_shape_final,
)
# Flatten
x = x.reshape(x.shape[0], -1, x.shape[-1])
mask = mask.view(x.shape[0], -1)
# Add pos embed
x = x + self.decoder_pos_embed
# Apply decoder blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# Predictor projection
x = self.decoder_pred(x)
return x, mask
def forward_loss(
self, x: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Note: in mask, 0 is *visible*, 1 is *masked*
x: e.g. [B, 3, H, W]
pred: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
label: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
"""
if len(self.q_stride) == 2:
label = self.get_pixel_label_2d(x, mask)
elif len(self.q_stride) == 3:
label = self.get_pixel_label_3d(x, mask)
else:
raise NotImplementedError
pred = pred[mask]
loss = (pred - label) ** 2
return loss.mean(), pred, label
def forward(
self,
x: torch.Tensor,
mask_ratio: float = 0.6,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
latent, mask = self.forward_encoder(x, mask_ratio, mask=mask)
pred, pred_mask = self.forward_decoder(
latent, mask
) # pred_mask is mask at resolution of *prediction*
# Toggle mask, to generate labels for *masked* tokens
return *self.forward_loss(x, pred, ~pred_mask), mask
# Image Models
@pretrained_model({
"mae_in1k": "https://huggingface.co/merve/hiera-tiny-224-in1k/resolve/main/mae_hiera_tiny_224.pth",
}, default="mae_in1k")
def mae_hiera_tiny_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs,
)
@pretrained_model({
"mae_in1k": "https://huggingface.co/merve/hiera-small-224-in1k/resolve/main/mae_hiera_small_224.pth",
}, default="mae_in1k")
def mae_hiera_small_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs,
)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
}, default="mae_in1k")
def mae_hiera_base_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
}, default="mae_in1k")
def mae_hiera_base_plus_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
}, default="mae_in1k")
def mae_hiera_large_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
}, default="mae_in1k")
def mae_hiera_huge_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
)
# Video Models
@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
}, default="mae_k400")
def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs):
return MaskedAutoencoderHiera(
num_classes=num_classes, # K400 has 400 classes
input_size=(16, 224, 224),
q_stride=(1, 2, 2),
mask_unit_size=(1, 8, 8),
patch_kernel=(3, 7, 7),
patch_stride=(2, 4, 4),
patch_padding=(1, 3, 3),
sep_pos_embed=True,
q_pool=2,
**kwdargs
)
@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
}, default="mae_k400")
@pretrained_model(None)
def mae_hiera_base_plus_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
)
@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
}, default="mae_k400")
@pretrained_model(None)
def mae_hiera_large_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
)
@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
}, default="mae_k400")
def mae_hiera_huge_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
)