toilaluan commited on
Commit
09b2c2d
·
verified ·
1 Parent(s): 4fb1fc2

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - vision
5
+ - image-reconstruction
6
+ - siglip2
7
+ - safetensors
8
+ ---
9
+
10
+ # F2P Decoder
11
+
12
+ Hugging Face `AutoModel` wrapper for the SigLIP2 feature-to-pixel decoder used in this repository.
13
+
14
+ ```python
15
+ import torch
16
+ from transformers import AutoModel
17
+
18
+ model = AutoModel.from_pretrained(
19
+ "your-namespace/f2p_decoder",
20
+ trust_remote_code=True,
21
+ ).eval()
22
+
23
+ features = torch.randn(1, 257, 1152)
24
+ reconstruction = model(features)
25
+ print(reconstruction.shape) # (1, 3, 224, 224)
26
+ ```
27
+
28
+ The model expects SigLIP2 patch features with a CLS token, for example from
29
+ `google/siglip2-so400m-patch14-224`. The output is an image tensor in the
30
+ decoder's reconstructed pixel space.
31
+
32
+ Source `.pt` checkpoint: `nyu-visionx/siglip2_decoder/model.pt`.
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .configuration_f2p_decoder import F2PDecoderConfig
2
+ from .modeling_f2p_decoder import F2PDecoderModel
3
+
4
+ __all__ = ["F2PDecoderConfig", "F2PDecoderModel"]
__pycache__/configuration_f2p_decoder.cpython-312.pyc ADDED
Binary file (3.53 kB). View file
 
__pycache__/decoder.cpython-312.pyc ADDED
Binary file (53.7 kB). View file
 
__pycache__/modeling_f2p_decoder.cpython-312.pyc ADDED
Binary file (4.62 kB). View file
 
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "F2PDecoderModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_f2p_decoder.F2PDecoderConfig",
8
+ "AutoModel": "modeling_f2p_decoder.F2PDecoderModel"
9
+ },
10
+ "decoder_hidden_size": 1152,
11
+ "decoder_intermediate_size": 4096,
12
+ "decoder_num_attention_heads": 16,
13
+ "decoder_num_hidden_layers": 28,
14
+ "drop_cls_token": true,
15
+ "dtype": "float32",
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.0,
18
+ "hidden_size": 1152,
19
+ "image_mean": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "image_size": 224,
25
+ "image_std": [
26
+ 0.5,
27
+ 0.5,
28
+ 0.5
29
+ ],
30
+ "initializer_range": 0.02,
31
+ "layer_norm_eps": 1e-12,
32
+ "model_type": "f2p_decoder",
33
+ "num_channels": 3,
34
+ "num_patches": 256,
35
+ "patch_size": 14,
36
+ "pretrained_encoder_name": "google/siglip2-so400m-patch14-224",
37
+ "qkv_bias": true,
38
+ "source_decoder_repo": "nyu-visionx/siglip2_decoder",
39
+ "transformers_version": "4.57.6"
40
+ }
configuration_f2p_decoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class F2PDecoderConfig(PretrainedConfig):
5
+ """Configuration for a feature-to-pixel reconstruction decoder."""
6
+
7
+ model_type = "f2p_decoder"
8
+
9
+ def __init__(
10
+ self,
11
+ pretrained_encoder_name: str = "google/siglip2-so400m-patch14-224",
12
+ source_decoder_repo: str = "nyu-visionx/siglip2_decoder",
13
+ image_size: int = 224,
14
+ patch_size: int = 14,
15
+ num_channels: int = 3,
16
+ hidden_size: int = 1152,
17
+ decoder_hidden_size: int = 1152,
18
+ decoder_num_hidden_layers: int = 28,
19
+ decoder_num_attention_heads: int = 16,
20
+ decoder_intermediate_size: int = 4096,
21
+ hidden_act: str = "gelu",
22
+ hidden_dropout_prob: float = 0.0,
23
+ attention_probs_dropout_prob: float = 0.0,
24
+ initializer_range: float = 0.02,
25
+ layer_norm_eps: float = 1e-12,
26
+ qkv_bias: bool = True,
27
+ num_patches: int = 256,
28
+ drop_cls_token: bool = True,
29
+ image_mean: list[float] | None = None,
30
+ image_std: list[float] | None = None,
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(**kwargs)
34
+ if getattr(self, "auto_map", None) is None:
35
+ self.auto_map = {
36
+ "AutoConfig": "configuration_f2p_decoder.F2PDecoderConfig",
37
+ "AutoModel": "modeling_f2p_decoder.F2PDecoderModel",
38
+ }
39
+
40
+ if image_mean is None:
41
+ image_mean = [0.5, 0.5, 0.5]
42
+ if image_std is None:
43
+ image_std = [0.5, 0.5, 0.5]
44
+ if len(image_mean) != num_channels or len(image_std) != num_channels:
45
+ raise ValueError("image_mean and image_std must match num_channels.")
46
+ if not drop_cls_token:
47
+ raise ValueError("Only drop_cls_token=True is supported by this decoder.")
48
+
49
+ self.pretrained_encoder_name = pretrained_encoder_name
50
+ self.source_decoder_repo = source_decoder_repo
51
+ self.image_size = int(image_size)
52
+ self.patch_size = int(patch_size)
53
+ self.num_channels = int(num_channels)
54
+ self.hidden_size = int(hidden_size)
55
+ self.decoder_hidden_size = int(decoder_hidden_size)
56
+ self.decoder_num_hidden_layers = int(decoder_num_hidden_layers)
57
+ self.decoder_num_attention_heads = int(decoder_num_attention_heads)
58
+ self.decoder_intermediate_size = int(decoder_intermediate_size)
59
+ self.hidden_act = hidden_act
60
+ self.hidden_dropout_prob = float(hidden_dropout_prob)
61
+ self.attention_probs_dropout_prob = float(attention_probs_dropout_prob)
62
+ self.initializer_range = float(initializer_range)
63
+ self.layer_norm_eps = float(layer_norm_eps)
64
+ self.qkv_bias = bool(qkv_bias)
65
+ self.num_patches = int(num_patches)
66
+ self.drop_cls_token = bool(drop_cls_token)
67
+ self.image_mean = [float(value) for value in image_mean]
68
+ self.image_std = [float(value) for value in image_std]
convert_original_checkpoint.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from configuration_f2p_decoder import F2PDecoderConfig
8
+ from modeling_f2p_decoder import F2PDecoderModel
9
+
10
+
11
+ def convert(output_dir: str) -> None:
12
+ output_path = Path(output_dir)
13
+ checkpoint_path = hf_hub_download("nyu-visionx/siglip2_decoder", "model.pt")
14
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
15
+ state_dict = {f"decoder.{key}": value for key, value in state_dict.items()}
16
+
17
+ config = F2PDecoderConfig()
18
+ model = F2PDecoderModel(config)
19
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
20
+ unexpected_keys = [key for key in unexpected_keys if key]
21
+ missing_keys = [
22
+ key for key in missing_keys if key not in {"image_mean", "image_std"}
23
+ ]
24
+ if missing_keys or unexpected_keys:
25
+ raise RuntimeError(
26
+ "Checkpoint conversion mismatch: "
27
+ f"missing={missing_keys}, unexpected={unexpected_keys}"
28
+ )
29
+
30
+ model.save_pretrained(output_path, safe_serialization=True)
31
+ print(f"Saved Hugging Face artifact to {output_path}")
32
+
33
+
34
+ def main() -> None:
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--output_dir", default="hf_artifacts/f2p_decoder")
37
+ args = parser.parse_args()
38
+ convert(args.output_dir)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
decoder.py ADDED
@@ -0,0 +1,1149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViT MAE (masked autoencoder) model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from copy import deepcopy
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Set, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+
28
+ # correct the above import to the following
29
+ from transformers.models.vit_mae.configuration_vit_mae import ViTMAEConfig
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_start_docstrings,
33
+ logging,
34
+ replace_return_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ )
37
+ from transformers.pytorch_utils import (
38
+ find_pruneable_heads_and_indices,
39
+ prune_linear_layer,
40
+ )
41
+ from transformers.activations import ACT2FN
42
+ from transformers.modeling_outputs import BaseModelOutput
43
+ from transformers.modeling_utils import PreTrainedModel
44
+
45
+ logger = logging.get_logger(__name__)
46
+ _CONFIG_FOR_DOC = "ViTMAEConfig"
47
+ _CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
48
+
49
+
50
+ @dataclass
51
+ class ViTMAEModelOutput(ModelOutput):
52
+ """
53
+ Class for ViTMAEModel's outputs, with potential hidden states and attentions.
54
+
55
+ Args:
56
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
57
+ Sequence of hidden-states at the output of the last layer of the model.
58
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
59
+ Tensor indicating which patches are masked (1) and which are not (0).
60
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
61
+ Tensor containing the original index of the (shuffled) masked patches.
62
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
63
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
64
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
65
+ plus the initial embedding outputs.
66
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
67
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
68
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
69
+ the self-attention heads.
70
+ """
71
+
72
+ last_hidden_state: torch.FloatTensor = None
73
+ mask: torch.LongTensor = None
74
+ ids_restore: torch.LongTensor = None
75
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
77
+
78
+
79
+ @dataclass
80
+ class ViTMAEDecoderOutput(ModelOutput):
81
+ """
82
+ Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
83
+
84
+ Args:
85
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
86
+ Pixel reconstruction logits.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
89
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
90
+ plus the initial embedding outputs.
91
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
92
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
93
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
94
+ the self-attention heads.
95
+ """
96
+
97
+ logits: torch.FloatTensor = None
98
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
99
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
100
+
101
+
102
+ @dataclass
103
+ class ViTMAEForPreTrainingOutput(ModelOutput):
104
+ """
105
+ Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
106
+
107
+ Args:
108
+ loss (`torch.FloatTensor` of shape `(1,)`):
109
+ Pixel reconstruction loss.
110
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
111
+ Pixel reconstruction logits.
112
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
113
+ Tensor indicating which patches are masked (1) and which are not (0).
114
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
115
+ Tensor containing the original index of the (shuffled) masked patches.
116
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
117
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
118
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
119
+ plus the initial embedding outputs.
120
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
121
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
122
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
123
+ the self-attention heads.
124
+ """
125
+
126
+ loss: Optional[torch.FloatTensor] = None
127
+ logits: torch.FloatTensor = None
128
+ mask: torch.LongTensor = None
129
+ ids_restore: torch.LongTensor = None
130
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
131
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
132
+
133
+
134
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
135
+ """
136
+ Create 2D sin/cos positional embeddings.
137
+
138
+ Args:
139
+ embed_dim (`int`):
140
+ Embedding dimension.
141
+ grid_size (`int`):
142
+ The grid height and width.
143
+ add_cls_token (`bool`, *optional*, defaults to `False`):
144
+ Whether or not to add a classification (CLS) token.
145
+
146
+ Returns:
147
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
148
+ position embeddings (with or without classification token)
149
+ """
150
+ grid_h = np.arange(grid_size, dtype=np.float32)
151
+ grid_w = np.arange(grid_size, dtype=np.float32)
152
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
153
+ grid = np.stack(grid, axis=0)
154
+
155
+ grid = grid.reshape([2, 1, grid_size, grid_size])
156
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
157
+ if add_cls_token:
158
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
159
+ return pos_embed
160
+
161
+
162
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
163
+ if embed_dim % 2 != 0:
164
+ raise ValueError("embed_dim must be even")
165
+
166
+ # use half of dimensions to encode grid_h
167
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
168
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
169
+
170
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
171
+ return emb
172
+
173
+
174
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
175
+ """
176
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
177
+ """
178
+ if embed_dim % 2 != 0:
179
+ raise ValueError("embed_dim must be even")
180
+
181
+ omega = np.arange(embed_dim // 2, dtype=float)
182
+ omega /= embed_dim / 2.0
183
+ omega = 1.0 / 10000**omega # (D/2,)
184
+
185
+ pos = pos.reshape(-1) # (M,)
186
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
187
+
188
+ emb_sin = np.sin(out) # (M, D/2)
189
+ emb_cos = np.cos(out) # (M, D/2)
190
+
191
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
192
+ return emb
193
+
194
+
195
+ class ViTMAEEmbeddings(nn.Module):
196
+ """
197
+ Construct the CLS token, position and patch embeddings.
198
+
199
+ """
200
+
201
+ def __init__(self, config):
202
+ super().__init__()
203
+
204
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
205
+ self.patch_embeddings = ViTMAEPatchEmbeddings(config)
206
+ self.num_patches = self.patch_embeddings.num_patches
207
+ # fixed sin-cos embedding
208
+ self.position_embeddings = nn.Parameter(
209
+ torch.zeros(1, self.num_patches + 1, config.hidden_size),
210
+ requires_grad=False,
211
+ )
212
+ self.config = config
213
+ self.initialize_weights()
214
+
215
+ def initialize_weights(self):
216
+ # initialize (and freeze) position embeddings by sin-cos embedding
217
+ pos_embed = get_2d_sincos_pos_embed(
218
+ self.position_embeddings.shape[-1],
219
+ int(self.patch_embeddings.num_patches**0.5),
220
+ add_cls_token=True,
221
+ )
222
+ self.position_embeddings.data.copy_(
223
+ torch.from_numpy(pos_embed).float().unsqueeze(0)
224
+ )
225
+
226
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
227
+ w = self.patch_embeddings.projection.weight.data
228
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
229
+
230
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
231
+ torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
232
+
233
+ def interpolate_pos_encoding(
234
+ self, embeddings: torch.Tensor, height: int, width: int
235
+ ) -> torch.Tensor:
236
+ """
237
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
238
+ resolution images.
239
+
240
+ Source:
241
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
242
+ """
243
+ num_patches = embeddings.shape[1] - 1
244
+ num_positions = self.position_embeddings.shape[1] - 1
245
+
246
+ if num_patches == num_positions and height == width:
247
+ return self.position_embeddings
248
+
249
+ class_pos_embed = self.position_embeddings[:, 0, :]
250
+ patch_pos_embed = self.position_embeddings[:, 1:, :]
251
+ dim = embeddings.shape[-1]
252
+ h0 = height // self.config.patch_size
253
+ w0 = width // self.config.patch_size
254
+ # we add a small number to avoid floating point error in the interpolation
255
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
256
+ h0, w0 = h0 + 0.1, w0 + 0.1
257
+ patch_pos_embed = patch_pos_embed.reshape(
258
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
259
+ )
260
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
261
+ patch_pos_embed = nn.functional.interpolate(
262
+ patch_pos_embed,
263
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
264
+ mode="bicubic",
265
+ align_corners=False,
266
+ )
267
+ if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
268
+ raise ValueError(
269
+ "Width or height does not match with the interpolated position embeddings"
270
+ )
271
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
272
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
273
+
274
+ def random_masking(self, sequence, noise=None):
275
+ """
276
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
277
+ noise.
278
+
279
+ Args:
280
+ sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
281
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
282
+ mainly used for testing purposes to control randomness and maintain the reproducibility
283
+ """
284
+ batch_size, seq_length, dim = sequence.shape
285
+ len_keep = int(seq_length * (1 - self.config.mask_ratio))
286
+
287
+ if noise is None:
288
+ noise = torch.rand(
289
+ batch_size, seq_length, device=sequence.device
290
+ ) # noise in [0, 1]
291
+
292
+ # sort noise for each sample
293
+ ids_shuffle = torch.argsort(noise, dim=1).to(
294
+ sequence.device
295
+ ) # ascend: small is keep, large is remove
296
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
297
+
298
+ # keep the first subset
299
+ ids_keep = ids_shuffle[:, :len_keep]
300
+ sequence_unmasked = torch.gather(
301
+ sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)
302
+ )
303
+
304
+ # generate the binary mask: 0 is keep, 1 is remove
305
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
306
+ mask[:, :len_keep] = 0
307
+ # unshuffle to get the binary mask
308
+ mask = torch.gather(mask, dim=1, index=ids_restore)
309
+
310
+ return sequence_unmasked, mask, ids_restore
311
+
312
+ def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
313
+ batch_size, num_channels, height, width = pixel_values.shape
314
+ embeddings = self.patch_embeddings(
315
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
316
+ )
317
+ if interpolate_pos_encoding:
318
+ position_embeddings = self.interpolate_pos_encoding(
319
+ embeddings, height, width
320
+ )
321
+ else:
322
+ position_embeddings = self.position_embeddings
323
+
324
+ # add position embeddings w/o cls token
325
+ embeddings = embeddings + position_embeddings[:, 1:, :]
326
+
327
+ # masking: length -> length * config.mask_ratio
328
+ embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
329
+
330
+ # append cls token
331
+ cls_token = self.cls_token + position_embeddings[:, :1, :]
332
+ cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
333
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
334
+
335
+ return embeddings, mask, ids_restore
336
+
337
+
338
+ class ViTMAEPatchEmbeddings(nn.Module):
339
+ """
340
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
341
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
342
+ Transformer.
343
+ """
344
+
345
+ def __init__(self, config):
346
+ super().__init__()
347
+ image_size, patch_size = config.image_size, config.patch_size
348
+ num_channels, hidden_size = config.num_channels, config.hidden_size
349
+ image_size = (
350
+ image_size
351
+ if isinstance(image_size, collections.abc.Iterable)
352
+ else (image_size, image_size)
353
+ )
354
+ patch_size = (
355
+ patch_size
356
+ if isinstance(patch_size, collections.abc.Iterable)
357
+ else (patch_size, patch_size)
358
+ )
359
+ num_patches = (image_size[1] // patch_size[1]) * (
360
+ image_size[0] // patch_size[0]
361
+ )
362
+ self.image_size = image_size
363
+ self.patch_size = patch_size
364
+ self.num_channels = num_channels
365
+ self.num_patches = num_patches
366
+
367
+ self.projection = nn.Conv2d(
368
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
369
+ )
370
+
371
+ def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
372
+ batch_size, num_channels, height, width = pixel_values.shape
373
+ if num_channels != self.num_channels:
374
+ raise ValueError(
375
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
376
+ )
377
+
378
+ if not interpolate_pos_encoding and (
379
+ height != self.image_size[0] or width != self.image_size[1]
380
+ ):
381
+ raise ValueError(
382
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
383
+ )
384
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
385
+ return x
386
+
387
+
388
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
389
+ class ViTMAESelfAttention(nn.Module):
390
+ def __init__(self, config: ViTMAEConfig) -> None:
391
+ super().__init__()
392
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
393
+ config, "embedding_size"
394
+ ):
395
+ raise ValueError(
396
+ f"The hidden size {(config.hidden_size,)} is not a multiple of the number of attention "
397
+ f"heads {config.num_attention_heads}."
398
+ )
399
+
400
+ self.num_attention_heads = config.num_attention_heads
401
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
402
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
403
+
404
+ self.query = nn.Linear(
405
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
406
+ )
407
+ self.key = nn.Linear(
408
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
409
+ )
410
+ self.value = nn.Linear(
411
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
412
+ )
413
+
414
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
415
+
416
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
417
+ new_x_shape = x.size()[:-1] + (
418
+ self.num_attention_heads,
419
+ self.attention_head_size,
420
+ )
421
+ x = x.view(new_x_shape)
422
+ return x.permute(0, 2, 1, 3)
423
+
424
+ def forward(
425
+ self,
426
+ hidden_states,
427
+ head_mask: Optional[torch.Tensor] = None,
428
+ output_attentions: bool = False,
429
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
430
+ mixed_query_layer = self.query(hidden_states)
431
+
432
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
433
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
434
+ query_layer = self.transpose_for_scores(mixed_query_layer)
435
+
436
+ # Take the dot product between "query" and "key" to get the raw attention scores.
437
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
438
+
439
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
440
+
441
+ # Normalize the attention scores to probabilities.
442
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
443
+
444
+ # This is actually dropping out entire tokens to attend to, which might
445
+ # seem a bit unusual, but is taken from the original Transformer paper.
446
+ attention_probs = self.dropout(attention_probs)
447
+
448
+ # Mask heads if we want to
449
+ if head_mask is not None:
450
+ attention_probs = attention_probs * head_mask
451
+
452
+ context_layer = torch.matmul(attention_probs, value_layer)
453
+
454
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
455
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
456
+ context_layer = context_layer.view(new_context_layer_shape)
457
+
458
+ outputs = (
459
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
460
+ )
461
+
462
+ return outputs
463
+
464
+
465
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
466
+ class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
467
+ def __init__(self, config: ViTMAEConfig) -> None:
468
+ super().__init__(config)
469
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states,
474
+ head_mask: Optional[torch.Tensor] = None,
475
+ output_attentions: bool = False,
476
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
477
+ mixed_query_layer = self.query(hidden_states)
478
+
479
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
480
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
481
+ query_layer = self.transpose_for_scores(mixed_query_layer)
482
+
483
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
484
+ query_layer,
485
+ key_layer,
486
+ value_layer,
487
+ head_mask,
488
+ self.attention_probs_dropout_prob if self.training else 0.0,
489
+ is_causal=False,
490
+ scale=None,
491
+ )
492
+
493
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
494
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
495
+ context_layer = context_layer.view(new_context_layer_shape)
496
+
497
+ return context_layer, None
498
+
499
+
500
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
501
+ class ViTMAESelfOutput(nn.Module):
502
+ """
503
+ The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
504
+ layernorm applied before each block.
505
+ """
506
+
507
+ def __init__(self, config: ViTMAEConfig) -> None:
508
+ super().__init__()
509
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
510
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
511
+
512
+ def forward(
513
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
514
+ ) -> torch.Tensor:
515
+ hidden_states = self.dense(hidden_states)
516
+ hidden_states = self.dropout(hidden_states)
517
+
518
+ return hidden_states
519
+
520
+
521
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
522
+ class ViTMAEAttention(nn.Module):
523
+ def __init__(self, config: ViTMAEConfig) -> None:
524
+ super().__init__()
525
+ self.attention = ViTMAESelfAttention(config)
526
+ self.output = ViTMAESelfOutput(config)
527
+ self.pruned_heads = set()
528
+
529
+ def prune_heads(self, heads: Set[int]) -> None:
530
+ if len(heads) == 0:
531
+ return
532
+ heads, index = find_pruneable_heads_and_indices(
533
+ heads,
534
+ self.attention.num_attention_heads,
535
+ self.attention.attention_head_size,
536
+ self.pruned_heads,
537
+ )
538
+
539
+ # Prune linear layers
540
+ self.attention.query = prune_linear_layer(self.attention.query, index)
541
+ self.attention.key = prune_linear_layer(self.attention.key, index)
542
+ self.attention.value = prune_linear_layer(self.attention.value, index)
543
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
544
+
545
+ # Update hyper params and store pruned heads
546
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
547
+ heads
548
+ )
549
+ self.attention.all_head_size = (
550
+ self.attention.attention_head_size * self.attention.num_attention_heads
551
+ )
552
+ self.pruned_heads = self.pruned_heads.union(heads)
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ head_mask: Optional[torch.Tensor] = None,
558
+ output_attentions: bool = False,
559
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
560
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
561
+
562
+ attention_output = self.output(self_outputs[0], hidden_states)
563
+
564
+ outputs = (attention_output,) + self_outputs[
565
+ 1:
566
+ ] # add attentions if we output them
567
+ return outputs
568
+
569
+
570
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
571
+ class ViTMAESdpaAttention(ViTMAEAttention):
572
+ def __init__(self, config: ViTMAEConfig) -> None:
573
+ super().__init__(config)
574
+ self.attention = ViTMAESdpaSelfAttention(config)
575
+
576
+
577
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
578
+ class ViTMAEIntermediate(nn.Module):
579
+ def __init__(self, config: ViTMAEConfig) -> None:
580
+ super().__init__()
581
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
582
+ if isinstance(config.hidden_act, str):
583
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
584
+ else:
585
+ self.intermediate_act_fn = config.hidden_act
586
+
587
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
588
+ hidden_states = self.dense(hidden_states)
589
+ hidden_states = self.intermediate_act_fn(hidden_states)
590
+
591
+ return hidden_states
592
+
593
+
594
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
595
+ class ViTMAEOutput(nn.Module):
596
+ def __init__(self, config: ViTMAEConfig) -> None:
597
+ super().__init__()
598
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
599
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
600
+
601
+ def forward(
602
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
603
+ ) -> torch.Tensor:
604
+ hidden_states = self.dense(hidden_states)
605
+ hidden_states = self.dropout(hidden_states)
606
+
607
+ hidden_states = hidden_states + input_tensor
608
+
609
+ return hidden_states
610
+
611
+
612
+ VITMAE_ATTENTION_CLASSES = {
613
+ "eager": ViTMAEAttention,
614
+ "sdpa": ViTMAESdpaAttention,
615
+ }
616
+
617
+
618
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
619
+ class ViTMAELayer(nn.Module):
620
+ """This corresponds to the Block class in the timm implementation."""
621
+
622
+ def __init__(self, config: ViTMAEConfig) -> None:
623
+ super().__init__()
624
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
625
+ self.seq_len_dim = 1
626
+ self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
627
+ self.intermediate = ViTMAEIntermediate(config)
628
+ self.output = ViTMAEOutput(config)
629
+ self.layernorm_before = nn.LayerNorm(
630
+ config.hidden_size, eps=config.layer_norm_eps
631
+ )
632
+ self.layernorm_after = nn.LayerNorm(
633
+ config.hidden_size, eps=config.layer_norm_eps
634
+ )
635
+
636
+ def forward(
637
+ self,
638
+ hidden_states: torch.Tensor,
639
+ head_mask: Optional[torch.Tensor] = None,
640
+ output_attentions: bool = False,
641
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
642
+ self_attention_outputs = self.attention(
643
+ self.layernorm_before(
644
+ hidden_states
645
+ ), # in ViTMAE, layernorm is applied before self-attention
646
+ head_mask,
647
+ output_attentions=output_attentions,
648
+ )
649
+ attention_output = self_attention_outputs[0]
650
+ outputs = self_attention_outputs[
651
+ 1:
652
+ ] # add self attentions if we output attention weights
653
+
654
+ # first residual connection
655
+ hidden_states = attention_output + hidden_states
656
+
657
+ # in ViTMAE, layernorm is also applied after self-attention
658
+ layer_output = self.layernorm_after(hidden_states)
659
+ layer_output = self.intermediate(layer_output)
660
+
661
+ # second residual connection is done here
662
+ layer_output = self.output(layer_output, hidden_states)
663
+
664
+ outputs = (layer_output,) + outputs
665
+
666
+ return outputs
667
+
668
+
669
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
670
+ class ViTMAEEncoder(nn.Module):
671
+ def __init__(self, config: ViTMAEConfig) -> None:
672
+ super().__init__()
673
+ self.config = config
674
+ self.layer = nn.ModuleList(
675
+ [ViTMAELayer(config) for _ in range(config.num_hidden_layers)]
676
+ )
677
+ self.gradient_checkpointing = False
678
+
679
+ def forward(
680
+ self,
681
+ hidden_states: torch.Tensor,
682
+ head_mask: Optional[torch.Tensor] = None,
683
+ output_attentions: bool = False,
684
+ output_hidden_states: bool = False,
685
+ return_dict: bool = True,
686
+ ) -> Union[tuple, BaseModelOutput]:
687
+ all_hidden_states = () if output_hidden_states else None
688
+ all_self_attentions = () if output_attentions else None
689
+
690
+ for i, layer_module in enumerate(self.layer):
691
+ if output_hidden_states:
692
+ all_hidden_states = all_hidden_states + (hidden_states,)
693
+
694
+ layer_head_mask = head_mask[i] if head_mask is not None else None
695
+
696
+ if self.gradient_checkpointing and self.training:
697
+ layer_outputs = self._gradient_checkpointing_func(
698
+ layer_module.__call__,
699
+ hidden_states,
700
+ layer_head_mask,
701
+ output_attentions,
702
+ )
703
+ else:
704
+ layer_outputs = layer_module(
705
+ hidden_states, layer_head_mask, output_attentions
706
+ )
707
+
708
+ hidden_states = layer_outputs[0]
709
+
710
+ if output_attentions:
711
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
712
+
713
+ if output_hidden_states:
714
+ all_hidden_states = all_hidden_states + (hidden_states,)
715
+
716
+ if not return_dict:
717
+ return tuple(
718
+ v
719
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
720
+ if v is not None
721
+ )
722
+ return BaseModelOutput(
723
+ last_hidden_state=hidden_states,
724
+ hidden_states=all_hidden_states,
725
+ attentions=all_self_attentions,
726
+ )
727
+
728
+
729
+ class ViTMAEPreTrainedModel(PreTrainedModel):
730
+ """
731
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
732
+ models.
733
+ """
734
+
735
+ config_class = ViTMAEConfig
736
+ base_model_prefix = "vit"
737
+ main_input_name = "pixel_values"
738
+ supports_gradient_checkpointing = True
739
+ _supports_sdpa = True
740
+
741
+ def _init_weights(self, module):
742
+ """Initialize the weights"""
743
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
744
+ # Slightly different from the TF version which uses truncated_normal for initialization
745
+ # cf https://github.com/pytorch/pytorch/pull/5617
746
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
747
+ if module.bias is not None:
748
+ module.bias.data.zero_()
749
+ elif isinstance(module, nn.LayerNorm):
750
+ module.bias.data.zero_()
751
+ module.weight.data.fill_(1.0)
752
+
753
+
754
+ VIT_MAE_START_DOCSTRING = r"""
755
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
756
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
757
+ behavior.
758
+
759
+ Parameters:
760
+ config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
761
+ Initializing with a config file does not load the weights associated with the model, only the
762
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
763
+ """
764
+
765
+ VIT_MAE_INPUTS_DOCSTRING = r"""
766
+ Args:
767
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
768
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
769
+ for details.
770
+
771
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
772
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
773
+
774
+ - 1 indicates the head is **not masked**,
775
+ - 0 indicates the head is **masked**.
776
+
777
+ output_attentions (`bool`, *optional*):
778
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
779
+ tensors for more detail.
780
+ output_hidden_states (`bool`, *optional*):
781
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
782
+ more detail.
783
+ return_dict (`bool`, *optional*):
784
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
785
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
786
+ Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
787
+ resolution images.
788
+ """
789
+
790
+
791
+ @add_start_docstrings(
792
+ "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
793
+ VIT_MAE_START_DOCSTRING,
794
+ )
795
+ class ViTMAEModel(ViTMAEPreTrainedModel):
796
+ def __init__(self, config):
797
+ super().__init__(config)
798
+ self.config = config
799
+
800
+ self.embeddings = ViTMAEEmbeddings(config)
801
+ self.encoder = ViTMAEEncoder(config)
802
+
803
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
804
+
805
+ # Initialize weights and apply final processing
806
+ self.post_init()
807
+
808
+ def get_input_embeddings(self):
809
+ return self.embeddings.patch_embeddings
810
+
811
+ def _prune_heads(self, heads_to_prune):
812
+ """
813
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
814
+ class PreTrainedModel
815
+ """
816
+ for layer, heads in heads_to_prune.items():
817
+ self.encoder.layer[layer].attention.prune_heads(heads)
818
+
819
+ @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
820
+ @replace_return_docstrings(
821
+ output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC
822
+ )
823
+ def forward(
824
+ self,
825
+ pixel_values: Optional[torch.FloatTensor] = None,
826
+ noise: Optional[torch.FloatTensor] = None,
827
+ head_mask: Optional[torch.FloatTensor] = None,
828
+ output_attentions: Optional[bool] = None,
829
+ output_hidden_states: Optional[bool] = None,
830
+ return_dict: Optional[bool] = None,
831
+ interpolate_pos_encoding: bool = False,
832
+ ) -> Union[Tuple, ViTMAEModelOutput]:
833
+ r"""
834
+ Returns:
835
+
836
+ Examples:
837
+
838
+ ```python
839
+ >>> from transformers import AutoImageProcessor, ViTMAEModel
840
+ >>> from PIL import Image
841
+ >>> import requests
842
+
843
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
844
+ >>> image = Image.open(requests.get(url, stream=True).raw)
845
+
846
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
847
+ >>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
848
+
849
+ >>> inputs = image_processor(images=image, return_tensors="pt")
850
+ >>> outputs = model(**inputs)
851
+ >>> last_hidden_states = outputs.last_hidden_state
852
+ ```"""
853
+ output_attentions = (
854
+ output_attentions
855
+ if output_attentions is not None
856
+ else self.config.output_attentions
857
+ )
858
+ output_hidden_states = (
859
+ output_hidden_states
860
+ if output_hidden_states is not None
861
+ else self.config.output_hidden_states
862
+ )
863
+ return_dict = (
864
+ return_dict if return_dict is not None else self.config.use_return_dict
865
+ )
866
+
867
+ if pixel_values is None:
868
+ raise ValueError("You have to specify pixel_values")
869
+
870
+ # Prepare head mask if needed
871
+ # 1.0 in head_mask indicate we keep the head
872
+ # attention_probs has shape bsz x n_heads x N x N
873
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
874
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
875
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
876
+
877
+ embedding_output, mask, ids_restore = self.embeddings(
878
+ pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
879
+ )
880
+
881
+ encoder_outputs = self.encoder(
882
+ embedding_output,
883
+ head_mask=head_mask,
884
+ output_attentions=output_attentions,
885
+ output_hidden_states=output_hidden_states,
886
+ return_dict=return_dict,
887
+ )
888
+ sequence_output = encoder_outputs[0]
889
+ sequence_output = self.layernorm(sequence_output)
890
+
891
+ if not return_dict:
892
+ return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
893
+
894
+ return ViTMAEModelOutput(
895
+ last_hidden_state=sequence_output,
896
+ mask=mask,
897
+ ids_restore=ids_restore,
898
+ hidden_states=encoder_outputs.hidden_states,
899
+ attentions=encoder_outputs.attentions,
900
+ )
901
+
902
+
903
+ class GeneralDecoder(nn.Module):
904
+ def __init__(self, config, num_patches):
905
+ super().__init__()
906
+ self.decoder_embed = nn.Linear(
907
+ config.hidden_size, config.decoder_hidden_size, bias=True
908
+ )
909
+ self.decoder_pos_embed = nn.Parameter(
910
+ torch.zeros(1, num_patches + 1, config.decoder_hidden_size),
911
+ requires_grad=False,
912
+ ) # fixed sin-cos embedding
913
+
914
+ decoder_config = deepcopy(config)
915
+ decoder_config.hidden_size = config.decoder_hidden_size
916
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
917
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
918
+ decoder_config.intermediate_size = config.decoder_intermediate_size
919
+ self.decoder_layers = nn.ModuleList(
920
+ [
921
+ ViTMAELayer(decoder_config)
922
+ for _ in range(config.decoder_num_hidden_layers)
923
+ ]
924
+ )
925
+
926
+ self.decoder_norm = nn.LayerNorm(
927
+ config.decoder_hidden_size, eps=config.layer_norm_eps
928
+ )
929
+ self.decoder_pred = nn.Linear(
930
+ config.decoder_hidden_size,
931
+ config.patch_size**2 * config.num_channels,
932
+ bias=True,
933
+ ) # encoder to decoder
934
+ self.gradient_checkpointing = False
935
+ self.config = config
936
+ self.num_patches = num_patches
937
+ self.initialize_weights(num_patches)
938
+ self.decoder_config = decoder_config
939
+ self.set_trainable_cls_token()
940
+
941
+ def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None):
942
+ # register a trainable CLS token
943
+ tensor = (
944
+ torch.zeros(1, 1, self.decoder_config.hidden_size)
945
+ if tensor is None
946
+ else tensor
947
+ )
948
+ self.trainable_cls_token = nn.Parameter(tensor)
949
+
950
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
951
+ """
952
+ This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
953
+ allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
954
+ resolution images.
955
+
956
+ Source:
957
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
958
+ """
959
+
960
+ # -1 removes the class dimension since we later append it without interpolation
961
+ embeddings_positions = embeddings.shape[1] - 1
962
+ num_positions = self.decoder_pos_embed.shape[1] - 1
963
+
964
+ # Separation of class token and patch tokens
965
+ class_pos_embed = self.decoder_pos_embed[:, 0, :]
966
+ patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
967
+
968
+ # To retain the final 3d tensor with the required dimensions
969
+ dim = self.decoder_pos_embed.shape[-1]
970
+
971
+ # Increasing a dimension to enable bicubic interpolation
972
+ patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
973
+
974
+ # permute to bring the dimension to be interpolated, to the last
975
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
976
+
977
+ # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
978
+ # 1 keeps the other dimension constant
979
+ patch_pos_embed = nn.functional.interpolate(
980
+ patch_pos_embed,
981
+ scale_factor=(1, embeddings_positions / num_positions),
982
+ mode="bicubic",
983
+ align_corners=False,
984
+ )
985
+
986
+ # Converting back to the original shape
987
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
988
+ # Adding the class token back
989
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
990
+
991
+ def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
992
+ b, l, c = x.shape
993
+ if l == self.num_patches:
994
+ return x
995
+ # interpolate the latent
996
+ # print(f"interpolating latent from {l} to {self.num_patches}, x.shape = {x.shape}")
997
+ h, w = int(l**0.5), int(l**0.5)
998
+ x = x.reshape(b, h, w, c)
999
+ x = x.permute(0, 3, 1, 2)
1000
+ target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
1001
+ x = nn.functional.interpolate(
1002
+ x, size=target_size, mode="bilinear", align_corners=False
1003
+ )
1004
+ x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
1005
+ return x
1006
+
1007
+ def initialize_weights(self, num_patches):
1008
+ # initialize (and freeze) position embeddings by sin-cos embedding
1009
+ decoder_pos_embed = get_2d_sincos_pos_embed(
1010
+ self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
1011
+ )
1012
+ self.decoder_pos_embed.data.copy_(
1013
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
1014
+ )
1015
+
1016
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
1017
+ # torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
1018
+
1019
+ def unpatchify(
1020
+ self,
1021
+ patchified_pixel_values,
1022
+ original_image_size: Optional[Tuple[int, int]] = None,
1023
+ ):
1024
+ """
1025
+ Args:
1026
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1027
+ Patchified pixel values.
1028
+ original_image_size (`Tuple[int, int]`, *optional*):
1029
+ Original image size.
1030
+
1031
+ Returns:
1032
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
1033
+ Pixel values.
1034
+ """
1035
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1036
+ original_image_size = (
1037
+ original_image_size
1038
+ if original_image_size is not None
1039
+ else (self.config.image_size, self.config.image_size)
1040
+ )
1041
+ original_height, original_width = original_image_size
1042
+ num_patches_h = original_height // patch_size
1043
+ num_patches_w = original_width // patch_size
1044
+ # sanity check
1045
+ if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
1046
+ raise ValueError(
1047
+ f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
1048
+ )
1049
+
1050
+ # unpatchify
1051
+ batch_size = patchified_pixel_values.shape[0]
1052
+ patchified_pixel_values = patchified_pixel_values.reshape(
1053
+ batch_size,
1054
+ num_patches_h,
1055
+ num_patches_w,
1056
+ patch_size,
1057
+ patch_size,
1058
+ num_channels,
1059
+ )
1060
+ patchified_pixel_values = torch.einsum(
1061
+ "nhwpqc->nchpwq", patchified_pixel_values
1062
+ )
1063
+ pixel_values = patchified_pixel_values.reshape(
1064
+ batch_size,
1065
+ num_channels,
1066
+ num_patches_h * patch_size,
1067
+ num_patches_w * patch_size,
1068
+ )
1069
+ return pixel_values
1070
+
1071
+ def forward(
1072
+ self,
1073
+ hidden_states,
1074
+ output_attentions=False,
1075
+ output_hidden_states=False,
1076
+ return_dict=True,
1077
+ interpolate_pos_encoding: bool = False,
1078
+ drop_cls_token: bool = False,
1079
+ ):
1080
+ # embed tokens
1081
+ x = self.decoder_embed(hidden_states)
1082
+ # print(f"x.shape = {x.shape}")
1083
+ # append mask tokens to sequence
1084
+ # mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
1085
+ # append mask tokens to sequence
1086
+ # x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
1087
+ x_ = x[:, 1:, :] # no cls token
1088
+ if drop_cls_token:
1089
+ cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
1090
+ # print(f"cls_token.shape = {cls_token.shape}, x_.shape = {x_.shape}")
1091
+ x_ = self.interpolate_latent(x_)
1092
+ x = torch.cat([cls_token, x_], dim=1)
1093
+ else:
1094
+ raise NotImplementedError("drop_cls_token is not implemented")
1095
+ x = self.interpolate_latent(x) # interpolate the whole latent
1096
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
1097
+ # add pos embed
1098
+ if interpolate_pos_encoding:
1099
+ decoder_pos_embed = self.interpolate_pos_encoding(x)
1100
+ else:
1101
+ decoder_pos_embed = self.decoder_pos_embed
1102
+ hidden_states = x + decoder_pos_embed
1103
+
1104
+ # apply Transformer layers (blocks)
1105
+ all_hidden_states = () if output_hidden_states else None
1106
+ all_self_attentions = () if output_attentions else None
1107
+ for i, layer_module in enumerate(self.decoder_layers):
1108
+ if output_hidden_states:
1109
+ all_hidden_states = all_hidden_states + (hidden_states,)
1110
+
1111
+ if self.gradient_checkpointing and self.training:
1112
+ layer_outputs = self._gradient_checkpointing_func(
1113
+ layer_module.__call__,
1114
+ hidden_states,
1115
+ None,
1116
+ output_attentions,
1117
+ )
1118
+ else:
1119
+ layer_outputs = layer_module(
1120
+ hidden_states, head_mask=None, output_attentions=output_attentions
1121
+ )
1122
+
1123
+ hidden_states = layer_outputs[0]
1124
+
1125
+ if output_attentions:
1126
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1127
+
1128
+ if output_hidden_states:
1129
+ all_hidden_states = all_hidden_states + (hidden_states,)
1130
+
1131
+ hidden_states = self.decoder_norm(hidden_states)
1132
+
1133
+ # predictor projection
1134
+ logits = self.decoder_pred(hidden_states)
1135
+
1136
+ # remove cls token
1137
+ logits = logits[:, 1:, :]
1138
+
1139
+ if not return_dict:
1140
+ return tuple(
1141
+ v
1142
+ for v in [logits, all_hidden_states, all_self_attentions]
1143
+ if v is not None
1144
+ )
1145
+ return ViTMAEDecoderOutput(
1146
+ logits=logits,
1147
+ hidden_states=all_hidden_states,
1148
+ attentions=all_self_attentions,
1149
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0760e38a1d417bef3dc9916d9add8f50c8a0e3ce06b56c84d8319fabfdc466cc
3
+ size 1662407016
modeling_f2p_decoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers.modeling_outputs import ModelOutput
7
+ from transformers.modeling_utils import PreTrainedModel
8
+
9
+ try:
10
+ from .configuration_f2p_decoder import F2PDecoderConfig
11
+ from .decoder import GeneralDecoder
12
+ except ImportError:
13
+ from configuration_f2p_decoder import F2PDecoderConfig
14
+ from decoder import GeneralDecoder
15
+
16
+
17
+ @dataclass
18
+ class F2PDecoderOutput(ModelOutput):
19
+ reconstruction: torch.FloatTensor = None
20
+ logits: torch.FloatTensor = None
21
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
22
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
23
+
24
+
25
+ class F2PDecoderModel(PreTrainedModel):
26
+ """Feature-to-pixel decoder for SigLIP2 patch features."""
27
+
28
+ config_class = F2PDecoderConfig
29
+ base_model_prefix = "f2p_decoder"
30
+ main_input_name = "hidden_states"
31
+ supports_gradient_checkpointing = True
32
+
33
+ def __init__(self, config: F2PDecoderConfig):
34
+ super().__init__(config)
35
+ image_mean = torch.tensor(config.image_mean, dtype=torch.float32).view(
36
+ 1, config.num_channels, 1, 1
37
+ )
38
+ image_std = torch.tensor(config.image_std, dtype=torch.float32).view(
39
+ 1, config.num_channels, 1, 1
40
+ )
41
+ self.register_buffer("image_mean", image_mean)
42
+ self.register_buffer("image_std", image_std)
43
+ self.decoder = GeneralDecoder(config, num_patches=config.num_patches)
44
+
45
+ def _set_gradient_checkpointing(self, module, value=False):
46
+ if isinstance(module, GeneralDecoder):
47
+ module.gradient_checkpointing = value
48
+
49
+ def forward(
50
+ self,
51
+ hidden_states: Optional[torch.Tensor] = None,
52
+ zs: Optional[torch.Tensor] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ ):
57
+ if hidden_states is None:
58
+ hidden_states = zs
59
+ if hidden_states is None:
60
+ raise ValueError("Pass SigLIP2 features as hidden_states or zs.")
61
+
62
+ output_attentions = (
63
+ output_attentions
64
+ if output_attentions is not None
65
+ else self.config.output_attentions
66
+ )
67
+ output_hidden_states = (
68
+ output_hidden_states
69
+ if output_hidden_states is not None
70
+ else self.config.output_hidden_states
71
+ )
72
+ decoder_output = self.decoder(
73
+ hidden_states,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ return_dict=True,
77
+ drop_cls_token=self.config.drop_cls_token,
78
+ )
79
+ reconstruction = self.decoder.unpatchify(decoder_output.logits)
80
+ reconstruction = reconstruction * self.image_std + self.image_mean
81
+
82
+ if return_dict:
83
+ return F2PDecoderOutput(
84
+ reconstruction=reconstruction,
85
+ logits=decoder_output.logits,
86
+ hidden_states=decoder_output.hidden_states,
87
+ attentions=decoder_output.attentions,
88
+ )
89
+ return reconstruction
90
+
91
+ @torch.no_grad()
92
+ def infer(self, hidden_states: torch.Tensor) -> torch.Tensor:
93
+ return self.forward(hidden_states, return_dict=False)