p1atdev commited on
Commit
0031938
1 Parent(s): 72034c5

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_hiera.py +140 -0
  2. modeling_hiera.py +1086 -0
configuration_hiera.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Hiera model configuration"""
2
+
3
+ import math
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ # HIERA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
11
+ # "hoge/hoge": ("/config.json"),
12
+ # }
13
+
14
+
15
+ class HieraConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate a Hiera
18
+ model according to the specified arguments, defining the model architecture. Instantiating a
19
+ configuration with the defaults will yield a similar configuration to that of the Hiera
20
+ [/]()
21
+ architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ image_size (`int`, *optional*, defaults to 224):
28
+ The size (resolution) of each image.
29
+ patch_size (`list(int)`, *optional*, defaults to [7, 7]):
30
+ The size (resolution) of each patch.
31
+ stride_size (`list(int)`, *optional*, defaults to [4, 4]):
32
+ The size (resolution) of each stride.
33
+ padding_size (`list(int)`, *optional*, defaults to [3, 3]):
34
+ The size (resolution) of each padding.
35
+ num_channels (`int`, *optional*, defaults to 3):
36
+ The number of input channels.
37
+ embed_dim (`int`, *optional*, defaults to 96):
38
+ Dimensionality of patch embedding.
39
+ depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
40
+ Depth of each layer in the Transformer encoder.
41
+ num_heads (`list(int)`, *optional*, defaults to `[1, 2, 4, 8]`):
42
+ Number of attention heads in each layer of the Transformer encoder.
43
+ q_pool (`int`, *optional*, defaults to 3):
44
+ Number of q_pool stages.
45
+ q_stride (`list(int)`, *optional*, defaults to [2, 2]):
46
+ Size of stride of q_pool,
47
+ mask_unit_size (`list(int)`, *optional*, defaults to [8, 8]):
48
+ Size of mask unit in attention.
49
+ mask_unit_attention (`list(bool)`, *optional*, defaults to [True, True, False, False]):
50
+ Whether or not to enable mask unit attention in each stage.
51
+ separate_positional_embeds (`bool`, *optional*, defaults to False):
52
+ Whether or not to use separeted positional embeddings.
53
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
54
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
55
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
56
+ Stochastic depth rate.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
59
+ `"selu"` and `"gelu_new"` are supported.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
61
+ The epsilon used by the layer normalization layers.
62
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
63
+ The dropout probability for all fully connected layers in the embeddings and encoder.
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ initializer_bias (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all bias matrices.
68
+
69
+ Example:
70
+
71
+ ```python
72
+ >>> from transformers import HieraConfig, HieraModel
73
+
74
+ >>> # Initializing a Hiera / style configuration
75
+ >>> configuration = HieraConfig()
76
+
77
+ >>> # Initializing a model (with random weights) from the / style configuration
78
+ >>> model = HieraModel(configuration)
79
+
80
+ >>> # Accessing the model configuration
81
+ >>> configuration = model.config
82
+ ```"""
83
+
84
+ model_type = "hiera"
85
+
86
+ attribute_map = {}
87
+
88
+ def __init__(
89
+ self,
90
+ image_size=224,
91
+ patch_size=[7, 7],
92
+ stride_size=[4, 4],
93
+ padding_size=[3, 3],
94
+ num_channels=3,
95
+ embed_dim=96,
96
+ depths=[2, 3, 16, 3],
97
+ num_heads=[1, 2, 4, 8],
98
+ q_pool=3, # number of q_pool stages
99
+ q_stride=[2, 2],
100
+ mask_unit_size=[8, 8],
101
+ mask_unit_attention=[True, True, False, False],
102
+ separate_positional_embeds=False,
103
+ mlp_ratio=4.0,
104
+ drop_path_rate=0.0,
105
+ hidden_act="gelu",
106
+ layer_norm_eps=1e-6,
107
+ hidden_dropout_prob=0.0,
108
+ initializer_range=0.02,
109
+ initializer_bias=0.02,
110
+ **kwargs,
111
+ ):
112
+ super().__init__(**kwargs)
113
+
114
+ self.image_size = image_size
115
+ self.patch_size = patch_size
116
+ self.stride_size = stride_size
117
+ self.padding_size = padding_size
118
+ self.num_channels = num_channels
119
+ self.embed_dim = embed_dim
120
+ self.depths = depths
121
+ self.num_layers = len(depths)
122
+ self.num_heads = num_heads
123
+ self.mlp_ratio = mlp_ratio
124
+ self.hidden_dropout_prob = hidden_dropout_prob
125
+ self.drop_path_rate = drop_path_rate
126
+ self.hidden_act = hidden_act
127
+ self.layer_norm_eps = layer_norm_eps
128
+
129
+ assert q_pool < len(depths), "q_pool must be less than depths"
130
+
131
+ self.mask_unit_size = mask_unit_size
132
+ self.flat_mask_unit_size = int(math.prod(mask_unit_size))
133
+ self.mask_unit_attention = mask_unit_attention
134
+ self.q_pool = q_pool
135
+ self.q_stride = q_stride
136
+ self.flat_q_stride = int(math.prod(q_stride))
137
+ self.separate_positional_embeds = separate_positional_embeds
138
+
139
+ self.initializer_range = initializer_range
140
+ self.initializer_bias = initializer_bias
modeling_hiera.py ADDED
@@ -0,0 +1,1086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Hiera Transformer model."""
2
+
3
+ import collections.abc
4
+ import math
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union, Type, List
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
+ import torch.nn.functional as F
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (
17
+ ImageClassifierOutput,
18
+ BaseModelOutputWithPooling,
19
+ )
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.utils import (
22
+ ModelOutput,
23
+ add_code_sample_docstrings,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ )
28
+
29
+ from .configuration_hiera import HieraConfig
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ # General docstring
35
+ _CONFIG_FOR_DOC = "HieraConfig"
36
+
37
+ # Base docstring
38
+ _CHECKPOINT_FOR_DOC = "/"
39
+ _EXPECTED_OUTPUT_SHAPE = [1, 64, 768]
40
+
41
+ # Image classification docstring
42
+ _IMAGE_CLASS_CHECKPOINT = "/"
43
+ _IMAGE_CLASS_EXPECTED_OUTPUT = ""
44
+
45
+
46
+ HIERA_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
+ "/",
48
+ # See all Hiera models at https://huggingface.co/models?filter=hiera
49
+ ]
50
+
51
+
52
+ def conv_nd(n: int) -> Type[nn.Module]:
53
+ """
54
+ Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
55
+ If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
56
+ """
57
+ return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
58
+
59
+
60
+ def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
61
+ # Refer to `Unroll` to see how this performs a maxpool-Nd
62
+ return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
63
+
64
+
65
+ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
66
+ # target_size: [(T), (H), W]
67
+ # (spatial) mask: [B, C, (t), (h), w]
68
+ if mask is None:
69
+ return mask
70
+
71
+ assert len(mask.shape[2:]) == len(target_size)
72
+ if mask.shape[2:] != target_size:
73
+ return F.interpolate(mask.float(), size=target_size)
74
+ return mask
75
+
76
+
77
+ def do_masked_conv(
78
+ x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None
79
+ ) -> torch.Tensor:
80
+ """Zero-out the masked regions of the input before conv.
81
+ Prevents leakage of masked regions when using overlapping kernels.
82
+ """
83
+ if conv is None:
84
+ return x
85
+ if mask is None:
86
+ return conv(x)
87
+
88
+ mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
89
+ return conv(x * mask.bool())
90
+
91
+
92
+ def undo_windowing(
93
+ x: torch.Tensor, shape: List[int], mu_shape: List[int]
94
+ ) -> torch.Tensor:
95
+ """
96
+ Restore spatial organization by undoing windowed organization of mask units.
97
+
98
+ Args:
99
+ x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
100
+ shape: current spatial shape, if it were not organized into mask unit
101
+ windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
102
+ mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
103
+ Returns:
104
+ x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
105
+ """
106
+ D = len(shape)
107
+ B, C = x.shape[0], x.shape[-1]
108
+ # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
109
+ num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
110
+ x = x.view(B, *num_MUs, *mu_shape, C)
111
+
112
+ # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
113
+ permute = (
114
+ [0]
115
+ + sum(
116
+ [list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))],
117
+ [],
118
+ )
119
+ + [len(x.shape) - 1]
120
+ )
121
+ x = x.permute(permute).reshape(B, *shape, C)
122
+
123
+ return x
124
+
125
+
126
+ # Copied from transformers.models.swin.modeling_swin.drop_path
127
+ def drop_path(
128
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
129
+ ) -> torch.Tensor:
130
+ """
131
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
132
+
133
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
134
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
135
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
136
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
137
+ argument.
138
+ """
139
+ if drop_prob == 0.0 or not training:
140
+ return input
141
+ keep_prob = 1 - drop_prob
142
+ shape = (input.shape[0],) + (1,) * (
143
+ input.ndim - 1
144
+ ) # work with diff dim tensors, not just 2D ConvNets
145
+ random_tensor = keep_prob + torch.rand(
146
+ shape, dtype=input.dtype, device=input.device
147
+ )
148
+ random_tensor.floor_() # binarize
149
+ output = input.div(keep_prob) * random_tensor
150
+ return output
151
+
152
+
153
+ # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Hiera
154
+ class HieraDropPath(nn.Module):
155
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
156
+
157
+ def __init__(self, drop_prob: float) -> None:
158
+ super().__init__()
159
+ self.drop_prob = drop_prob
160
+
161
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162
+ return drop_path(hidden_states, self.drop_prob, self.training)
163
+
164
+ def extra_repr(self) -> str:
165
+ return "p={}".format(self.drop_prob)
166
+
167
+
168
+ @dataclass
169
+ # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2
170
+ class HieraEncoderOutput(ModelOutput):
171
+ """
172
+ Hiera encoder's outputs, with potential hidden states and attentions.
173
+
174
+ Args:
175
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
176
+ Sequence of hidden-states at the output of the last layer of the model.
177
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
178
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
179
+ shape `(batch_size, sequence_length, hidden_size)`.
180
+
181
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
182
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
183
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
184
+ sequence_length)`.
185
+
186
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
187
+ heads.
188
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
189
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
190
+ shape `(batch_size, hidden_size, height, width)`.
191
+
192
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
193
+ include the spatial dimensions.
194
+ """
195
+
196
+ last_hidden_state: torch.FloatTensor
197
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
198
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
199
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
200
+
201
+
202
+ @dataclass
203
+ # Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2
204
+ class HieraMaskedImageModelingOutput(ModelOutput):
205
+ """
206
+ Hiera masked image model outputs.
207
+
208
+ Args:
209
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
210
+ Masked image modeling (MLM) loss.
211
+ reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
212
+ Reconstructed pixel values.
213
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
214
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
215
+ shape `(batch_size, sequence_length, hidden_size)`.
216
+
217
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
218
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
219
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
220
+ sequence_length)`.
221
+
222
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
223
+ heads.
224
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
225
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
226
+ shape `(batch_size, hidden_size, height, width)`.
227
+
228
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
229
+ include the spatial dimensions.
230
+ """
231
+
232
+ reconstruction: torch.FloatTensor
233
+ loss: Optional[torch.FloatTensor] = None
234
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
235
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
236
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
237
+
238
+ @property
239
+ def logits(self):
240
+ warnings.warn(
241
+ "logits attribute is deprecated and will be removed in version 5 of Transformers."
242
+ " Please use the reconstruction attribute to retrieve the final output instead.",
243
+ FutureWarning,
244
+ )
245
+ return self.reconstruction
246
+
247
+
248
+ class HieraPretrainedModel(PreTrainedModel):
249
+ """
250
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
251
+ models.
252
+ """
253
+
254
+ config_class = HieraConfig
255
+ base_model_prefix = "hiera"
256
+ main_input_name = "pixel_values"
257
+ supports_gradient_checkpointing = True
258
+
259
+ def _init_weights(self, module):
260
+ """Initialize the weights"""
261
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
262
+ nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
263
+ if isinstance(module, nn.Linear) and module.bias is not None:
264
+ nn.init.constant_(module.bias, val=self.config.initializer_bias)
265
+ elif isinstance(module, nn.LayerNorm):
266
+ nn.init.constant_(module.bias, val=self.config.initializer_bias)
267
+ nn.init.constant_(module.weight, 1.0)
268
+
269
+
270
+ HIERA_START_DOCSTRING = r"""
271
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
272
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
273
+ behavior.
274
+
275
+ Parameters:
276
+ config ([`HieraConfig`]): Model configuration class with all the parameters of the model.
277
+ Initializing with a config file does not load the weights associated with the model, only the
278
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
279
+ """
280
+
281
+ HIERA_INPUTS_DOCSTRING = r"""
282
+ Args:
283
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
284
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
285
+ for details.
286
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
287
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
288
+
289
+ - 1 indicates the head is **not masked**,
290
+ - 0 indicates the head is **masked**.
291
+
292
+ output_attentions (`bool`, *optional*):
293
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
294
+ tensors for more detail.
295
+ output_hidden_states (`bool`, *optional*):
296
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
297
+ more detail.
298
+ return_dict (`bool`, *optional*):
299
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
300
+ """
301
+
302
+
303
+ class HieraUnroll(nn.Module):
304
+ """
305
+ Reorders the tokens such that patches are contiguous in memory.
306
+ E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
307
+ [B, (Sy, Sx, H // Sy, W // Sx), C]
308
+
309
+ This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
310
+ Not only is this faster, but it also makes it easy to support inputs of arbitrary
311
+ dimensions in addition to patch-wise sparsity.
312
+
313
+ Performing this operation multiple times in sequence puts entire windows as contiguous
314
+ in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
315
+ size 8x8 would be contiguous in memory, allowing operations like mask unit attention
316
+ computed easily and efficiently, while also allowing max to be applied sequentially.
317
+
318
+ Note: This means that intermediate values of the model are not in HxW order, so they
319
+ need to be re-rolled if you want to use the intermediate values as a HxW feature map.
320
+ The last block of the network is fine though, since by then the strides are all consumed.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ config: HieraConfig,
326
+ ):
327
+ super().__init__()
328
+
329
+ image_size, stride_size = config.image_size, config.stride_size
330
+ image_size = (
331
+ image_size
332
+ if isinstance(image_size, collections.abc.Iterable)
333
+ else (image_size, image_size)
334
+ )
335
+
336
+ self.size = [i // s for i, s in zip(image_size, stride_size)]
337
+ self.schedule = [config.q_stride] * (len(config.depths) - 1)
338
+
339
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
340
+ """
341
+ Input: Flattened patch embeddings [B, N, C]
342
+ Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
343
+ """
344
+ B, _, C = x.shape
345
+
346
+ cur_size = self.size
347
+ x = x.view(*([B] + cur_size + [C]))
348
+
349
+ for strides in self.schedule:
350
+ # Move patches with the given strides to the batch dimension
351
+
352
+ # Create a view of the tensor with the patch stride as separate dims
353
+ # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
354
+ cur_size = [i // s for i, s in zip(cur_size, strides)]
355
+ new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
356
+ x = x.view(new_shape)
357
+
358
+ # Move the patch stride into the batch dimension
359
+ # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
360
+ L = len(new_shape)
361
+ permute = (
362
+ [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
363
+ )
364
+ x = x.permute(permute)
365
+
366
+ # Now finally flatten the relevant dims into the batch dimension
367
+ x = x.flatten(0, len(strides))
368
+ B *= math.prod(strides)
369
+
370
+ x = x.reshape(-1, math.prod(self.size), C)
371
+ return x
372
+
373
+
374
+ class HieraReroll(nn.Module):
375
+ """
376
+ Undos the "unroll" operation so that you can use intermediate features.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ config: HieraConfig,
382
+ ):
383
+ super().__init__()
384
+
385
+ image_size, stride_size = config.image_size, config.stride_size
386
+ image_size = (
387
+ image_size
388
+ if isinstance(image_size, collections.abc.Iterable)
389
+ else (image_size, image_size)
390
+ )
391
+
392
+ self.size = [i // s for i, s in zip(image_size, stride_size)]
393
+
394
+ unroll_schedule = [config.q_stride] * (len(config.depths) - 1)
395
+
396
+ # The first stage has to reverse everything
397
+ # The next stage has to reverse all but the first unroll, etc.
398
+ self.schedule = {}
399
+ size = self.size
400
+ for i in range(config.depths[-2]):
401
+ self.schedule[i] = unroll_schedule, size
402
+ # schedule unchanged if no pooling at a stage end
403
+ if i + 1 in config.depths[: config.q_pool]:
404
+ if len(unroll_schedule) > 0:
405
+ size = [n // s for n, s in zip(size, unroll_schedule[0])]
406
+ unroll_schedule = unroll_schedule[1:]
407
+
408
+ def forward(
409
+ self, x: torch.Tensor, block_idx: int, mask: Optional[torch.Tensor] = None
410
+ ) -> torch.Tensor:
411
+ """
412
+ Roll the given tensor back up to spatial order assuming it's from the given block.
413
+
414
+ If no mask is provided:
415
+ - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
416
+ If a mask is provided:
417
+ - Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
418
+ """
419
+ schedule, size = self.schedule[block_idx]
420
+ B, N, C = x.shape
421
+
422
+ D = len(size)
423
+ cur_mu_shape = [1] * D
424
+
425
+ for strides in schedule:
426
+ # Extract the current patch from N
427
+ x = x.view(B, *strides, N // int(math.prod(strides)), *cur_mu_shape, C)
428
+
429
+ # Move that patch into the current MU
430
+ # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
431
+ L = len(x.shape)
432
+ permute = (
433
+ [0, 1 + D]
434
+ + sum(
435
+ [list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))],
436
+ [],
437
+ )
438
+ + [L - 1]
439
+ )
440
+ x = x.permute(permute)
441
+
442
+ # Reshape to [B, N//(Sy*Sx), *MU, C]
443
+ for i in range(D):
444
+ cur_mu_shape[i] *= strides[i]
445
+ x = x.reshape(B, -1, *cur_mu_shape, C)
446
+ N = x.shape[1]
447
+
448
+ # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
449
+ x = x.view(B, N, *cur_mu_shape, C)
450
+
451
+ # If masked, return [B, #MUs, MUy, MUx, C]
452
+ if mask is not None:
453
+ return x
454
+
455
+ # If not masked, we can return [B, H, W, C]
456
+ x = undo_windowing(x, size, cur_mu_shape)
457
+
458
+ return x
459
+
460
+
461
+ class HieraAttention(nn.Module):
462
+ """
463
+ Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
464
+
465
+ Note: this assumes the tokens have already been flattened and unrolled into mask units.
466
+ See `Unroll` for more details.
467
+ """
468
+
469
+ def __init__(
470
+ self,
471
+ config: HieraConfig,
472
+ dim: int,
473
+ dim_out: int,
474
+ num_heads: int,
475
+ q_stride: int = 1,
476
+ window_size: int = 0,
477
+ use_mask_unit_attn: bool = False,
478
+ ):
479
+ """
480
+ Args:
481
+ - dim, dim_out: The input and output feature dimensions.
482
+ - heads: The number of attention heads.
483
+ - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
484
+ - window_size: The current (flattened) size of a mask unit *after* pooling (if any).
485
+ - use_mask_unit_attn: Use Mask Unit or Global Attention.
486
+ """
487
+ super().__init__()
488
+
489
+ self.dim = dim
490
+ self.dim_out = dim_out
491
+ self.num_heads = num_heads
492
+ self.q_stride = q_stride
493
+
494
+ self.head_dim = dim_out // num_heads
495
+ self.scale = (self.head_dim) ** -0.5
496
+
497
+ self.qkv = nn.Linear(dim, 3 * dim_out)
498
+ self.proj = nn.Linear(dim_out, dim_out)
499
+
500
+ self.window_size = window_size
501
+ self.use_mask_unit_attn = use_mask_unit_attn
502
+
503
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
504
+ """Input should be of shape [batch, tokens, channels]."""
505
+ B, N, _ = x.shape
506
+ num_windows = (
507
+ (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
508
+ )
509
+
510
+ qkv = (
511
+ self.qkv(x)
512
+ .reshape(B, -1, num_windows, 3, self.num_heads, self.head_dim)
513
+ .permute(3, 0, 4, 2, 1, 5)
514
+ )
515
+ q, k, v = qkv[0], qkv[1], qkv[2]
516
+
517
+ if self.q_stride > 1:
518
+ # Refer to Unroll to see how this performs a maxpool-Nd
519
+ q = (
520
+ q.view(B, self.num_heads, num_windows, self.q_stride, -1, self.head_dim)
521
+ .max(dim=3)
522
+ .values
523
+ )
524
+
525
+ if hasattr(F, "scaled_dot_product_attention"):
526
+ # Note: the original paper did *not* use SDPA, it's a free boost!
527
+ x = F.scaled_dot_product_attention(q, k, v)
528
+ else:
529
+ attn = (q * self.scale) @ k.transpose(-1, -2)
530
+ attn = attn.softmax(dim=-1)
531
+ x = attn @ v
532
+
533
+ x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
534
+ x = self.proj(x)
535
+ return x
536
+
537
+
538
+ class HieraMLP(nn.Module):
539
+ def __init__(self, config: HieraConfig, dim: int):
540
+ super().__init__()
541
+
542
+ self.fc1 = nn.Linear(dim, int(config.mlp_ratio * dim))
543
+ if isinstance(config.hidden_act, str):
544
+ self.act_fn = ACT2FN[config.hidden_act]
545
+ else:
546
+ self.act_fn = config.hidden_act
547
+ self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
548
+ self.fc2 = nn.Linear(int(config.mlp_ratio * dim), dim)
549
+ self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
550
+
551
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
552
+ x = self.fc1(x)
553
+ x = self.act_fn(x)
554
+ x = self.dropout1(x)
555
+ x = self.fc2(x)
556
+ x = self.dropout2(x)
557
+ return x
558
+
559
+
560
+ class HieraLayer(nn.Module):
561
+ def __init__(
562
+ self,
563
+ config: HieraConfig,
564
+ dim: int,
565
+ dim_out: int,
566
+ num_heads: int,
567
+ drop_path_rate: float = 0.0,
568
+ q_stride: int = 1,
569
+ window_size: int = 0,
570
+ use_mask_unit_attn: bool = False,
571
+ ):
572
+ super().__init__()
573
+
574
+ self.dim = dim
575
+ self.dim_out = dim_out
576
+
577
+ self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
578
+ self.attn = HieraAttention(
579
+ config=config,
580
+ dim=dim,
581
+ dim_out=dim_out,
582
+ num_heads=num_heads,
583
+ q_stride=q_stride,
584
+ window_size=window_size,
585
+ use_mask_unit_attn=use_mask_unit_attn,
586
+ )
587
+
588
+ self.norm2 = nn.LayerNorm(dim_out, eps=config.layer_norm_eps)
589
+ self.mlp = HieraMLP(
590
+ config,
591
+ dim=dim_out,
592
+ )
593
+
594
+ self.drop_path = (
595
+ HieraDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
596
+ )
597
+ if dim != dim_out:
598
+ self.proj = nn.Linear(dim, dim_out)
599
+ else:
600
+ self.proj = None
601
+
602
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
603
+ # Attention + Q Pooling
604
+ x_norm = self.norm1(x)
605
+
606
+ if self.proj is not None:
607
+ x = do_pool(self.proj(x_norm), stride=self.attn.q_stride)
608
+ x = x + self.drop_path(self.attn(x_norm))
609
+
610
+ # MLP
611
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
612
+
613
+ return x
614
+
615
+
616
+ class HieraStage(nn.Module):
617
+ def __init__(
618
+ self,
619
+ config: HieraConfig,
620
+ dim: int,
621
+ depth: int,
622
+ num_heads: int,
623
+ window_size: int,
624
+ has_q_pool: bool = True,
625
+ drop_path_rate: float = 0.0,
626
+ use_mask_unit_attention: bool = True,
627
+ ):
628
+ super().__init__()
629
+
630
+ self.blocks = nn.ModuleList(
631
+ [
632
+ HieraLayer(
633
+ config=config,
634
+ dim=dim // 2 if i == 0 and has_q_pool else dim,
635
+ dim_out=dim,
636
+ num_heads=num_heads,
637
+ drop_path_rate=drop_path_rate,
638
+ q_stride=(config.flat_q_stride if i == 0 and has_q_pool else 1),
639
+ window_size=window_size,
640
+ use_mask_unit_attn=use_mask_unit_attention,
641
+ )
642
+ for i in range(depth)
643
+ ]
644
+ )
645
+
646
+ def forward(
647
+ self,
648
+ hidden_states: torch.Tensor,
649
+ ) -> torch.Tensor:
650
+ for _i, block in enumerate(self.blocks):
651
+ hidden_states = block(hidden_states)
652
+
653
+ return hidden_states
654
+
655
+
656
+ class HieraPatchEmbeddings(nn.Module):
657
+ """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
658
+
659
+ def __init__(
660
+ self,
661
+ config: HieraConfig,
662
+ ):
663
+ super().__init__()
664
+ image_size, patch_size, stride_size, padding_size = (
665
+ config.image_size,
666
+ config.patch_size,
667
+ config.stride_size,
668
+ config.padding_size,
669
+ )
670
+ num_channels, hidden_size = config.num_channels, config.embed_dim
671
+ image_size = (
672
+ image_size
673
+ if isinstance(image_size, collections.abc.Iterable)
674
+ else (image_size, image_size)
675
+ )
676
+
677
+ self.image_size = image_size
678
+ self.patch_size = patch_size
679
+ self.stride_size = stride_size
680
+ self.padding_size = padding_size
681
+ self.num_channels = num_channels
682
+
683
+ self.num_patches = math.prod(patch_size)
684
+
685
+ self.spatial_dims = len(patch_size)
686
+
687
+ # Support any number of spatial dimensions
688
+ self.projection = conv_nd(self.spatial_dims)(
689
+ num_channels,
690
+ hidden_size,
691
+ kernel_size=patch_size,
692
+ stride=stride_size,
693
+ padding=padding_size,
694
+ )
695
+
696
+ def forward(
697
+ self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None
698
+ ) -> Tuple[torch.Tensor, Tuple[int, ...]]:
699
+ _, num_channels, height, width = pixel_values.shape
700
+ if num_channels != self.num_channels:
701
+ raise ValueError(
702
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
703
+ )
704
+
705
+ embeddings = do_masked_conv(pixel_values, self.projection, mask)
706
+
707
+ _, _, height, width = embeddings.shape
708
+ output_dimensions = (height, width)
709
+
710
+ embeddings = embeddings.reshape(
711
+ embeddings.shape[0], embeddings.shape[1], -1
712
+ ).transpose(2, 1)
713
+
714
+ return embeddings, output_dimensions
715
+
716
+
717
+ class HieraPositionEmbeddings(nn.Module):
718
+ def __init__(
719
+ self,
720
+ config: HieraConfig,
721
+ ):
722
+ super().__init__()
723
+
724
+ image_size, stride_size = config.image_size, config.stride_size
725
+ image_size = (
726
+ image_size
727
+ if isinstance(image_size, collections.abc.Iterable)
728
+ else (image_size, image_size)
729
+ )
730
+
731
+ self.tokens_spatial_shape = [i // s for i, s in zip(image_size, stride_size)]
732
+ num_tokens = math.prod(self.tokens_spatial_shape)
733
+ self.separate_positional_embeds = config.separate_positional_embeds
734
+ self.mask_spatial_shape = [
735
+ i // s for i, s in zip(self.tokens_spatial_shape, config.mask_unit_size)
736
+ ]
737
+
738
+ if self.separate_positional_embeds:
739
+ self.pos_embeddings_spatial = nn.Parameter(
740
+ torch.zeros(
741
+ 1,
742
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
743
+ config.embed_dim,
744
+ )
745
+ )
746
+ self.pos_embeddings_temporal = nn.Parameter(
747
+ torch.zeros(1, self.tokens_spatial_shape[0], config.embed_dim)
748
+ )
749
+ else:
750
+ self.pos_embeddings = nn.Parameter(
751
+ torch.zeros(1, num_tokens, config.embed_dim)
752
+ )
753
+
754
+ def forward(self) -> torch.Tensor:
755
+ if self.separate_positional_embeds:
756
+ return self.pos_embeddings_spatial.repeat(
757
+ 1, self.tokens_spatial_shape[0], 1
758
+ ) + torch.repeat_interleave(
759
+ self.pos_embeddings_temporal,
760
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
761
+ dim=1,
762
+ )
763
+ else:
764
+ return self.pos_embeddings
765
+
766
+
767
+ class HieraEmbeddings(nn.Module):
768
+ def __init__(self, config: HieraConfig):
769
+ super().__init__()
770
+
771
+ self.patch_embeddings = HieraPatchEmbeddings(config)
772
+ self.pos_embeddings = HieraPositionEmbeddings(config)
773
+
774
+ def forward(
775
+ self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None
776
+ ) -> Tuple[torch.Tensor, ...]:
777
+ embeddings, output_dimensions = self.patch_embeddings(
778
+ pixel_values,
779
+ mask=(
780
+ mask.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
781
+ if mask is not None
782
+ else None
783
+ ),
784
+ )
785
+ embeddings = embeddings + self.pos_embeddings()
786
+
787
+ return embeddings, output_dimensions
788
+
789
+
790
+ class HieraEncoder(nn.Module):
791
+ def __init__(self, config: HieraConfig):
792
+ super().__init__()
793
+
794
+ self.num_layers = len(config.depths)
795
+ self.config = config
796
+
797
+ dpr = [
798
+ x.item()
799
+ for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))
800
+ ]
801
+
802
+ self.layers = nn.ModuleList(
803
+ [
804
+ HieraStage(
805
+ config,
806
+ dim=int(config.embed_dim * (2**i_layer)),
807
+ depth=config.depths[i_layer],
808
+ num_heads=config.num_heads[i_layer],
809
+ drop_path_rate=dpr[i_layer],
810
+ has_q_pool=i_layer > 0,
811
+ window_size=config.flat_mask_unit_size
812
+ // (config.flat_q_stride**i_layer),
813
+ use_mask_unit_attention=config.mask_unit_attention[i_layer],
814
+ )
815
+ for i_layer in range(self.num_layers)
816
+ ]
817
+ )
818
+
819
+ def forward(
820
+ self,
821
+ hidden_states: torch.Tensor,
822
+ input_dimensions: Tuple[int, int],
823
+ output_attentions: Optional[bool] = False,
824
+ output_hidden_states: Optional[bool] = False,
825
+ return_dict: Optional[bool] = True,
826
+ ) -> Union[Tuple, HieraEncoderOutput]:
827
+ all_hidden_states = () if output_hidden_states else None
828
+ all_reshaped_hidden_states = () if output_hidden_states else None
829
+ all_self_attentions = () if output_attentions else None
830
+
831
+ if output_hidden_states:
832
+ assert isinstance(all_hidden_states, Tuple)
833
+ assert isinstance(all_reshaped_hidden_states, Tuple)
834
+
835
+ batch_size, _, hidden_size = hidden_states.shape
836
+ # rearrange b (h w) c -> b c h w
837
+ reshaped_hidden_state = hidden_states.view(
838
+ batch_size, *input_dimensions, hidden_size
839
+ )
840
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
841
+ all_hidden_states += (hidden_states,)
842
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
843
+
844
+ for _i, layer_module in enumerate(self.layers):
845
+
846
+ layer_outputs = layer_module(hidden_states)
847
+
848
+ hidden_states = layer_outputs
849
+
850
+
851
+ if not return_dict:
852
+ return tuple(
853
+ v
854
+ for v in [hidden_states, all_hidden_states, all_hidden_states]
855
+ if v is not None
856
+ )
857
+
858
+ return HieraEncoderOutput(
859
+ last_hidden_state=hidden_states,
860
+ hidden_states=all_hidden_states,
861
+ attentions=all_self_attentions,
862
+ reshaped_hidden_states=all_reshaped_hidden_states,
863
+ )
864
+
865
+
866
+ class HieraHead(nn.Module):
867
+ def __init__(self, config: HieraConfig):
868
+ super().__init__()
869
+
870
+ num_features = int(config.embed_dim * (2 ** (config.num_layers - 1)))
871
+
872
+ self.dropout = (
873
+ nn.Dropout(config.hidden_dropout_prob)
874
+ if config.hidden_dropout_prob > 0
875
+ else nn.Identity()
876
+ )
877
+ self.projection = nn.Linear(num_features, config.num_labels)
878
+
879
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
880
+ x = self.dropout(x)
881
+ x = self.projection(x)
882
+
883
+ return x
884
+
885
+
886
+ class HieraModel(HieraPretrainedModel):
887
+ def __init__(
888
+ self,
889
+ config: HieraConfig,
890
+ add_pooling_layer=True,
891
+ ):
892
+ super().__init__(config)
893
+
894
+ self.config = config
895
+ self.num_layers = len(config.depths)
896
+ self.num_features = int(config.embed_dim * (2 ** (self.num_layers - 1)))
897
+
898
+ self.embeddings = HieraEmbeddings(config)
899
+ self.unroll = HieraUnroll(config)
900
+ self.reroll = HieraReroll(config)
901
+
902
+ self.encoder = HieraEncoder(config)
903
+
904
+ self.norm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
905
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
906
+
907
+ # Initialize weights and apply final processing
908
+ self.post_init()
909
+
910
+ def get_input_embeddings(self):
911
+ return self.embeddings.patch_embeddings
912
+
913
+ @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
914
+ @add_code_sample_docstrings(
915
+ checkpoint=_CHECKPOINT_FOR_DOC,
916
+ output_type=BaseModelOutputWithPooling,
917
+ config_class=_CONFIG_FOR_DOC,
918
+ modality="vision",
919
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
920
+ )
921
+ def forward(
922
+ self,
923
+ pixel_values: Optional[torch.BoolTensor] = None,
924
+ mask: Optional[torch.BoolTensor] = None,
925
+ # head_mask: Optional[torch.FloatTensor] = None,
926
+ output_attentions: Optional[bool] = None,
927
+ output_hidden_states: Optional[bool] = None,
928
+ return_dict: Optional[bool] = None,
929
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
930
+ r"""
931
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
932
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
933
+ """
934
+ """
935
+ mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
936
+ Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
937
+ """
938
+
939
+ output_attentions = (
940
+ output_attentions
941
+ if output_attentions is not None
942
+ else self.config.output_attentions
943
+ )
944
+ output_hidden_states = (
945
+ output_hidden_states
946
+ if output_hidden_states is not None
947
+ else self.config.output_hidden_states
948
+ )
949
+ return_dict = (
950
+ return_dict if return_dict is not None else self.config.use_return_dict
951
+ )
952
+
953
+ if pixel_values is None:
954
+ raise ValueError("You have to specify pixel_values")
955
+
956
+ embedding_output, input_dimensions = self.embeddings(pixel_values, mask=mask)
957
+ unrolled_embedding = self.unroll(embedding_output)
958
+
959
+ # Discard masked tokens
960
+ if mask is not None:
961
+ unrolled_embedding = unrolled_embedding[
962
+ mask[..., None].tile(
963
+ 1, self.config.flat_mask_unit_size, unrolled_embedding.shape[2]
964
+ )
965
+ ].view(unrolled_embedding.shape[0], -1, unrolled_embedding.shape[-1])
966
+
967
+ encoder_outputs = self.encoder(unrolled_embedding, input_dimensions)
968
+
969
+ sequence_output = encoder_outputs[0].mean(dim=1) # last hidden states
970
+ sequence_output = self.norm(sequence_output)
971
+
972
+ pooled_output = None
973
+ if self.pooler is not None:
974
+ pooled_output = self.pooler(sequence_output.transpose(1, 0))
975
+ pooled_output = torch.flatten(pooled_output, 1)
976
+
977
+ if not return_dict:
978
+ output = (sequence_output, pooled_output) * encoder_outputs[1:]
979
+ return output
980
+
981
+ return BaseModelOutputWithPooling(
982
+ last_hidden_state=sequence_output,
983
+ pooler_output=pooled_output,
984
+ # hidden_states=encoder_outputs.hidden_states
985
+ )
986
+
987
+
988
+ @add_start_docstrings(
989
+ """
990
+ Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state
991
+ of the [CLS] token) e.g. for ImageNet.
992
+ """,
993
+ HIERA_START_DOCSTRING,
994
+ )
995
+ class HieraForImageClassification(HieraPretrainedModel):
996
+ def __init__(
997
+ self,
998
+ config,
999
+ add_pooling_layer=False,
1000
+ ):
1001
+ super().__init__(
1002
+ config,
1003
+ )
1004
+
1005
+ self.num_labels = config.num_labels
1006
+ self.hiera = HieraModel(config, add_pooling_layer=add_pooling_layer)
1007
+
1008
+ # Classifier head
1009
+ self.head = HieraHead(config)
1010
+
1011
+ # Initialize weights and apply final processing
1012
+ self.post_init()
1013
+
1014
+ @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1015
+ @add_code_sample_docstrings(
1016
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1017
+ output_type=ImageClassifierOutput,
1018
+ config_class=_CONFIG_FOR_DOC,
1019
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1020
+ )
1021
+ def forward(
1022
+ self,
1023
+ pixel_values: Optional[torch.FloatTensor] = None,
1024
+ # head_mask: Optional[torch.FloatTensor] = None,
1025
+ labels: Optional[torch.LongTensor] = None,
1026
+ output_attentions: Optional[bool] = None,
1027
+ output_hidden_states: Optional[bool] = None,
1028
+ return_dict: Optional[bool] = None,
1029
+ ) -> Union[Tuple, ImageClassifierOutput]:
1030
+ r"""
1031
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1032
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1033
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1034
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1035
+ """
1036
+ return_dict = (
1037
+ return_dict if return_dict is not None else self.config.use_return_dict
1038
+ )
1039
+
1040
+ outputs = self.hiera(
1041
+ pixel_values,
1042
+ # head_mask=head_mask,
1043
+ output_attentions=output_attentions,
1044
+ output_hidden_states=output_hidden_states,
1045
+ return_dict=return_dict,
1046
+ )
1047
+
1048
+ last_hidden_states = outputs[0]
1049
+
1050
+ logits = self.head(last_hidden_states)
1051
+
1052
+ loss = None
1053
+ if labels is not None:
1054
+ if self.config.problem_type is None:
1055
+ if self.num_labels == 1:
1056
+ self.config.problem_type = "regression"
1057
+ elif self.num_labels > 1 and (
1058
+ labels.dtype == torch.long or labels.dtype == torch.int
1059
+ ):
1060
+ self.config.problem_type = "single_label_classification"
1061
+ else:
1062
+ self.config.problem_type = "multi_label_classification"
1063
+
1064
+ if self.config.problem_type == "regression":
1065
+ loss_fct = MSELoss()
1066
+ if self.num_labels == 1:
1067
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1068
+ else:
1069
+ loss = loss_fct(logits, labels)
1070
+ elif self.config.problem_type == "single_label_classification":
1071
+ loss_fct = CrossEntropyLoss()
1072
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1073
+ elif self.config.problem_type == "multi_label_classification":
1074
+ loss_fct = BCEWithLogitsLoss()
1075
+ loss = loss_fct(logits, labels)
1076
+
1077
+ if not return_dict:
1078
+ output = (logits,) + outputs[2:]
1079
+ return ((loss,) + output) if loss is not None else output
1080
+
1081
+ return ImageClassifierOutput(
1082
+ loss=loss,
1083
+ logits=logits,
1084
+ hidden_states=outputs.hidden_states,
1085
+ attentions=outputs.attentions,
1086
+ )