lucasjin commited on
Commit
9c82578
1 Parent(s): d4cfa96

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoints/aimv2-3B-patch14-448",
3
+ "architectures": [
4
+ "AIMv2Model"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_aimv2.AIMv2Config",
9
+ "AutoModel": "modeling_aimv2.AIMv2Model",
10
+ "FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
11
+ },
12
+ "hidden_size": 3072,
13
+ "image_size": 448,
14
+ "intermediate_size": 8192,
15
+ "model_type": "aimv2",
16
+ "num_attention_heads": 24,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 24,
19
+ "patch_size": 14,
20
+ "projection_dropout": 0.0,
21
+ "qkv_bias": false,
22
+ "rms_norm_eps": 1e-05,
23
+ "torch_dtype": "bfloat16",
24
+ "transformers_version": "4.46.2",
25
+ "use_bias": false
26
+ }
configuration_aimv2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["AIMv2Config"]
6
+
7
+
8
+ class AIMv2Config(PretrainedConfig):
9
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
10
+
11
+ Instantiating a configuration with the defaults will yield a similar configuration
12
+ to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
13
+
14
+ Args:
15
+ hidden_size: Dimension of the hidden representations.
16
+ intermediate_size: Dimension of the SwiGLU representations.
17
+ num_hidden_layers: Number of hidden layers in the Transformer.
18
+ num_attention_heads: Number of attention heads for each attention layer
19
+ in the Transformer.
20
+ num_channels: Number of input channels.
21
+ image_size: Image size.
22
+ patch_size: Patch size.
23
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
24
+ attention_dropout: Dropout ratio for attention probabilities.
25
+ projection_dropout: Dropout ratio for the projection layer after the attention.
26
+ qkv_bias: Whether to add a bias to the queries, keys and values.
27
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
28
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
29
+ """
30
+
31
+ model_type: str = "aimv2"
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 1024,
36
+ intermediate_size: int = 2816,
37
+ num_hidden_layers: int = 24,
38
+ num_attention_heads: int = 8,
39
+ num_channels: int = 3,
40
+ image_size: int = 224,
41
+ patch_size: int = 14,
42
+ rms_norm_eps: float = 1e-5,
43
+ attention_dropout: float = 0.0,
44
+ projection_dropout: float = 0.0,
45
+ qkv_bias: bool = False,
46
+ use_bias: bool = False,
47
+ **kwargs: Any,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_channels = num_channels
55
+ self.patch_size = patch_size
56
+ self.image_size = image_size
57
+ self.attention_dropout = attention_dropout
58
+ self.rms_norm_eps = rms_norm_eps
59
+
60
+ self.projection_dropout = projection_dropout
61
+ self.qkv_bias = qkv_bias
62
+ self.use_bias = use_bias
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f445eaeab8c48ae50ab0de0157b47747b484064ed73e770e64db09eabc93927a
3
+ size 5446053960
model.safetensors.index.json ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5446035456
4
+ },
5
+ "weight_map": {
6
+ "preprocessor.patchifier.norm.weight": "model-00001-of-00002.safetensors",
7
+ "preprocessor.patchifier.proj.bias": "model-00001-of-00002.safetensors",
8
+ "preprocessor.patchifier.proj.weight": "model-00001-of-00002.safetensors",
9
+ "preprocessor.pos_embed": "model-00001-of-00002.safetensors",
10
+ "trunk.blocks.0.attn.proj.weight": "model-00001-of-00002.safetensors",
11
+ "trunk.blocks.0.attn.qkv.weight": "model-00001-of-00002.safetensors",
12
+ "trunk.blocks.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
13
+ "trunk.blocks.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
14
+ "trunk.blocks.0.mlp.fc3.weight": "model-00001-of-00002.safetensors",
15
+ "trunk.blocks.0.norm_1.weight": "model-00001-of-00002.safetensors",
16
+ "trunk.blocks.0.norm_2.weight": "model-00001-of-00002.safetensors",
17
+ "trunk.blocks.1.attn.proj.weight": "model-00001-of-00002.safetensors",
18
+ "trunk.blocks.1.attn.qkv.weight": "model-00001-of-00002.safetensors",
19
+ "trunk.blocks.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
20
+ "trunk.blocks.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
21
+ "trunk.blocks.1.mlp.fc3.weight": "model-00001-of-00002.safetensors",
22
+ "trunk.blocks.1.norm_1.weight": "model-00001-of-00002.safetensors",
23
+ "trunk.blocks.1.norm_2.weight": "model-00001-of-00002.safetensors",
24
+ "trunk.blocks.10.attn.proj.weight": "model-00001-of-00002.safetensors",
25
+ "trunk.blocks.10.attn.qkv.weight": "model-00001-of-00002.safetensors",
26
+ "trunk.blocks.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
27
+ "trunk.blocks.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
28
+ "trunk.blocks.10.mlp.fc3.weight": "model-00001-of-00002.safetensors",
29
+ "trunk.blocks.10.norm_1.weight": "model-00001-of-00002.safetensors",
30
+ "trunk.blocks.10.norm_2.weight": "model-00001-of-00002.safetensors",
31
+ "trunk.blocks.11.attn.proj.weight": "model-00001-of-00002.safetensors",
32
+ "trunk.blocks.11.attn.qkv.weight": "model-00001-of-00002.safetensors",
33
+ "trunk.blocks.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
34
+ "trunk.blocks.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
35
+ "trunk.blocks.11.mlp.fc3.weight": "model-00001-of-00002.safetensors",
36
+ "trunk.blocks.11.norm_1.weight": "model-00001-of-00002.safetensors",
37
+ "trunk.blocks.11.norm_2.weight": "model-00001-of-00002.safetensors",
38
+ "trunk.blocks.12.attn.proj.weight": "model-00001-of-00002.safetensors",
39
+ "trunk.blocks.12.attn.qkv.weight": "model-00001-of-00002.safetensors",
40
+ "trunk.blocks.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
41
+ "trunk.blocks.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
42
+ "trunk.blocks.12.mlp.fc3.weight": "model-00001-of-00002.safetensors",
43
+ "trunk.blocks.12.norm_1.weight": "model-00001-of-00002.safetensors",
44
+ "trunk.blocks.12.norm_2.weight": "model-00001-of-00002.safetensors",
45
+ "trunk.blocks.13.attn.proj.weight": "model-00001-of-00002.safetensors",
46
+ "trunk.blocks.13.attn.qkv.weight": "model-00001-of-00002.safetensors",
47
+ "trunk.blocks.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
48
+ "trunk.blocks.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
49
+ "trunk.blocks.13.mlp.fc3.weight": "model-00001-of-00002.safetensors",
50
+ "trunk.blocks.13.norm_1.weight": "model-00001-of-00002.safetensors",
51
+ "trunk.blocks.13.norm_2.weight": "model-00001-of-00002.safetensors",
52
+ "trunk.blocks.14.attn.proj.weight": "model-00001-of-00002.safetensors",
53
+ "trunk.blocks.14.attn.qkv.weight": "model-00001-of-00002.safetensors",
54
+ "trunk.blocks.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
55
+ "trunk.blocks.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
56
+ "trunk.blocks.14.mlp.fc3.weight": "model-00001-of-00002.safetensors",
57
+ "trunk.blocks.14.norm_1.weight": "model-00001-of-00002.safetensors",
58
+ "trunk.blocks.14.norm_2.weight": "model-00001-of-00002.safetensors",
59
+ "trunk.blocks.15.attn.proj.weight": "model-00001-of-00002.safetensors",
60
+ "trunk.blocks.15.attn.qkv.weight": "model-00001-of-00002.safetensors",
61
+ "trunk.blocks.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
62
+ "trunk.blocks.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
63
+ "trunk.blocks.15.mlp.fc3.weight": "model-00001-of-00002.safetensors",
64
+ "trunk.blocks.15.norm_1.weight": "model-00001-of-00002.safetensors",
65
+ "trunk.blocks.15.norm_2.weight": "model-00001-of-00002.safetensors",
66
+ "trunk.blocks.16.attn.proj.weight": "model-00001-of-00002.safetensors",
67
+ "trunk.blocks.16.attn.qkv.weight": "model-00001-of-00002.safetensors",
68
+ "trunk.blocks.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
69
+ "trunk.blocks.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
70
+ "trunk.blocks.16.mlp.fc3.weight": "model-00001-of-00002.safetensors",
71
+ "trunk.blocks.16.norm_1.weight": "model-00001-of-00002.safetensors",
72
+ "trunk.blocks.16.norm_2.weight": "model-00001-of-00002.safetensors",
73
+ "trunk.blocks.17.attn.proj.weight": "model-00001-of-00002.safetensors",
74
+ "trunk.blocks.17.attn.qkv.weight": "model-00001-of-00002.safetensors",
75
+ "trunk.blocks.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
76
+ "trunk.blocks.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
77
+ "trunk.blocks.17.mlp.fc3.weight": "model-00001-of-00002.safetensors",
78
+ "trunk.blocks.17.norm_1.weight": "model-00001-of-00002.safetensors",
79
+ "trunk.blocks.17.norm_2.weight": "model-00001-of-00002.safetensors",
80
+ "trunk.blocks.18.attn.proj.weight": "model-00001-of-00002.safetensors",
81
+ "trunk.blocks.18.attn.qkv.weight": "model-00001-of-00002.safetensors",
82
+ "trunk.blocks.18.mlp.fc1.weight": "model-00001-of-00002.safetensors",
83
+ "trunk.blocks.18.mlp.fc2.weight": "model-00001-of-00002.safetensors",
84
+ "trunk.blocks.18.mlp.fc3.weight": "model-00001-of-00002.safetensors",
85
+ "trunk.blocks.18.norm_1.weight": "model-00001-of-00002.safetensors",
86
+ "trunk.blocks.18.norm_2.weight": "model-00001-of-00002.safetensors",
87
+ "trunk.blocks.19.attn.proj.weight": "model-00001-of-00002.safetensors",
88
+ "trunk.blocks.19.attn.qkv.weight": "model-00001-of-00002.safetensors",
89
+ "trunk.blocks.19.mlp.fc1.weight": "model-00001-of-00002.safetensors",
90
+ "trunk.blocks.19.mlp.fc2.weight": "model-00001-of-00002.safetensors",
91
+ "trunk.blocks.19.mlp.fc3.weight": "model-00001-of-00002.safetensors",
92
+ "trunk.blocks.19.norm_1.weight": "model-00001-of-00002.safetensors",
93
+ "trunk.blocks.19.norm_2.weight": "model-00001-of-00002.safetensors",
94
+ "trunk.blocks.2.attn.proj.weight": "model-00001-of-00002.safetensors",
95
+ "trunk.blocks.2.attn.qkv.weight": "model-00001-of-00002.safetensors",
96
+ "trunk.blocks.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
97
+ "trunk.blocks.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
98
+ "trunk.blocks.2.mlp.fc3.weight": "model-00001-of-00002.safetensors",
99
+ "trunk.blocks.2.norm_1.weight": "model-00001-of-00002.safetensors",
100
+ "trunk.blocks.2.norm_2.weight": "model-00001-of-00002.safetensors",
101
+ "trunk.blocks.20.attn.proj.weight": "model-00001-of-00002.safetensors",
102
+ "trunk.blocks.20.attn.qkv.weight": "model-00001-of-00002.safetensors",
103
+ "trunk.blocks.20.mlp.fc1.weight": "model-00001-of-00002.safetensors",
104
+ "trunk.blocks.20.mlp.fc2.weight": "model-00001-of-00002.safetensors",
105
+ "trunk.blocks.20.mlp.fc3.weight": "model-00001-of-00002.safetensors",
106
+ "trunk.blocks.20.norm_1.weight": "model-00001-of-00002.safetensors",
107
+ "trunk.blocks.20.norm_2.weight": "model-00001-of-00002.safetensors",
108
+ "trunk.blocks.21.attn.proj.weight": "model-00001-of-00002.safetensors",
109
+ "trunk.blocks.21.attn.qkv.weight": "model-00001-of-00002.safetensors",
110
+ "trunk.blocks.21.mlp.fc1.weight": "model-00001-of-00002.safetensors",
111
+ "trunk.blocks.21.mlp.fc2.weight": "model-00001-of-00002.safetensors",
112
+ "trunk.blocks.21.mlp.fc3.weight": "model-00001-of-00002.safetensors",
113
+ "trunk.blocks.21.norm_1.weight": "model-00001-of-00002.safetensors",
114
+ "trunk.blocks.21.norm_2.weight": "model-00001-of-00002.safetensors",
115
+ "trunk.blocks.22.attn.proj.weight": "model-00002-of-00002.safetensors",
116
+ "trunk.blocks.22.attn.qkv.weight": "model-00002-of-00002.safetensors",
117
+ "trunk.blocks.22.mlp.fc1.weight": "model-00002-of-00002.safetensors",
118
+ "trunk.blocks.22.mlp.fc2.weight": "model-00002-of-00002.safetensors",
119
+ "trunk.blocks.22.mlp.fc3.weight": "model-00002-of-00002.safetensors",
120
+ "trunk.blocks.22.norm_1.weight": "model-00002-of-00002.safetensors",
121
+ "trunk.blocks.22.norm_2.weight": "model-00002-of-00002.safetensors",
122
+ "trunk.blocks.23.attn.proj.weight": "model-00002-of-00002.safetensors",
123
+ "trunk.blocks.23.attn.qkv.weight": "model-00002-of-00002.safetensors",
124
+ "trunk.blocks.23.mlp.fc1.weight": "model-00002-of-00002.safetensors",
125
+ "trunk.blocks.23.mlp.fc2.weight": "model-00002-of-00002.safetensors",
126
+ "trunk.blocks.23.mlp.fc3.weight": "model-00002-of-00002.safetensors",
127
+ "trunk.blocks.23.norm_1.weight": "model-00002-of-00002.safetensors",
128
+ "trunk.blocks.23.norm_2.weight": "model-00002-of-00002.safetensors",
129
+ "trunk.blocks.3.attn.proj.weight": "model-00001-of-00002.safetensors",
130
+ "trunk.blocks.3.attn.qkv.weight": "model-00001-of-00002.safetensors",
131
+ "trunk.blocks.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
132
+ "trunk.blocks.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
133
+ "trunk.blocks.3.mlp.fc3.weight": "model-00001-of-00002.safetensors",
134
+ "trunk.blocks.3.norm_1.weight": "model-00001-of-00002.safetensors",
135
+ "trunk.blocks.3.norm_2.weight": "model-00001-of-00002.safetensors",
136
+ "trunk.blocks.4.attn.proj.weight": "model-00001-of-00002.safetensors",
137
+ "trunk.blocks.4.attn.qkv.weight": "model-00001-of-00002.safetensors",
138
+ "trunk.blocks.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
139
+ "trunk.blocks.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
140
+ "trunk.blocks.4.mlp.fc3.weight": "model-00001-of-00002.safetensors",
141
+ "trunk.blocks.4.norm_1.weight": "model-00001-of-00002.safetensors",
142
+ "trunk.blocks.4.norm_2.weight": "model-00001-of-00002.safetensors",
143
+ "trunk.blocks.5.attn.proj.weight": "model-00001-of-00002.safetensors",
144
+ "trunk.blocks.5.attn.qkv.weight": "model-00001-of-00002.safetensors",
145
+ "trunk.blocks.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
146
+ "trunk.blocks.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
147
+ "trunk.blocks.5.mlp.fc3.weight": "model-00001-of-00002.safetensors",
148
+ "trunk.blocks.5.norm_1.weight": "model-00001-of-00002.safetensors",
149
+ "trunk.blocks.5.norm_2.weight": "model-00001-of-00002.safetensors",
150
+ "trunk.blocks.6.attn.proj.weight": "model-00001-of-00002.safetensors",
151
+ "trunk.blocks.6.attn.qkv.weight": "model-00001-of-00002.safetensors",
152
+ "trunk.blocks.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
153
+ "trunk.blocks.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
154
+ "trunk.blocks.6.mlp.fc3.weight": "model-00001-of-00002.safetensors",
155
+ "trunk.blocks.6.norm_1.weight": "model-00001-of-00002.safetensors",
156
+ "trunk.blocks.6.norm_2.weight": "model-00001-of-00002.safetensors",
157
+ "trunk.blocks.7.attn.proj.weight": "model-00001-of-00002.safetensors",
158
+ "trunk.blocks.7.attn.qkv.weight": "model-00001-of-00002.safetensors",
159
+ "trunk.blocks.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
160
+ "trunk.blocks.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
161
+ "trunk.blocks.7.mlp.fc3.weight": "model-00001-of-00002.safetensors",
162
+ "trunk.blocks.7.norm_1.weight": "model-00001-of-00002.safetensors",
163
+ "trunk.blocks.7.norm_2.weight": "model-00001-of-00002.safetensors",
164
+ "trunk.blocks.8.attn.proj.weight": "model-00001-of-00002.safetensors",
165
+ "trunk.blocks.8.attn.qkv.weight": "model-00001-of-00002.safetensors",
166
+ "trunk.blocks.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
167
+ "trunk.blocks.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
168
+ "trunk.blocks.8.mlp.fc3.weight": "model-00001-of-00002.safetensors",
169
+ "trunk.blocks.8.norm_1.weight": "model-00001-of-00002.safetensors",
170
+ "trunk.blocks.8.norm_2.weight": "model-00001-of-00002.safetensors",
171
+ "trunk.blocks.9.attn.proj.weight": "model-00001-of-00002.safetensors",
172
+ "trunk.blocks.9.attn.qkv.weight": "model-00001-of-00002.safetensors",
173
+ "trunk.blocks.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
174
+ "trunk.blocks.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
175
+ "trunk.blocks.9.mlp.fc3.weight": "model-00001-of-00002.safetensors",
176
+ "trunk.blocks.9.norm_1.weight": "model-00001-of-00002.safetensors",
177
+ "trunk.blocks.9.norm_2.weight": "model-00001-of-00002.safetensors",
178
+ "trunk.post_trunk_norm.weight": "model-00002-of-00002.safetensors"
179
+ }
180
+ }
modeling_aimv2.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from .configuration_aimv2 import AIMv2Config
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ __all__ = ["AIMv2Model"]
11
+
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+ self.eps = eps
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ output = self._norm(x.float()).type_as(x)
21
+ return output * self.weight
22
+
23
+ def extra_repr(self) -> str:
24
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
25
+
26
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+
30
+ class AIMv2SwiGLUFFN(nn.Module):
31
+ def __init__(self, config: AIMv2Config):
32
+ super().__init__()
33
+ hidden_features = config.intermediate_size
34
+ in_features = config.hidden_size
35
+ bias = config.use_bias
36
+
37
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
38
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
39
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = F.silu(self.fc1(x)) * self.fc3(x)
43
+ x = self.fc2(x)
44
+ return x
45
+
46
+
47
+ class AIMv2PatchEmbed(nn.Module):
48
+ def __init__(self, config: AIMv2Config):
49
+ super().__init__()
50
+ self.proj = nn.Conv2d(
51
+ config.num_channels,
52
+ config.hidden_size,
53
+ kernel_size=(config.patch_size, config.patch_size),
54
+ stride=(config.patch_size, config.patch_size),
55
+ )
56
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.proj(x).flatten(2).transpose(1, 2)
60
+ x = self.norm(x)
61
+ return x
62
+
63
+
64
+ class AIMv2ViTPreprocessor(nn.Module):
65
+ def __init__(self, config: AIMv2Config):
66
+ super().__init__()
67
+ num_patches = (config.image_size // config.patch_size) ** 2
68
+
69
+ self.patchifier = AIMv2PatchEmbed(config)
70
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ tokens = self.patchifier(x)
74
+ _, N, _ = tokens.shape
75
+ pos_embed = self.pos_embed.to(tokens.device)
76
+ tokens = tokens + pos_embed[:, :N]
77
+ return tokens
78
+
79
+
80
+ class AIMv2Attention(nn.Module):
81
+ def __init__(self, config: AIMv2Config):
82
+ super().__init__()
83
+ dim = config.hidden_size
84
+
85
+ self.num_heads = config.num_attention_heads
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
87
+ self.attn_drop = nn.Dropout(config.attention_dropout)
88
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
89
+ self.proj_drop = nn.Dropout(config.projection_dropout)
90
+
91
+ def forward(
92
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ B, N, C = x.shape
95
+ qkv = (
96
+ self.qkv(x)
97
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
98
+ .permute(2, 0, 3, 1, 4)
99
+ )
100
+ q, k, v = qkv.unbind(0)
101
+
102
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
103
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
104
+ x = self.proj(x)
105
+ x = self.proj_drop(x)
106
+ return x
107
+
108
+
109
+ class AIMv2Block(nn.Module):
110
+ def __init__(self, config: AIMv2Config):
111
+ super().__init__()
112
+ self.attn = AIMv2Attention(config)
113
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+ self.mlp = AIMv2SwiGLUFFN(config)
115
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+
117
+ def forward(
118
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
119
+ ) -> torch.Tensor:
120
+ x = x + self.attn(self.norm_1(x), mask)
121
+ x = x + self.mlp(self.norm_2(x))
122
+ return x
123
+
124
+
125
+ class AIMv2Transformer(nn.Module):
126
+ def __init__(self, config: AIMv2Config):
127
+ super().__init__()
128
+ self.blocks = nn.ModuleList(
129
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
130
+ )
131
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
132
+
133
+ def forward(
134
+ self,
135
+ tokens: torch.Tensor,
136
+ mask: Optional[torch.Tensor] = None,
137
+ output_hidden_states: bool = False,
138
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
139
+ hidden_states = () if output_hidden_states else None
140
+ for block in self.blocks:
141
+ tokens = block(tokens, mask)
142
+ if output_hidden_states:
143
+ hidden_states += (tokens,)
144
+ tokens = self.post_trunk_norm(tokens)
145
+ return tokens, hidden_states
146
+
147
+
148
+ class AIMv2PretrainedModel(PreTrainedModel):
149
+ config_class = AIMv2Config
150
+ base_model_prefix = "aimv2"
151
+ main_input_name = "pixel_values"
152
+ _supports_sdpa = True
153
+
154
+
155
+ class AIMv2Model(AIMv2PretrainedModel):
156
+ def __init__(self, config: AIMv2Config):
157
+ super().__init__(config)
158
+ self.preprocessor = AIMv2ViTPreprocessor(config)
159
+ self.trunk = AIMv2Transformer(config)
160
+
161
+ def forward(
162
+ self,
163
+ pixel_values: torch.Tensor,
164
+ mask: Optional[torch.Tensor] = None,
165
+ output_hidden_states: Optional[bool] = None,
166
+ return_dict: Optional[bool] = None,
167
+ ) -> Union[
168
+ Tuple[torch.Tensor],
169
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
170
+ BaseModelOutputWithNoAttention,
171
+ ]:
172
+ if output_hidden_states is None:
173
+ output_hidden_states = self.config.output_hidden_states
174
+ if return_dict is None:
175
+ return_dict = self.config.use_return_dict
176
+
177
+ x = self.preprocessor(pixel_values)
178
+ x, hidden_states = self.trunk(
179
+ x, mask, output_hidden_states=output_hidden_states
180
+ )
181
+
182
+ if not return_dict:
183
+ res = (x,)
184
+ res += (hidden_states,) if output_hidden_states else ()
185
+ return res
186
+
187
+ return BaseModelOutputWithNoAttention(
188
+ last_hidden_state=x,
189
+ hidden_states=hidden_states,
190
+ )
191
+
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 448,
4
+ "width": 448
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 448
26
+ }
27
+ }