michalk8 commited on
Commit
45e617d
·
1 Parent(s): bf755d9

Upload files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ aimv2_overview_light.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,80 @@
1
- ---
2
- license: apple-ascl
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apple-ascl
4
+ metrics:
5
+ - accuracy
6
+ pipeline_tag: image-feature-extraction
7
+ tags:
8
+ - vision
9
+ - image-feature-extraction
10
+ - mlx
11
+ - pytorch
12
+ ---
13
+ # Introduction
14
+ [[`AIMv2 Paper`](#)] [[`BibTeX`](#citation)]
15
+
16
+ We introduce the AIMv2 family of vision models pre-trained with a multimodal autoregressive objective.
17
+ AIMv2 pre-training is simple and straightforward to train and scale effectively. Some AIMv2 highlights include:
18
+
19
+ 1. Outperforms OAI CLIP and SigLIP on the majority of multimodal understanding benchmarks.
20
+ 2. Outperforms DINOv2 on open-vocabulary object detection and referring expression comprehension.
21
+ 3. Exhibits strong recognition performance with AIMv2-3B achieving *89.5% on ImageNet using a frozen trunk*.
22
+
23
+ <img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
24
+
25
+ ## Usage
26
+
27
+ ### PyTorch
28
+ ```python
29
+ import requests
30
+ from PIL import Image
31
+ from transformers import AutoImageProcessor, AutoModel
32
+
33
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
34
+ image = Image.open(requests.get(url, stream=True).raw)
35
+
36
+ processor = AutoImageProcessor.from_pretrained(
37
+ "apple/aimv2-large-patch14-native",
38
+ )
39
+ model = AutoModel.from_pretrained(
40
+ "apple/aimv2-large-patch14-native",
41
+ trust_remote_code=True,
42
+ )
43
+
44
+ inputs = processor(images=image, return_tensors="pt")
45
+ outputs = model(**inputs)
46
+ ```
47
+
48
+ ### JAX
49
+ ```python
50
+ import requests
51
+ from PIL import Image
52
+ from transformers import AutoImageProcessor, FlaxAutoModel
53
+
54
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
55
+ image = Image.open(requests.get(url, stream=True).raw)
56
+
57
+ processor = AutoImageProcessor.from_pretrained(
58
+ "apple/aimv2-large-patch14-native",
59
+ )
60
+ model = FlaxAutoModel.from_pretrained(
61
+ "apple/aimv2-large-patch14-native",
62
+ trust_remote_code=True,
63
+ )
64
+
65
+ inputs = processor(images=image, return_tensors="jax")
66
+ outputs = model(**inputs)
67
+ ```
68
+
69
+ ## Citation
70
+ If you find our work useful, please consider citing us as:
71
+ ```bibtex
72
+ @misc{fini2024multimodal,
73
+ title = {Multimodal Autoregressive Pre-training of Large Vision Encoders},
74
+ author = {Enrico Fini and Mustafa Shukor and Xiujun Li and Philipp Dufter and Michal Klein and David Haldimann and Sai Aitharaju and Victor Guilherme Turrisi da Costa and Louis Béthune and Zhe Gan and Alexander T Toshev and Marcin Eichner and Moin Nabi and Yinfei Yang and Joshua M. Susskind and Alaaeldin El-Nouby},
75
+ year = {2024},
76
+ archivePrefix = {arXiv},
77
+ primaryClass = {cs.CV},
78
+ }
79
+ ```
80
+
aimv2_overview_light.png ADDED

Git LFS Details

  • SHA256: 524b6eb5049fb4bac6303ecee386d0e885fa69a96756557d843084ba4caae08f
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AIMv2Model"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aimv2.AIMv2Config",
8
+ "AutoModel": "modeling_aimv2.AIMv2Model",
9
+ "FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
10
+ },
11
+ "hidden_size": 1024,
12
+ "intermediate_size": 2816,
13
+ "model_type": "aimv2",
14
+ "num_attention_heads": 8,
15
+ "num_channels": 3,
16
+ "num_hidden_layers": 24,
17
+ "num_queries": 256,
18
+ "patch_size": 14,
19
+ "projection_dropout": 0.0,
20
+ "qkv_bias": false,
21
+ "rms_norm_eps": 1e-05,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.46.3",
24
+ "use_bias": false
25
+ }
configuration_aimv2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["AIMv2Config"]
6
+
7
+
8
+ class AIMv2Config(PretrainedConfig):
9
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
10
+
11
+ Instantiating a configuration with the defaults will yield a similar configuration
12
+ to that of the [apple/aimv2-large-patch14-native](https://huggingface.co/apple/aimv2-large-patch14-native)
13
+
14
+ Args:
15
+ hidden_size: Dimension of the hidden representations.
16
+ intermediate_size: Dimension of the SwiGLU representations.
17
+ num_hidden_layers: Number of hidden layers in the Transformer.
18
+ num_attention_heads: Number of attention heads for each attention layer
19
+ in the Transformer.
20
+ num_channels: Number of input channels.
21
+ num_queries: Number of learnable queries in the head.
22
+ patch_size: Patch size.
23
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
24
+ attention_dropout: Dropout ratio for attention probabilities.
25
+ projection_dropout: Dropout ratio for the projection layer after the attention.
26
+ qkv_bias: Whether to add a bias to the queries, keys and values.
27
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
28
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
29
+ """
30
+
31
+ model_type: str = "aimv2"
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 1024,
36
+ intermediate_size: int = 2816,
37
+ num_hidden_layers: int = 24,
38
+ num_attention_heads: int = 8,
39
+ num_channels: int = 3,
40
+ num_queries: int = 256,
41
+ patch_size: int = 14,
42
+ rms_norm_eps: float = 1e-5,
43
+ attention_dropout: float = 0.0,
44
+ projection_dropout: float = 0.0,
45
+ qkv_bias: bool = False,
46
+ use_bias: bool = False,
47
+ **kwargs: Any,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_channels = num_channels
55
+ self.num_queries = num_queries
56
+ self.patch_size = patch_size
57
+ self.attention_dropout = attention_dropout
58
+ self.rms_norm_eps = rms_norm_eps
59
+
60
+ self.projection_dropout = projection_dropout
61
+ self.qkv_bias = qkv_bias
62
+ self.use_bias = use_bias
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d52bb8cda1854e55348e9e6046046cc8a8b6218167d8b8580a340f6f4b172ca4
3
+ size 1235749750
mlx_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2de3d578c9360aaadb2174946f1c72a93b3861f0687dc1b66ad254ddf80f2da
3
+ size 1235760719
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cdc4c4ea6f2a477edebb482cc36ba021409a313eabdf3e6be62eb722771e7d1
3
+ size 1235760720
modeling_aimv2.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from .configuration_aimv2 import AIMv2Config
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ __all__ = ["AIMv2Model"]
11
+
12
+
13
+ def _get_1d_sincos_pos_embed_from_grid(
14
+ embed_dim: int, pos: torch.Tensor
15
+ ) -> torch.Tensor:
16
+ omega = torch.arange(embed_dim // 2).float()
17
+ omega /= embed_dim / 2.0
18
+ omega = 1.0 / 10000**omega # (D / 2,)
19
+ pos = pos.reshape(-1) # (M,)
20
+ out = pos[:, None] * omega[None, :] # (M, D / 2), outer product
21
+ emb_sin, emb_cos = torch.sin(out), torch.cos(out) # (M, D / 2)
22
+ emb = torch.concatenate([emb_sin, emb_cos], dim=1) # (M, D)
23
+ return emb
24
+
25
+
26
+ def get_sincos_pos_embed(h: int, w: int, embed_dim: int) -> torch.Tensor:
27
+ assert embed_dim % 2 == 0, embed_dim
28
+ grid_h = torch.arange(h).float()
29
+ grid_w = torch.arange(w).float()
30
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
31
+ grid = torch.stack(grid, dim=0)
32
+ grid = grid.reshape([2, 1, h, w])
33
+ emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
34
+ emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
35
+ pos_embed = torch.concatenate([emb_h, emb_w], dim=1) # (H * W, D)
36
+ return pos_embed
37
+
38
+
39
+ class RMSNorm(nn.Module):
40
+ def __init__(self, dim: int, eps: float = 1e-6):
41
+ super().__init__()
42
+ self.weight = nn.Parameter(torch.ones(dim))
43
+ self.eps = eps
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ output = self._norm(x.float()).type_as(x)
47
+ return output * self.weight
48
+
49
+ def extra_repr(self) -> str:
50
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
51
+
52
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
53
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
54
+
55
+
56
+ class AIMv2SwiGLUFFN(nn.Module):
57
+ def __init__(self, config: AIMv2Config):
58
+ super().__init__()
59
+ hidden_features = config.intermediate_size
60
+ in_features = config.hidden_size
61
+ bias = config.use_bias
62
+
63
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
64
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
65
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ x = F.silu(self.fc1(x)) * self.fc3(x)
69
+ x = self.fc2(x)
70
+ return x
71
+
72
+
73
+ class AIMv2PatchEmbed(nn.Module):
74
+ def __init__(self, config: AIMv2Config):
75
+ super().__init__()
76
+ self.proj = nn.Conv2d(
77
+ config.num_channels,
78
+ config.hidden_size,
79
+ kernel_size=(config.patch_size, config.patch_size),
80
+ stride=(config.patch_size, config.patch_size),
81
+ )
82
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ x = self.proj(x).flatten(2).transpose(1, 2)
86
+ x = self.norm(x)
87
+ return x
88
+
89
+
90
+ class AIMv2ViTPreprocessor(nn.Module):
91
+ def __init__(self, config: AIMv2Config):
92
+ super().__init__()
93
+ self.patch_h = config.patch_size
94
+ self.patch_w = config.patch_size
95
+ self.embed_dim = config.hidden_size
96
+
97
+ self.patchifier = AIMv2PatchEmbed(config)
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ _, _, H, W = x.shape
101
+ tokens = self.patchifier(x)
102
+ pos_embed = get_sincos_pos_embed(
103
+ H // self.patch_h, W // self.patch_w, embed_dim=self.embed_dim
104
+ )
105
+ tokens = tokens + pos_embed
106
+ return tokens
107
+
108
+
109
+ class AIMv2Attention(nn.Module):
110
+ def __init__(self, config: AIMv2Config):
111
+ super().__init__()
112
+ dim = config.hidden_size
113
+
114
+ self.num_heads = config.num_attention_heads
115
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
116
+ self.attn_drop = nn.Dropout(config.attention_dropout)
117
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
118
+ self.proj_drop = nn.Dropout(config.projection_dropout)
119
+
120
+ def forward(
121
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
122
+ ) -> torch.Tensor:
123
+ B, N, C = x.shape
124
+ qkv = (
125
+ self.qkv(x)
126
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
127
+ .permute(2, 0, 3, 1, 4)
128
+ )
129
+ q, k, v = qkv.unbind(0)
130
+
131
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
132
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
133
+ x = self.proj(x)
134
+ x = self.proj_drop(x)
135
+ return x
136
+
137
+
138
+ class AIMv2Block(nn.Module):
139
+ def __init__(self, config: AIMv2Config):
140
+ super().__init__()
141
+ self.attn = AIMv2Attention(config)
142
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143
+ self.mlp = AIMv2SwiGLUFFN(config)
144
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
145
+
146
+ def forward(
147
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
148
+ ) -> torch.Tensor:
149
+ x = x + self.attn(self.norm_1(x), mask)
150
+ x = x + self.mlp(self.norm_2(x))
151
+ return x
152
+
153
+
154
+ class AIMv2Transformer(nn.Module):
155
+ def __init__(self, config: AIMv2Config):
156
+ super().__init__()
157
+ self.blocks = nn.ModuleList(
158
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
159
+ )
160
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
+
162
+ def forward(
163
+ self,
164
+ tokens: torch.Tensor,
165
+ mask: Optional[torch.Tensor] = None,
166
+ output_hidden_states: bool = False,
167
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
168
+ hidden_states = () if output_hidden_states else None
169
+ for block in self.blocks:
170
+ tokens = block(tokens, mask)
171
+ if output_hidden_states:
172
+ hidden_states += (tokens,)
173
+ tokens = self.post_trunk_norm(tokens)
174
+ return tokens, hidden_states
175
+
176
+
177
+ class AIMv2PretrainedModel(PreTrainedModel):
178
+ config_class = AIMv2Config
179
+ base_model_prefix = "aimv2"
180
+ main_input_name = "pixel_values"
181
+ _supports_sdpa = True
182
+
183
+
184
+ class AIMv2Model(AIMv2PretrainedModel):
185
+ def __init__(self, config: AIMv2Config):
186
+ super().__init__(config)
187
+ self.preprocessor = AIMv2ViTPreprocessor(config)
188
+ self.trunk = AIMv2Transformer(config)
189
+
190
+ def forward(
191
+ self,
192
+ pixel_values: torch.Tensor,
193
+ mask: Optional[torch.Tensor] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ ) -> Union[
197
+ Tuple[torch.Tensor],
198
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
199
+ BaseModelOutputWithNoAttention,
200
+ ]:
201
+ if output_hidden_states is None:
202
+ output_hidden_states = self.config.output_hidden_states
203
+ if return_dict is None:
204
+ return_dict = self.config.use_return_dict
205
+
206
+ x = self.preprocessor(pixel_values)
207
+ x, hidden_states = self.trunk(
208
+ x, mask, output_hidden_states=output_hidden_states
209
+ )
210
+
211
+ if not return_dict:
212
+ res = (x,)
213
+ res += (hidden_states,) if output_hidden_states else ()
214
+ return res
215
+
216
+ return BaseModelOutputWithNoAttention(
217
+ last_hidden_state=x,
218
+ hidden_states=hidden_states,
219
+ )
220
+
modeling_flax_aimv2.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple, Union
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from .configuration_aimv2 import AIMv2Config
7
+ from flax.core import frozen_dict
8
+ from transformers import FlaxPreTrainedModel
9
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
10
+
11
+ __all__ = ["FlaxAIMv2Model"]
12
+
13
+
14
+ def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: jax.Array) -> jax.Array:
15
+ omega = jnp.arange(embed_dim // 2, dtype=pos.dtype)
16
+ omega /= embed_dim / 2.0
17
+ omega = 1.0 / 10000**omega # (D / 2,)
18
+ pos = pos.reshape(-1) # (M,)
19
+ out = pos[:, None] * omega[None, :] # (M, D / 2), outer product
20
+ emb_sin, emb_cos = jnp.sin(out), jnp.cos(out) # (M, D / 2)
21
+ emb = jnp.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
22
+ return emb
23
+
24
+
25
+ def get_sincos_pos_embed(
26
+ h: int,
27
+ w: int,
28
+ embed_dim: int,
29
+ dtype: jnp.dtype = jnp.float32,
30
+ ) -> jax.Array:
31
+ assert embed_dim % 2 == 0, embed_dim
32
+ grid_h = jnp.arange(h, dtype=dtype)
33
+ grid_w = jnp.arange(w, dtype=dtype)
34
+ grid = jnp.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = jnp.stack(grid, axis=0)
36
+ grid = grid.reshape([2, 1, h, w])
37
+ emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
38
+ emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
39
+ pos_embed = jnp.concatenate([emb_h, emb_w], axis=1) # (H * W, D)
40
+ return pos_embed
41
+
42
+
43
+ class FlaxRMSNorm(nn.Module):
44
+ eps: float = 1e-6
45
+
46
+ @nn.compact
47
+ def __call__(self, x: jax.Array) -> jax.Array:
48
+ dim = x.shape[-1]
49
+ scale = self.param("scale", nn.initializers.ones_init(), (dim,))
50
+ output = self._norm(x.astype(jnp.float32)).astype(x.dtype)
51
+ output = output * scale.astype(x.dtype)
52
+ return output
53
+
54
+ def _norm(self, x: jax.Array) -> jax.Array:
55
+ return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps)
56
+
57
+
58
+ class FlaxAIMv2SwiGLUFFN(nn.Module):
59
+ config: AIMv2Config
60
+ dtype: jnp.dtype = jnp.float32
61
+
62
+ @nn.compact
63
+ def __call__(self, x: jax.Array) -> jax.Array:
64
+ hidden_features = self.config.intermediate_size
65
+ in_features = self.config.hidden_size
66
+ bias = self.config.use_bias
67
+
68
+ x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x)
69
+ x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x)
70
+ x = nn.silu(x1) * x2
71
+ x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x)
72
+ return x
73
+
74
+
75
+ class FlaxAIMv2PatchEmbed(nn.Module):
76
+ config: AIMv2Config
77
+ dtype: jnp.dtype = jnp.float32
78
+
79
+ @nn.compact
80
+ def __call__(self, x: jax.Array) -> jax.Array:
81
+ patch_size = (self.config.patch_size, self.config.patch_size)
82
+ x = x.transpose(0, 2, 3, 1) # (N C H W) -> (N H W C)
83
+ x = nn.Conv(
84
+ self.config.hidden_size,
85
+ kernel_size=patch_size,
86
+ strides=patch_size,
87
+ padding=(0, 0),
88
+ dtype=self.dtype,
89
+ name="proj",
90
+ )(x)
91
+ x = jax.lax.collapse(x, 1, 3) # (N, H * W, F)
92
+ x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x)
93
+ return x
94
+
95
+
96
+ class FlaxAIMv2ViTPreprocessor(nn.Module):
97
+ config: AIMv2Config
98
+ dtype: jnp.dtype = jnp.float32
99
+
100
+ @nn.compact
101
+ def __call__(self, x: jax.Array) -> jax.Array:
102
+ _, _, H, W = x.shape
103
+ patch_h = self.config.patch_size
104
+ patch_w = self.config.patch_size
105
+
106
+ tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")(
107
+ x
108
+ )
109
+ pos_embed = get_sincos_pos_embed(
110
+ H // patch_h,
111
+ W // patch_w,
112
+ embed_dim=self.config.hidden_size,
113
+ dtype=self.dtype,
114
+ )
115
+ tokens = tokens + pos_embed
116
+ return tokens
117
+
118
+
119
+ class FlaxAIMv2Attention(nn.Module):
120
+ config: AIMv2Config
121
+ dtype: jnp.dtype = jnp.float32
122
+
123
+ @nn.compact
124
+ def __call__(
125
+ self,
126
+ x: jax.Array,
127
+ mask: Optional[jax.Array] = None,
128
+ deterministic: bool = True,
129
+ output_attentions: bool = False,
130
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
131
+ B, N, C = x.shape
132
+ dim, num_heads = self.config.hidden_size, self.config.num_attention_heads
133
+
134
+ qkv = nn.Dense(
135
+ dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv"
136
+ )(x)
137
+ qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4)
138
+ q, k, v = qkv[0], qkv[1], qkv[2]
139
+
140
+ attn_weights = nn.dot_product_attention_weights(
141
+ q.swapaxes(-3, -2), # [B, N, H, C]
142
+ k.swapaxes(-3, -2),
143
+ mask=mask,
144
+ deterministic=deterministic,
145
+ dtype=self.dtype,
146
+ )
147
+ attn_weights = nn.Dropout(
148
+ self.config.attention_dropout, deterministic=deterministic, name="attn_drop"
149
+ )(attn_weights)
150
+
151
+ x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C)
152
+ x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")(
153
+ x
154
+ )
155
+ x = nn.Dropout(
156
+ self.config.projection_dropout,
157
+ deterministic=deterministic,
158
+ name="proj_drop",
159
+ )(x)
160
+ return (x, attn_weights) if output_attentions else (x, None)
161
+
162
+
163
+ class FlaxAIMv2Block(nn.Module):
164
+ config: AIMv2Config
165
+ dtype: jnp.dtype = jnp.float32
166
+
167
+ def setup(self):
168
+ self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn")
169
+ self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1")
170
+ self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp")
171
+ self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2")
172
+
173
+ def __call__(
174
+ self,
175
+ x: jax.Array,
176
+ mask: Optional[jax.Array] = None,
177
+ deterministic: bool = True,
178
+ output_attentions: bool = False,
179
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
180
+ features, attention = self.attn(
181
+ self.norm_1(x),
182
+ mask,
183
+ deterministic=deterministic,
184
+ output_attentions=output_attentions,
185
+ )
186
+ x = x + features
187
+ x = x + self.mlp(self.norm_2(x))
188
+ return x, attention
189
+
190
+
191
+ class FlaxAIMv2Transformer(nn.Module):
192
+ config: AIMv2Config
193
+ dtype: jnp.dtype = jnp.float32
194
+
195
+ @nn.compact
196
+ def __call__(
197
+ self,
198
+ tokens: jax.Array,
199
+ mask: Optional[jax.Array] = None,
200
+ deterministic: bool = True,
201
+ output_attentions: bool = False,
202
+ output_hidden_states: bool = False,
203
+ ) -> Tuple[
204
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
205
+ ]:
206
+ hidden_states = () if output_hidden_states else None
207
+ attentions = () if output_attentions else None
208
+ for blk_id, block in enumerate(range(self.config.num_hidden_layers)):
209
+ tokens, attention = FlaxAIMv2Block(
210
+ self.config, dtype=self.dtype, name=f"layers_{blk_id}"
211
+ )(
212
+ tokens,
213
+ mask,
214
+ deterministic=deterministic,
215
+ output_attentions=output_attentions,
216
+ )
217
+ if output_hidden_states:
218
+ hidden_states += (tokens,)
219
+ if output_attentions:
220
+ attentions += (attention,)
221
+ tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens)
222
+ return tokens, hidden_states, attentions
223
+
224
+
225
+ class FlaxAIMv2Module(nn.Module):
226
+ config: AIMv2Config
227
+ dtype: jnp.dtype = jnp.float32
228
+
229
+ @nn.compact
230
+ def __call__(
231
+ self,
232
+ x: jax.Array,
233
+ mask: Optional[jax.Array] = None,
234
+ deterministic: bool = True,
235
+ output_attentions: bool = False,
236
+ output_hidden_states: bool = False,
237
+ ) -> Tuple[
238
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
239
+ ]:
240
+ x = FlaxAIMv2ViTPreprocessor(
241
+ self.config, dtype=self.dtype, name="preprocessor"
242
+ )(x)
243
+ x, hidden_states, attentions = FlaxAIMv2Transformer(
244
+ self.config, dtype=self.dtype, name="trunk"
245
+ )(
246
+ x,
247
+ mask,
248
+ deterministic=deterministic,
249
+ output_attentions=output_attentions,
250
+ output_hidden_states=output_hidden_states,
251
+ )
252
+ return x, hidden_states, attentions
253
+
254
+
255
+ class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel):
256
+ config_class = AIMv2Config
257
+ base_model_prefix = "aimv2"
258
+ main_input_name = "pixel_values"
259
+
260
+ def __init__(
261
+ self,
262
+ config: AIMv2Config,
263
+ input_shape: Optional[Tuple[int, int, int, int]] = None, # [B, C, H, W]
264
+ dtype: jnp.dtype = jnp.float32,
265
+ **kwargs: Any,
266
+ ):
267
+ if input_shape is None:
268
+ # no effect on the param shape
269
+ input_shape = (1, 3, 224, 224)
270
+ super().__init__(
271
+ config,
272
+ module=FlaxAIMv2Module(config, dtype=dtype),
273
+ input_shape=input_shape,
274
+ dtype=dtype,
275
+ **kwargs,
276
+ )
277
+
278
+ def init_weights(
279
+ self,
280
+ rng: jax.Array,
281
+ input_shape: Tuple[int, ...],
282
+ params: Optional[frozen_dict.FrozenDict] = None,
283
+ ) -> frozen_dict.FrozenDict:
284
+ del params
285
+ input_pixels = jnp.empty(input_shape)
286
+ params = self.module.init(rng, input_pixels, deterministic=True)
287
+ return params["params"]
288
+
289
+
290
+ class FlaxAIMv2Model(FlaxAIMv2PretrainedModel):
291
+ def __call__(
292
+ self,
293
+ pixel_values: jax.Array,
294
+ params: Optional[frozen_dict.FrozenDict] = None,
295
+ mask: Optional[jax.Array] = None,
296
+ dropout_rng: Optional[jax.Array] = None,
297
+ deterministic: bool = True,
298
+ output_attentions: Optional[bool] = None,
299
+ output_hidden_states: Optional[bool] = None,
300
+ return_dict: Optional[bool] = None,
301
+ ) -> Union[
302
+ Tuple[jax.Array],
303
+ Tuple[jax.Array, Tuple[jax.Array, ...]],
304
+ Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]],
305
+ FlaxBaseModelOutput,
306
+ ]:
307
+ if params is None:
308
+ params = self.params
309
+ if output_attentions is None:
310
+ output_attentions = self.config.output_attentions
311
+ if output_hidden_states is None:
312
+ output_hidden_states = self.config.output_hidden_states
313
+ if return_dict is None:
314
+ return_dict = self.config.use_return_dict
315
+
316
+ rngs = None if deterministic else {"dropout": dropout_rng}
317
+
318
+ x, hidden_states, attentions = self.module.apply(
319
+ {"params": params},
320
+ pixel_values,
321
+ mask,
322
+ rngs=rngs,
323
+ deterministic=deterministic,
324
+ output_attentions=output_attentions,
325
+ output_hidden_states=output_hidden_states,
326
+ )
327
+
328
+ if not return_dict:
329
+ res = (x,)
330
+ res += (hidden_states,) if output_hidden_states else ()
331
+ res += (attentions,) if output_attentions else ()
332
+ return res
333
+
334
+ return FlaxBaseModelOutput(
335
+ last_hidden_state=x,
336
+ hidden_states=hidden_states,
337
+ attentions=attentions,
338
+ )
339
+
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }