hiera-tiny-ft-224-in1k / hiera /hiera_utils.py
merve's picture
merve HF staff
Upload 4 files
e5d3156
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
#
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
#
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
#
# Paper: https://arxiv.org/abs/2306.00989/
#
# References:
# slowfast: https://github.com/facebookresearch/SlowFast
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
import math
from typing import List, Tuple, Optional, Type, Callable, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
def pretrained_model(checkpoints: Dict[str, str], default: str = None) -> Callable:
""" Loads a Hiera model from a pretrained source (if pretrained=True). Use "checkpoint" to specify the checkpoint. """
def inner(model_func: Callable) -> Callable:
def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool = True, **kwdargs) -> nn.Module:
if pretrained:
if checkpoints is None:
raise RuntimeError("This model currently doesn't have pretrained weights available.")
elif checkpoint is None:
raise RuntimeError("No checkpoint specified.")
elif checkpoint not in checkpoints:
raise RuntimeError(f"Invalid checkpoint specified ({checkpoint}). Options are: {list(checkpoints.keys())}.")
state_dict = torch.hub.load_state_dict_from_url(checkpoints[checkpoint], map_location="cpu")
if "head.projection.weight" in state_dict["model_state"]:
# Set the number of classes equal to the state_dict only if the user doesn't want to overwrite it
if "num_classes" not in kwdargs:
kwdargs["num_classes"] = state_dict["model_state"]["head.projection.weight"].shape[0]
# If the user specified a different number of classes, remove the projection weights or else we'll error out
elif kwdargs["num_classes"] != state_dict["model_state"]["head.projection.weight"].shape[0]:
del state_dict["model_state"]["head.projection.weight"]
del state_dict["model_state"]["head.projection.bias"]
model = model_func(**kwdargs)
if pretrained:
# Disable being strict when trying to load a encoder-decoder model into an encoder-only model
if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"):
strict = False
model.load_state_dict(state_dict["model_state"], strict=strict)
return model
return model_def
return inner
def conv_nd(n: int) -> Type[nn.Module]:
"""
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
"""
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
# Refer to `Unroll` to see how this performs a maxpool-Nd
return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
# target_size: [(T), (H), W]
# (spatial) mask: [B, C, (t), (h), w]
if mask is None:
return mask
assert len(mask.shape[2:]) == len(target_size)
if mask.shape[2:] != target_size:
return F.interpolate(mask.float(), size=target_size)
return mask
def do_masked_conv(
x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Zero-out the masked regions of the input before conv.
Prevents leakage of masked regions when using overlapping kernels.
"""
if conv is None:
return x
if mask is None:
return conv(x)
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
return conv(x * mask.bool())
def undo_windowing(
x: torch.Tensor, shape: List[int], mu_shape: List[int]
) -> torch.Tensor:
"""
Restore spatial organization by undoing windowed organization of mask units.
Args:
x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
shape: current spatial shape, if it were not organized into mask unit
windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
Returns:
x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
"""
D = len(shape)
B, C = x.shape[0], x.shape[-1]
# [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
x = x.view(B, *num_MUs, *mu_shape, C)
# [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
permute = (
[0]
+ sum(
[list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))],
[],
)
+ [len(x.shape) - 1]
)
x = x.permute(permute).reshape(B, *shape, C)
return x
class Unroll(nn.Module):
"""
Reorders the tokens such that patches are contiguous in memory.
E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
[B, (Sy, Sx, H // Sy, W // Sx), C]
This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
Not only is this faster, but it also makes it easy to support inputs of arbitrary
dimensions in addition to patch-wise sparsity.
Performing this operation multiple times in sequence puts entire windows as contiguous
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
computed easily and efficiently, while also allowing max to be applied sequentially.
Note: This means that intermediate values of the model are not in HxW order, so they
need to be re-rolled if you want to use the intermediate values as a HxW feature map.
The last block of the network is fine though, since by then the strides are all consumed.
"""
def __init__(
self,
input_size: Tuple[int, ...],
patch_stride: Tuple[int, ...],
unroll_schedule: List[Tuple[int, ...]],
):
super().__init__()
self.size = [i // s for i, s in zip(input_size, patch_stride)]
self.schedule = unroll_schedule
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input: Flattened patch embeddings [B, N, C]
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
"""
B, _, C = x.shape
cur_size = self.size
x = x.view(*([B] + cur_size + [C]))
for strides in self.schedule:
# Move patches with the given strides to the batch dimension
# Create a view of the tensor with the patch stride as separate dims
# For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
cur_size = [i // s for i, s in zip(cur_size, strides)]
new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
x = x.view(new_shape)
# Move the patch stride into the batch dimension
# For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
L = len(new_shape)
permute = (
[0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
)
x = x.permute(permute)
# Now finally flatten the relevant dims into the batch dimension
x = x.flatten(0, len(strides))
B *= math.prod(strides)
x = x.reshape(-1, math.prod(self.size), C)
return x
class Reroll(nn.Module):
"""
Undos the "unroll" operation so that you can use intermediate features.
"""
def __init__(
self,
input_size: Tuple[int, ...],
patch_stride: Tuple[int, ...],
unroll_schedule: List[Tuple[int, ...]],
stage_ends: List[int],
q_pool: int,
):
super().__init__()
self.size = [i // s for i, s in zip(input_size, patch_stride)]
# The first stage has to reverse everything
# The next stage has to reverse all but the first unroll, etc.
self.schedule = {}
size = self.size
for i in range(stage_ends[-1] + 1):
self.schedule[i] = unroll_schedule, size
# schedule unchanged if no pooling at a stage end
if i in stage_ends[:q_pool]:
if len(unroll_schedule) > 0:
size = [n // s for n, s in zip(size, unroll_schedule[0])]
unroll_schedule = unroll_schedule[1:]
def forward(
self, x: torch.Tensor, block_idx: int, mask: torch.Tensor = None
) -> torch.Tensor:
"""
Roll the given tensor back up to spatial order assuming it's from the given block.
If no mask is provided:
- Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
If a mask is provided:
- Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
"""
schedule, size = self.schedule[block_idx]
B, N, C = x.shape
D = len(size)
cur_mu_shape = [1] * D
for strides in schedule:
# Extract the current patch from N
x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
# Move that patch into the current MU
# Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
L = len(x.shape)
permute = (
[0, 1 + D]
+ sum(
[list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))],
[],
)
+ [L - 1]
)
x = x.permute(permute)
# Reshape to [B, N//(Sy*Sx), *MU, C]
for i in range(D):
cur_mu_shape[i] *= strides[i]
x = x.reshape(B, -1, *cur_mu_shape, C)
N = x.shape[1]
# Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
x = x.view(B, N, *cur_mu_shape, C)
# If masked, return [B, #MUs, MUy, MUx, C]
if mask is not None:
return x
# If not masked, we can return [B, H, W, C]
x = undo_windowing(x, size, cur_mu_shape)
return x