""" PyTorch Hiera Transformer model.""" # references # - https://github.com/facebookresearch/hiera/blob/main/hiera/hiera.py # - https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py import collections.abc import math import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union, Type, List import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss import torch.nn.functional as F from transformers.activations import ACT2FN from transformers.modeling_outputs import ( ImageClassifierOutput, BaseModelOutputWithPooling, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) from .configuration_hiera import HieraConfig logger = logging.get_logger(__name__) # General docstring _CONFIG_FOR_DOC = "HieraConfig" # Base docstring _CHECKPOINT_FOR_DOC = "/" _EXPECTED_OUTPUT_SHAPE = [1, 64, 768] # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "/" _IMAGE_CLASS_EXPECTED_OUTPUT = "" HIERA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "/", # See all Hiera models at https://huggingface.co/models?filter=hiera ] def conv_nd(n: int) -> Type[nn.Module]: """ Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3. If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises) """ return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor: # Refer to `Unroll` to see how this performs a maxpool-Nd return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor: # target_size: [(T), (H), W] # (spatial) mask: [B, C, (t), (h), w] if mask is None: return mask assert len(mask.shape[2:]) == len(target_size) if mask.shape[2:] != target_size: return F.interpolate(mask.float(), size=target_size) return mask def do_masked_conv( x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Zero-out the masked regions of the input before conv. Prevents leakage of masked regions when using overlapping kernels. """ if conv is None: return x if mask is None: return conv(x) mask = get_resized_mask(target_size=x.shape[2:], mask=mask) return conv(x * mask.bool()) def undo_windowing( x: torch.Tensor, shape: List[int], mu_shape: List[int] ) -> torch.Tensor: """ Restore spatial organization by undoing windowed organization of mask units. Args: x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C] shape: current spatial shape, if it were not organized into mask unit windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C]. mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx] Returns: x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C] """ D = len(shape) B, C = x.shape[0], x.shape[-1] # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C] num_MUs = [s // mu for s, mu in zip(shape, mu_shape)] x = x.view(B, *num_MUs, *mu_shape, C) # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C] permute = ( [0] + sum( [list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], [], ) + [len(x.shape) - 1] ) x = x.permute(permute).reshape(B, *shape, C) return x # Copied from transformers.models.swin.modeling_swin.drop_path def drop_path( input: torch.Tensor, drop_prob: float = 0.0, training: bool = False ) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * ( input.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand( shape, dtype=input.dtype, device=input.device ) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Hiera class HieraDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return "p={}".format(self.drop_prob) @dataclass # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2 class HieraEncoderOutput(ModelOutput): """ Hiera encoder's outputs, with potential hidden states and attentions. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. """ last_hidden_state: torch.FloatTensor hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass # Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2 class HieraMaskedImageModelingOutput(ModelOutput): """ Hiera masked image model outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): Masked image modeling (MLM) loss. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. """ reconstruction: torch.FloatTensor loss: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @property def logits(self): warnings.warn( "logits attribute is deprecated and will be removed in version 5 of Transformers." " Please use the reconstruction attribute to retrieve the final output instead.", FutureWarning, ) return self.reconstruction class HieraPretrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = HieraConfig base_model_prefix = "hiera" main_input_name = "pixel_values" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.constant_(module.bias, val=self.config.initializer_bias) elif isinstance(module, nn.LayerNorm): nn.init.constant_(module.bias, val=self.config.initializer_bias) nn.init.constant_(module.weight, 1.0) HIERA_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`HieraConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ HIERA_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class HieraUnroll(nn.Module): """ Reorders the tokens such that patches are contiguous in memory. E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as [B, (Sy, Sx, H // Sy, W // Sx), C] This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1). Not only is this faster, but it also makes it easy to support inputs of arbitrary dimensions in addition to patch-wise sparsity. Performing this operation multiple times in sequence puts entire windows as contiguous in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of size 8x8 would be contiguous in memory, allowing operations like mask unit attention computed easily and efficiently, while also allowing max to be applied sequentially. Note: This means that intermediate values of the model are not in HxW order, so they need to be re-rolled if you want to use the intermediate values as a HxW feature map. The last block of the network is fine though, since by then the strides are all consumed. """ def __init__( self, config: HieraConfig, ): super().__init__() image_size, stride_size = config.image_size, config.stride_size image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) self.size = [i // s for i, s in zip(image_size, stride_size)] self.schedule = [config.q_stride] * (len(config.depths) - 1) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Input: Flattened patch embeddings [B, N, C] Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd """ B, _, C = x.shape cur_size = self.size x = x.view(*([B] + cur_size + [C])) for strides in self.schedule: # Move patches with the given strides to the batch dimension # Create a view of the tensor with the patch stride as separate dims # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C] cur_size = [i // s for i, s in zip(cur_size, strides)] new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C] x = x.view(new_shape) # Move the patch stride into the batch dimension # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C] L = len(new_shape) permute = ( [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1] ) x = x.permute(permute) # Now finally flatten the relevant dims into the batch dimension x = x.flatten(0, len(strides)) B *= math.prod(strides) x = x.reshape(-1, math.prod(self.size), C) return x class HieraReroll(nn.Module): """ Undos the "unroll" operation so that you can use intermediate features. """ def __init__( self, config: HieraConfig, ): super().__init__() image_size, stride_size = config.image_size, config.stride_size image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) self.size = [i // s for i, s in zip(image_size, stride_size)] unroll_schedule = [config.q_stride] * (len(config.depths) - 1) # The first stage has to reverse everything # The next stage has to reverse all but the first unroll, etc. self.schedule = {} size = self.size for i in range(config.depths[-2]): self.schedule[i] = unroll_schedule, size # schedule unchanged if no pooling at a stage end if i + 1 in config.depths[: config.q_pool]: if len(unroll_schedule) > 0: size = [n // s for n, s in zip(size, unroll_schedule[0])] unroll_schedule = unroll_schedule[1:] def forward( self, x: torch.Tensor, block_idx: int, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Roll the given tensor back up to spatial order assuming it's from the given block. If no mask is provided: - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc. If a mask is provided: - Returns [B, #MUs, MUy, MUx, C] for 2d, etc. """ schedule, size = self.schedule[block_idx] B, N, C = x.shape D = len(size) cur_mu_shape = [1] * D for strides in schedule: # Extract the current patch from N x = x.view(B, *strides, N // int(math.prod(strides)), *cur_mu_shape, C) # Move that patch into the current MU # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C] L = len(x.shape) permute = ( [0, 1 + D] + sum( [list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], [], ) + [L - 1] ) x = x.permute(permute) # Reshape to [B, N//(Sy*Sx), *MU, C] for i in range(D): cur_mu_shape[i] *= strides[i] x = x.reshape(B, -1, *cur_mu_shape, C) N = x.shape[1] # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C]) x = x.view(B, N, *cur_mu_shape, C) # If masked, return [B, #MUs, MUy, MUx, C] if mask is not None: return x # If not masked, we can return [B, H, W, C] x = undo_windowing(x, size, cur_mu_shape) return x class HieraAttention(nn.Module): """ Computes either Mask Unit or Global Attention. Also is able to perform q pooling. Note: this assumes the tokens have already been flattened and unrolled into mask units. See `Unroll` for more details. """ def __init__( self, config: HieraConfig, dim: int, dim_out: int, num_heads: int, q_stride: int = 1, window_size: int = 0, use_mask_unit_attn: bool = False, ): """ Args: - dim, dim_out: The input and output feature dimensions. - heads: The number of attention heads. - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4). - window_size: The current (flattened) size of a mask unit *after* pooling (if any). - use_mask_unit_attn: Use Mask Unit or Global Attention. """ super().__init__() self.dim = dim self.dim_out = dim_out self.num_heads = num_heads self.q_stride = q_stride self.head_dim = dim_out // num_heads self.scale = (self.head_dim) ** -0.5 self.qkv = nn.Linear(dim, 3 * dim_out) self.proj = nn.Linear(dim_out, dim_out) self.window_size = window_size self.use_mask_unit_attn = use_mask_unit_attn def forward(self, x: torch.Tensor) -> torch.Tensor: """Input should be of shape [batch, tokens, channels].""" B, N, _ = x.shape num_windows = ( (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 ) qkv = ( self.qkv(x) .reshape(B, -1, num_windows, 3, self.num_heads, self.head_dim) .permute(3, 0, 4, 2, 1, 5) ) q, k, v = qkv[0], qkv[1], qkv[2] if self.q_stride > 1: # Refer to Unroll to see how this performs a maxpool-Nd q = ( q.view(B, self.num_heads, num_windows, self.q_stride, -1, self.head_dim) .max(dim=3) .values ) if hasattr(F, "scaled_dot_product_attention"): # Note: the original paper did *not* use SDPA, it's a free boost! x = F.scaled_dot_product_attention(q, k, v) else: attn = (q * self.scale) @ k.transpose(-1, -2) attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 3).reshape(B, -1, self.dim_out) x = self.proj(x) return x class HieraMLP(nn.Module): def __init__(self, config: HieraConfig, dim: int): super().__init__() self.fc1 = nn.Linear(dim, int(config.mlp_ratio * dim)) if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act self.dropout1 = nn.Dropout(config.hidden_dropout_prob) self.fc2 = nn.Linear(int(config.mlp_ratio * dim), dim) self.dropout2 = nn.Dropout(config.hidden_dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act_fn(x) x = self.dropout1(x) x = self.fc2(x) x = self.dropout2(x) return x class HieraLayer(nn.Module): def __init__( self, config: HieraConfig, dim: int, dim_out: int, num_heads: int, drop_path_rate: float = 0.0, q_stride: int = 1, window_size: int = 0, use_mask_unit_attn: bool = False, ): super().__init__() self.dim = dim self.dim_out = dim_out self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attn = HieraAttention( config=config, dim=dim, dim_out=dim_out, num_heads=num_heads, q_stride=q_stride, window_size=window_size, use_mask_unit_attn=use_mask_unit_attn, ) self.norm2 = nn.LayerNorm(dim_out, eps=config.layer_norm_eps) self.mlp = HieraMLP( config, dim=dim_out, ) self.drop_path = ( HieraDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) else: self.proj = None def forward(self, x: torch.Tensor) -> torch.Tensor: # Attention + Q Pooling x_norm = self.norm1(x) if self.proj is not None: x = do_pool(self.proj(x_norm), stride=self.attn.q_stride) x = x + self.drop_path(self.attn(x_norm)) # MLP x = x + self.drop_path(self.mlp(self.norm2(x))) return x class HieraStage(nn.Module): def __init__( self, config: HieraConfig, dim: int, depth: int, num_heads: int, window_size: int, has_q_pool: bool = True, drop_path_rate: float = 0.0, use_mask_unit_attention: bool = True, ): super().__init__() self.blocks = nn.ModuleList( [ HieraLayer( config=config, dim=dim // 2 if i == 0 and has_q_pool else dim, dim_out=dim, num_heads=num_heads, drop_path_rate=drop_path_rate, q_stride=(config.flat_q_stride if i == 0 and has_q_pool else 1), window_size=window_size, use_mask_unit_attn=use_mask_unit_attention, ) for i in range(depth) ] ) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: for _i, block in enumerate(self.blocks): hidden_states = block(hidden_states) return hidden_states class HieraPatchEmbeddings(nn.Module): """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" def __init__( self, config: HieraConfig, ): super().__init__() image_size, patch_size, stride_size, padding_size = ( config.image_size, config.patch_size, config.stride_size, config.padding_size, ) num_channels, hidden_size = config.num_channels, config.embed_dim image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) self.image_size = image_size self.patch_size = patch_size self.stride_size = stride_size self.padding_size = padding_size self.num_channels = num_channels self.num_patches = math.prod(patch_size) self.spatial_dims = len(patch_size) # Support any number of spatial dimensions self.projection = conv_nd(self.spatial_dims)( num_channels, hidden_size, kernel_size=patch_size, stride=stride_size, padding=padding_size, ) def forward( self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Tuple[int, ...]]: _, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) embeddings = do_masked_conv(pixel_values, self.projection, mask) _, _, height, width = embeddings.shape output_dimensions = (height, width) embeddings = embeddings.reshape( embeddings.shape[0], embeddings.shape[1], -1 ).transpose(2, 1) return embeddings, output_dimensions class HieraPositionEmbeddings(nn.Module): def __init__( self, config: HieraConfig, ): super().__init__() image_size, stride_size = config.image_size, config.stride_size image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) self.tokens_spatial_shape = [i // s for i, s in zip(image_size, stride_size)] num_tokens = math.prod(self.tokens_spatial_shape) self.separate_positional_embeds = config.separate_positional_embeds self.mask_spatial_shape = [ i // s for i, s in zip(self.tokens_spatial_shape, config.mask_unit_size) ] if self.separate_positional_embeds: self.pos_embeddings_spatial = nn.Parameter( torch.zeros( 1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], config.embed_dim, ) ) self.pos_embeddings_temporal = nn.Parameter( torch.zeros(1, self.tokens_spatial_shape[0], config.embed_dim) ) else: self.pos_embeddings = nn.Parameter( torch.zeros(1, num_tokens, config.embed_dim) ) def forward(self) -> torch.Tensor: if self.separate_positional_embeds: return self.pos_embeddings_spatial.repeat( 1, self.tokens_spatial_shape[0], 1 ) + torch.repeat_interleave( self.pos_embeddings_temporal, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], dim=1, ) else: return self.pos_embeddings class HieraEmbeddings(nn.Module): def __init__(self, config: HieraConfig): super().__init__() self.patch_embeddings = HieraPatchEmbeddings(config) self.pos_embeddings = HieraPositionEmbeddings(config) def forward( self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, ...]: embeddings, output_dimensions = self.patch_embeddings( pixel_values, mask=( mask.view(pixel_values.shape[0], 1, *self.mask_spatial_shape) if mask is not None else None ), ) embeddings = embeddings + self.pos_embeddings() return embeddings, output_dimensions class HieraEncoder(nn.Module): def __init__(self, config: HieraConfig): super().__init__() self.num_layers = len(config.depths) self.config = config dpr = [ x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)) ] self.layers = nn.ModuleList( [ HieraStage( config, dim=int(config.embed_dim * (2**i_layer)), depth=config.depths[i_layer], num_heads=config.num_heads[i_layer], drop_path_rate=dpr[i_layer], has_q_pool=i_layer > 0, window_size=config.flat_mask_unit_size // (config.flat_q_stride**i_layer), use_mask_unit_attention=config.mask_unit_attention[i_layer], ) for i_layer in range(self.num_layers) ] ) def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, HieraEncoderOutput]: all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: assert isinstance(all_hidden_states, Tuple) assert isinstance(all_reshaped_hidden_states, Tuple) batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view( batch_size, *input_dimensions, hidden_size ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) for _i, layer_module in enumerate(self.layers): layer_outputs = layer_module(hidden_states) hidden_states = layer_outputs if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_hidden_states] if v is not None ) return HieraEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshaped_hidden_states, ) class HieraHead(nn.Module): def __init__(self, config: HieraConfig): super().__init__() num_features = int(config.embed_dim * (2 ** (config.num_layers - 1))) self.dropout = ( nn.Dropout(config.hidden_dropout_prob) if config.hidden_dropout_prob > 0 else nn.Identity() ) self.projection = nn.Linear(num_features, config.num_labels) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) x = self.projection(x) return x class HieraModel(HieraPretrainedModel): def __init__( self, config: HieraConfig, add_pooling_layer=True, ): super().__init__(config) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * (2 ** (self.num_layers - 1))) self.embeddings = HieraEmbeddings(config) self.unroll = HieraUnroll(config) self.reroll = HieraReroll(config) self.encoder = HieraEncoder(config) self.norm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC, modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward( self, pixel_values: Optional[torch.BoolTensor] = None, mask: Optional[torch.BoolTensor] = None, # head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ """ mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output, input_dimensions = self.embeddings(pixel_values, mask=mask) unrolled_embedding = self.unroll(embedding_output) # Discard masked tokens if mask is not None: unrolled_embedding = unrolled_embedding[ mask[..., None].tile( 1, self.config.flat_mask_unit_size, unrolled_embedding.shape[2] ) ].view(unrolled_embedding.shape[0], -1, unrolled_embedding.shape[-1]) encoder_outputs = self.encoder(unrolled_embedding, input_dimensions) sequence_output = encoder_outputs[0].mean(dim=1) # last hidden states sequence_output = self.norm(sequence_output) pooled_output = None if self.pooler is not None: pooled_output = self.pooler(sequence_output.transpose(1, 0)) pooled_output = torch.flatten(pooled_output, 1) if not return_dict: output = (sequence_output, pooled_output) * encoder_outputs[1:] return output return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, # hidden_states=encoder_outputs.hidden_states ) @add_start_docstrings( """ Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """, HIERA_START_DOCSTRING, ) class HieraForImageClassification(HieraPretrainedModel): def __init__( self, config, add_pooling_layer=False, ): super().__init__( config, ) self.num_labels = config.num_labels self.hiera = HieraModel(config, add_pooling_layer=add_pooling_layer) # Classifier head self.head = HieraHead(config) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, # head_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, ImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.hiera( pixel_values, # head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_states = outputs[0] logits = self.head(last_hidden_states) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )