zekaic commited on
Commit
f94c6d3
·
verified ·
1 Parent(s): 2c0a3a0

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SMBVisionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_smb_vision.SMBVisionModelConfig",
7
+ "AutoModel": "modeling_smb_vision.SMBVisionModel"
8
+ },
9
+ "dtype": "bfloat16",
10
+ "hidden_size": 1152,
11
+ "masking_ratio": 0.65,
12
+ "model_type": "smb_vision_model",
13
+ "predictor_config": {
14
+ "depth": 12,
15
+ "dtype": "bfloat16",
16
+ "hidden_act": "gelu_pytorch_tanh",
17
+ "hidden_size": 512,
18
+ "in_channels": 1,
19
+ "in_hidden_size": 1152,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 1536,
22
+ "model_type": "smb_vision_predictor",
23
+ "num_heads": 16
24
+ },
25
+ "transformers_version": "4.57.3",
26
+ "use_cache": true,
27
+ "vision_config": {
28
+ "deepstack_visual_indexes": [
29
+ 8,
30
+ 16,
31
+ 24
32
+ ],
33
+ "depth": 27,
34
+ "dtype": "bfloat16",
35
+ "hidden_act": "gelu_pytorch_tanh",
36
+ "hidden_size": 1152,
37
+ "in_channels": 1,
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 4304,
40
+ "model_type": "smb_vision_encoder",
41
+ "num_heads": 16,
42
+ "num_position_embeddings": 2304,
43
+ "out_hidden_size": 2048,
44
+ "patch_size": 16,
45
+ "spatial_merge_size": 2,
46
+ "temporal_patch_size": 16
47
+ }
48
+ }
configuration_smb_vision.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The SMB Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+
18
+ class SMBVisionConfig(PretrainedConfig):
19
+ model_type = "smb_vision_encoder"
20
+ base_config_key = "vision_config"
21
+
22
+ def __init__(
23
+ self,
24
+ depth=27,
25
+ hidden_size=1152,
26
+ hidden_act="gelu_pytorch_tanh",
27
+ intermediate_size=4304,
28
+ num_heads=16,
29
+ in_channels=3,
30
+ patch_size=16,
31
+ spatial_merge_size=2,
32
+ temporal_patch_size=2,
33
+ out_hidden_size=3584,
34
+ num_position_embeddings=2304,
35
+ deepstack_visual_indexes=[8, 16, 24],
36
+ initializer_range=0.02,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.depth = depth
42
+ self.hidden_size = hidden_size
43
+ self.hidden_act = hidden_act
44
+ self.intermediate_size = intermediate_size
45
+ self.num_heads = num_heads
46
+ self.in_channels = in_channels
47
+ self.patch_size = patch_size
48
+ self.spatial_merge_size = spatial_merge_size
49
+ self.temporal_patch_size = temporal_patch_size
50
+ self.out_hidden_size = out_hidden_size
51
+ self.num_position_embeddings = num_position_embeddings
52
+ self.initializer_range = initializer_range
53
+ self.deepstack_visual_indexes = deepstack_visual_indexes
54
+
55
+
56
+ class SMBVisionPredictorConfig(PretrainedConfig):
57
+ model_type = "smb_vision_predictor"
58
+ base_config_key = "predictor_config"
59
+
60
+ def __init__(
61
+ self,
62
+ depth=27,
63
+ in_hidden_size=1152,
64
+ hidden_size=512,
65
+ hidden_act="gelu_pytorch_tanh",
66
+ intermediate_size=1536,
67
+ num_heads=16,
68
+ in_channels=1,
69
+ initializer_range=0.02,
70
+ **kwargs,
71
+ ):
72
+ super().__init__(**kwargs)
73
+
74
+ self.depth = depth
75
+ self.in_hidden_size = in_hidden_size
76
+ self.hidden_size = hidden_size
77
+ self.hidden_act = hidden_act
78
+ self.intermediate_size = intermediate_size
79
+ self.num_heads = num_heads
80
+ self.in_channels = in_channels
81
+ self.initializer_range = initializer_range
82
+
83
+
84
+ class SMBVisionModelConfig(PretrainedConfig):
85
+ model_type = "smb_vision_model"
86
+ sub_configs = {"vision_config": SMBVisionConfig, "predictor_config": SMBVisionPredictorConfig}
87
+
88
+ def __init__(
89
+ self,
90
+ vision_config=None,
91
+ predictor_config=None,
92
+ hidden_size=1152,
93
+ masking_ratio=0.1,
94
+ **kwargs,
95
+ ):
96
+ if isinstance(vision_config, dict):
97
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
98
+ elif vision_config is None:
99
+ self.vision_config = self.sub_configs["vision_config"]()
100
+
101
+ if isinstance(predictor_config, dict):
102
+ self.predictor_config = self.sub_configs["predictor_config"](**predictor_config)
103
+ elif predictor_config is None:
104
+ self.predictor_config = self.sub_configs["predictor_config"]()
105
+
106
+ self.hidden_size = hidden_size
107
+ self.masking_ratio = masking_ratio
108
+
109
+ super().__init__(**kwargs)
110
+
111
+
112
+ __all__ = ["SMBVisionConfig", "SMBVisionPredictorConfig", "SMBVisionModelConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07c086b52183afd27fca41a807a7c3b9359913143a243f8299b1c279ef2d6922
3
+ size 1224159656
modeling_smb_vision.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The SMB Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from transformers.activations import ACT2FN
24
+ from transformers.modeling_layers import GradientCheckpointingLayer
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
26
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from transformers.processing_utils import Unpack
28
+ from transformers.utils import TransformersKwargs
29
+ from .configuration_smb_vision import (
30
+ SMBVisionConfig,
31
+ SMBVisionPredictorConfig,
32
+ SMBVisionModelConfig,
33
+ )
34
+
35
+
36
+ def rotate_half(x):
37
+ """Rotates half the hidden dims of the input."""
38
+ x1 = x[..., : x.shape[-1] // 2]
39
+ x2 = x[..., x.shape[-1] // 2 :]
40
+ return torch.cat((-x2, x1), dim=-1)
41
+
42
+
43
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
44
+ """
45
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
46
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
47
+ """
48
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
49
+ if n_rep == 1:
50
+ return hidden_states
51
+ hidden_states = hidden_states[:, :, None, :, :].expand(
52
+ batch, num_key_value_heads, n_rep, slen, head_dim
53
+ )
54
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
55
+
56
+
57
+ def eager_attention_forward(
58
+ module: nn.Module,
59
+ query: torch.Tensor,
60
+ key: torch.Tensor,
61
+ value: torch.Tensor,
62
+ attention_mask: Optional[torch.Tensor],
63
+ scaling: float,
64
+ dropout: float = 0.0,
65
+ **kwargs: Unpack[TransformersKwargs],
66
+ ):
67
+ key_states = repeat_kv(key, module.num_key_value_groups)
68
+ value_states = repeat_kv(value, module.num_key_value_groups)
69
+
70
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
71
+ if attention_mask is not None:
72
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
73
+ attn_weights = attn_weights + causal_mask
74
+
75
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
76
+ query.dtype
77
+ )
78
+ attn_weights = nn.functional.dropout(
79
+ attn_weights, p=dropout, training=module.training
80
+ )
81
+ attn_output = torch.matmul(attn_weights, value_states)
82
+ attn_output = attn_output.transpose(1, 2).contiguous()
83
+
84
+ return attn_output, attn_weights
85
+
86
+
87
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
88
+ """Applies Rotary Position Embedding to the query and key tensors.
89
+
90
+ Args:
91
+ q (`torch.Tensor`): The query tensor.
92
+ k (`torch.Tensor`): The key tensor.
93
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
94
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
95
+ position_ids (`torch.Tensor`, *optional*):
96
+ Deprecated and unused.
97
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
98
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
99
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
100
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
101
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
102
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
103
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
104
+ Returns:
105
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
106
+ """
107
+ cos = cos.unsqueeze(unsqueeze_dim)
108
+ sin = sin.unsqueeze(unsqueeze_dim)
109
+ q_embed = (q * cos) + (rotate_half(q) * sin)
110
+ k_embed = (k * cos) + (rotate_half(k) * sin)
111
+ return q_embed, k_embed
112
+
113
+
114
+ class SMBVisionMLP(nn.Module):
115
+ def __init__(self, config):
116
+ super().__init__()
117
+ self.hidden_size = config.hidden_size
118
+ self.intermediate_size = config.intermediate_size
119
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
120
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
121
+ self.act_fn = ACT2FN[config.hidden_act]
122
+
123
+ def forward(self, hidden_state):
124
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
125
+
126
+
127
+ class SMBVisionPatchEmbed(nn.Module):
128
+ def __init__(self, config) -> None:
129
+ super().__init__()
130
+ self.patch_size = config.patch_size
131
+ self.temporal_patch_size = config.temporal_patch_size
132
+ self.in_channels = config.in_channels
133
+ self.embed_dim = config.hidden_size
134
+
135
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
136
+ for in_channels in [1, 3, 4]:
137
+ setattr(
138
+ self,
139
+ f"proj_c{in_channels}",
140
+ nn.Conv3d(
141
+ in_channels,
142
+ self.embed_dim,
143
+ kernel_size=kernel_size,
144
+ stride=kernel_size,
145
+ bias=True,
146
+ ),
147
+ )
148
+
149
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150
+ target_dtype = self.proj_c1.weight.dtype
151
+ if self.in_channels == 1: # grayscale
152
+ hidden_states = hidden_states.view(
153
+ -1, 1, self.temporal_patch_size, self.patch_size, self.patch_size
154
+ )
155
+ hidden_states = self.proj_c1(hidden_states.to(dtype=target_dtype)).view(
156
+ -1, self.embed_dim
157
+ )
158
+ elif self.in_channels == 3: # rgb
159
+ hidden_states = hidden_states.view(
160
+ -1, 3, self.temporal_patch_size, self.patch_size, self.patch_size
161
+ )
162
+ hidden_states = self.proj_c3(hidden_states.to(dtype=target_dtype)).view(
163
+ -1, self.embed_dim
164
+ )
165
+ elif self.in_channels == 4: # multi sequence
166
+ hidden_states = hidden_states.view(
167
+ -1, 4, self.temporal_patch_size, self.patch_size, self.patch_size
168
+ )
169
+ hidden_states = self.proj_c4(hidden_states.to(dtype=target_dtype)).view(
170
+ -1, self.embed_dim
171
+ )
172
+ else:
173
+ raise ValueError(f"Unsupported number of channels: {self.in_channels}")
174
+ return hidden_states
175
+
176
+
177
+ class SMBVisionRotaryEmbedding(nn.Module):
178
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
179
+
180
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
181
+ super().__init__()
182
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
183
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
184
+
185
+ def forward(self, seqlen: int) -> torch.Tensor:
186
+ seq = torch.arange(
187
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
188
+ )
189
+ freqs = torch.outer(seq, self.inv_freq)
190
+ return freqs
191
+
192
+
193
+ class SMBVisionPatchMerger(nn.Module):
194
+ def __init__(self, config, use_postshuffle_norm=False) -> None:
195
+ super().__init__()
196
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
197
+ self.use_postshuffle_norm = use_postshuffle_norm
198
+ self.norm = nn.LayerNorm(
199
+ self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6
200
+ )
201
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
202
+ self.act_fn = nn.GELU()
203
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ x = self.norm(
207
+ x.contiguous().view(-1, self.hidden_size)
208
+ if self.use_postshuffle_norm
209
+ else x
210
+ ).view(-1, self.hidden_size)
211
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
212
+ return x
213
+
214
+
215
+ def apply_rotary_pos_emb_vision(
216
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
217
+ ) -> tuple[torch.Tensor, torch.Tensor]:
218
+ orig_q_dtype = q.dtype
219
+ orig_k_dtype = k.dtype
220
+ q, k = q.float(), k.float()
221
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
222
+ q_embed = (q * cos) + (rotate_half(q) * sin)
223
+ k_embed = (k * cos) + (rotate_half(k) * sin)
224
+ q_embed = q_embed.to(orig_q_dtype)
225
+ k_embed = k_embed.to(orig_k_dtype)
226
+ return q_embed, k_embed
227
+
228
+
229
+ class SMBVisionAttention(nn.Module):
230
+ def __init__(self, config) -> None:
231
+ super().__init__()
232
+ self.dim = config.hidden_size
233
+ self.num_heads = config.num_heads
234
+ self.head_dim = self.dim // self.num_heads
235
+ self.num_key_value_groups = 1 # needed for eager attention
236
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
237
+ self.proj = nn.Linear(self.dim, self.dim)
238
+ self.scaling = self.head_dim**-0.5
239
+ self.config = config
240
+ self.attention_dropout = 0.0
241
+ self.is_causal = False
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.Tensor,
246
+ cu_seqlens: torch.Tensor,
247
+ rotary_pos_emb: Optional[torch.Tensor] = None,
248
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
249
+ **kwargs,
250
+ ) -> torch.Tensor:
251
+ seq_length = hidden_states.shape[0]
252
+ query_states, key_states, value_states = (
253
+ self.qkv(hidden_states)
254
+ .reshape(seq_length, 3, self.num_heads, -1)
255
+ .permute(1, 0, 2, 3)
256
+ .unbind(0)
257
+ )
258
+ cos, sin = position_embeddings
259
+ query_states, key_states = apply_rotary_pos_emb_vision(
260
+ query_states, key_states, cos, sin
261
+ )
262
+
263
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
264
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
265
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
266
+
267
+ attention_interface: Callable = eager_attention_forward
268
+ if self.config._attn_implementation != "eager":
269
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
270
+ self.config._attn_implementation
271
+ ]
272
+
273
+ if self.config._attn_implementation == "flash_attention_2":
274
+ # Flash Attention 2: Use cu_seqlens for variable length attention
275
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
276
+ attn_output, _ = attention_interface(
277
+ self,
278
+ query_states,
279
+ key_states,
280
+ value_states,
281
+ attention_mask=None,
282
+ scaling=self.scaling,
283
+ dropout=0.0 if not self.training else self.attention_dropout,
284
+ cu_seq_lens_q=cu_seqlens,
285
+ cu_seq_lens_k=cu_seqlens,
286
+ max_length_q=max_seqlen,
287
+ max_length_k=max_seqlen,
288
+ is_causal=False,
289
+ **kwargs,
290
+ )
291
+ else:
292
+ # Other implementations: Process each chunk separately
293
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
294
+ splits = [
295
+ torch.split(tensor, lengths.tolist(), dim=2)
296
+ for tensor in (query_states, key_states, value_states)
297
+ ]
298
+
299
+ attn_outputs = [
300
+ attention_interface(
301
+ self,
302
+ q,
303
+ k,
304
+ v,
305
+ attention_mask=None,
306
+ scaling=self.scaling,
307
+ dropout=0.0 if not self.training else self.attention_dropout,
308
+ is_causal=False,
309
+ **kwargs,
310
+ )[0]
311
+ for q, k, v in zip(*splits)
312
+ ]
313
+ attn_output = torch.cat(attn_outputs, dim=1)
314
+
315
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
316
+ attn_output = self.proj(attn_output)
317
+ return attn_output
318
+
319
+
320
+ class SMBVisionBlock(GradientCheckpointingLayer):
321
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
322
+ super().__init__()
323
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
324
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
325
+ self.attn = SMBVisionAttention(config=config)
326
+ self.mlp = SMBVisionMLP(config=config)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ cu_seqlens: torch.Tensor,
332
+ rotary_pos_emb: Optional[torch.Tensor] = None,
333
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
334
+ **kwargs,
335
+ ) -> torch.Tensor:
336
+ hidden_states = hidden_states + self.attn(
337
+ self.norm1(hidden_states),
338
+ cu_seqlens=cu_seqlens,
339
+ rotary_pos_emb=rotary_pos_emb,
340
+ position_embeddings=position_embeddings,
341
+ **kwargs,
342
+ )
343
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
344
+ return hidden_states
345
+
346
+
347
+ class SMBVisionEncoder(PreTrainedModel):
348
+ config: SMBVisionConfig
349
+ _no_split_modules = ["SMBVisionBlock"]
350
+ _supports_flash_attn = True
351
+ _supports_sdpa = True
352
+ _supports_flex_attn = True
353
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
354
+ _supports_attention_backend = True
355
+
356
+ def __init__(self, config, *inputs, **kwargs) -> None:
357
+ super().__init__(config, *inputs, **kwargs)
358
+ self.spatial_merge_size = config.spatial_merge_size
359
+ self.patch_size = config.patch_size
360
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
361
+
362
+ self.patch_embed = SMBVisionPatchEmbed(
363
+ config=config,
364
+ )
365
+
366
+ self.pos_embed = nn.Embedding(
367
+ config.num_position_embeddings, config.hidden_size
368
+ )
369
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
370
+
371
+ head_dim = config.hidden_size // config.num_heads
372
+ self.rotary_pos_emb = SMBVisionRotaryEmbedding(head_dim // 2)
373
+
374
+ self.blocks = nn.ModuleList(
375
+ [SMBVisionBlock(config) for _ in range(config.depth)]
376
+ )
377
+ self.merger = SMBVisionPatchMerger(
378
+ config=config,
379
+ use_postshuffle_norm=False,
380
+ )
381
+
382
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
383
+ self.deepstack_merger_list = nn.ModuleList(
384
+ [
385
+ SMBVisionPatchMerger(
386
+ config=config,
387
+ use_postshuffle_norm=True,
388
+ )
389
+ for _ in range(len(config.deepstack_visual_indexes))
390
+ ]
391
+ )
392
+
393
+ self.gradient_checkpointing = False
394
+
395
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
396
+ merge_size = self.spatial_merge_size
397
+
398
+ max_hw = int(grid_thw[:, 1:].max().item())
399
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
400
+ device = freq_table.device
401
+
402
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
403
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
404
+
405
+ offset = 0
406
+ for num_frames, height, width in grid_thw:
407
+ merged_h, merged_w = height // merge_size, width // merge_size
408
+
409
+ block_rows = torch.arange(merged_h, device=device) # block row indices
410
+ block_cols = torch.arange(merged_w, device=device) # block col indices
411
+ intra_row = torch.arange(
412
+ merge_size, device=device
413
+ ) # intra-block row offsets
414
+ intra_col = torch.arange(
415
+ merge_size, device=device
416
+ ) # intra-block col offsets
417
+
418
+ # Compute full-resolution positions
419
+ row_idx = (
420
+ block_rows[:, None, None, None] * merge_size
421
+ + intra_row[None, None, :, None]
422
+ )
423
+ col_idx = (
424
+ block_cols[None, :, None, None] * merge_size
425
+ + intra_col[None, None, None, :]
426
+ )
427
+
428
+ row_idx = row_idx.expand(
429
+ merged_h, merged_w, merge_size, merge_size
430
+ ).reshape(-1)
431
+ col_idx = col_idx.expand(
432
+ merged_h, merged_w, merge_size, merge_size
433
+ ).reshape(-1)
434
+
435
+ coords = torch.stack((row_idx, col_idx), dim=-1)
436
+
437
+ if num_frames > 1:
438
+ coords = coords.repeat(num_frames, 1)
439
+
440
+ num_tokens = coords.shape[0]
441
+ pos_ids[offset : offset + num_tokens] = coords
442
+ offset += num_tokens
443
+
444
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
445
+ embeddings = embeddings.flatten(1)
446
+ return embeddings
447
+
448
+ def fast_pos_embed_interpolate(self, grid_thw):
449
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
450
+
451
+ idx_list = [[] for _ in range(4)]
452
+ weight_list = [[] for _ in range(4)]
453
+
454
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
455
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
456
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
457
+
458
+ h_idxs_floor = h_idxs.int()
459
+ w_idxs_floor = w_idxs.int()
460
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
461
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
462
+
463
+ dh = h_idxs - h_idxs_floor
464
+ dw = w_idxs - w_idxs_floor
465
+
466
+ base_h = h_idxs_floor * self.num_grid_per_side
467
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
468
+
469
+ indices = [
470
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
471
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
472
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
473
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
474
+ ]
475
+
476
+ weights = [
477
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
478
+ ((1 - dh)[None].T * dw[None]).flatten(),
479
+ (dh[None].T * (1 - dw)[None]).flatten(),
480
+ (dh[None].T * dw[None]).flatten(),
481
+ ]
482
+
483
+ for i in range(4):
484
+ idx_list[i].extend(indices[i].tolist())
485
+ weight_list[i].extend(weights[i].tolist())
486
+
487
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
488
+ weight_tensor = torch.tensor(
489
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
490
+ )
491
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
492
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
493
+
494
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
495
+
496
+ patch_pos_embeds_permute = []
497
+ merge_size = self.config.spatial_merge_size
498
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
499
+ pos_embed = pos_embed.repeat(t, 1)
500
+ pos_embed = (
501
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
502
+ .permute(0, 1, 3, 2, 4, 5)
503
+ .flatten(0, 4)
504
+ )
505
+ patch_pos_embeds_permute.append(pos_embed)
506
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
507
+ return patch_pos_embeds
508
+
509
+ def forward(
510
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
511
+ ) -> torch.Tensor:
512
+ """
513
+ Args:
514
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
515
+ The final hidden states of the model.
516
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
517
+ The temporal, height and width of feature shape of each image in LLM.
518
+
519
+ Returns:
520
+ `torch.Tensor`: hidden_states.
521
+ """
522
+ hidden_states = self.patch_embed(hidden_states)
523
+
524
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
525
+ hidden_states = hidden_states + pos_embeds
526
+
527
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
528
+
529
+ seq_len, _ = hidden_states.size()
530
+ hidden_states = hidden_states.reshape(seq_len, -1)
531
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
532
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
533
+ position_embeddings = (emb.cos(), emb.sin())
534
+
535
+ cu_seqlens = torch.repeat_interleave(
536
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
537
+ ).cumsum(
538
+ dim=0,
539
+ # Select dtype based on the following factors:
540
+ # - FA2 requires that cu_seqlens_q must have dtype int32
541
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
542
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
543
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
544
+ )
545
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
546
+
547
+ deepstack_feature_lists = []
548
+ for layer_num, blk in enumerate(self.blocks):
549
+ hidden_states = blk(
550
+ hidden_states,
551
+ cu_seqlens=cu_seqlens,
552
+ position_embeddings=position_embeddings,
553
+ **kwargs,
554
+ )
555
+ if layer_num in self.deepstack_visual_indexes:
556
+ deepstack_feature = self.deepstack_merger_list[
557
+ self.deepstack_visual_indexes.index(layer_num)
558
+ ](hidden_states)
559
+ deepstack_feature_lists.append(deepstack_feature)
560
+
561
+ # hidden_states = self.merger(hidden_states)
562
+
563
+ return hidden_states, deepstack_feature_lists
564
+
565
+
566
+ class SMBVisionPredictor(PreTrainedModel):
567
+ config: SMBVisionPredictorConfig
568
+ _no_split_modules = ["SMBVisionBlock"]
569
+ _supports_flash_attn = True
570
+ _supports_sdpa = True
571
+ _supports_flex_attn = True
572
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
573
+ _supports_attention_backend = True
574
+
575
+ def __init__(self, config, *inputs, **kwargs) -> None:
576
+ super().__init__(config, *inputs, **kwargs)
577
+
578
+ head_dim = config.hidden_size // config.num_heads
579
+ self.rotary_pos_emb = SMBVisionRotaryEmbedding(head_dim // 2)
580
+
581
+ self.blocks = nn.ModuleList(
582
+ [SMBVisionBlock(config) for _ in range(config.depth)]
583
+ )
584
+
585
+ self.in_proj = nn.Linear(config.in_hidden_size, config.hidden_size)
586
+ self.out_proj = nn.Linear(config.hidden_size, config.in_hidden_size)
587
+ self.mask_token = nn.Parameter(torch.zeros(config.hidden_size))
588
+ self.gradient_checkpointing = False
589
+
590
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
591
+ merge_size = 1
592
+
593
+ max_hw = int(grid_thw[:, 1:].max().item())
594
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
595
+ device = freq_table.device
596
+
597
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
598
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
599
+
600
+ offset = 0
601
+ for num_frames, height, width in grid_thw:
602
+ merged_h, merged_w = height // merge_size, width // merge_size
603
+
604
+ block_rows = torch.arange(merged_h, device=device) # block row indices
605
+ block_cols = torch.arange(merged_w, device=device) # block col indices
606
+ intra_row = torch.arange(
607
+ merge_size, device=device
608
+ ) # intra-block row offsets
609
+ intra_col = torch.arange(
610
+ merge_size, device=device
611
+ ) # intra-block col offsets
612
+
613
+ # Compute full-resolution positions
614
+ row_idx = (
615
+ block_rows[:, None, None, None] * merge_size
616
+ + intra_row[None, None, :, None]
617
+ )
618
+ col_idx = (
619
+ block_cols[None, :, None, None] * merge_size
620
+ + intra_col[None, None, None, :]
621
+ )
622
+
623
+ row_idx = row_idx.expand(
624
+ merged_h, merged_w, merge_size, merge_size
625
+ ).reshape(-1)
626
+ col_idx = col_idx.expand(
627
+ merged_h, merged_w, merge_size, merge_size
628
+ ).reshape(-1)
629
+
630
+ coords = torch.stack((row_idx, col_idx), dim=-1)
631
+
632
+ if num_frames > 1:
633
+ coords = coords.repeat(num_frames, 1)
634
+
635
+ num_tokens = coords.shape[0]
636
+ pos_ids[offset : offset + num_tokens] = coords
637
+ offset += num_tokens
638
+
639
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
640
+ embeddings = embeddings.flatten(1)
641
+ return embeddings
642
+
643
+ def forward(
644
+ self,
645
+ hidden_states: torch.Tensor,
646
+ grid_thw: torch.Tensor,
647
+ target_mask: torch.Tensor,
648
+ **kwargs,
649
+ ) -> torch.Tensor:
650
+ # mask out the hidden states
651
+ hidden_states = self.in_proj(hidden_states)
652
+ hidden_states[target_mask] = self.mask_token
653
+
654
+ # apply position embeddings
655
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
656
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
657
+ position_embeddings = (emb.cos(), emb.sin())
658
+
659
+ cu_seqlens = torch.repeat_interleave(
660
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
661
+ ).cumsum(
662
+ dim=0,
663
+ # Select dtype based on the following factors:
664
+ # - FA2 requires that cu_seqlens_q must have dtype int32
665
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
666
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
667
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
668
+ )
669
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
670
+
671
+ for layer_num, blk in enumerate(self.blocks):
672
+ hidden_states = blk(
673
+ hidden_states,
674
+ cu_seqlens=cu_seqlens,
675
+ position_embeddings=position_embeddings,
676
+ **kwargs,
677
+ )
678
+
679
+ # hidden_states = self.merger(hidden_states)
680
+ hidden_states = self.out_proj(hidden_states)
681
+
682
+ return hidden_states
683
+
684
+
685
+ @dataclass
686
+ class SMBVisionModelOutput(ModelOutput):
687
+ loss: Optional[torch.FloatTensor] = None
688
+ mim_loss: Optional[torch.FloatTensor] = None
689
+ jepa_loss: Optional[torch.FloatTensor] = None
690
+ hidden_states: Optional[torch.FloatTensor] = None
691
+ enc_hidden_states: Optional[torch.FloatTensor] = None
692
+ predicted_hidden_states: Optional[torch.FloatTensor] = None
693
+
694
+
695
+ class SMBVisionPretrainedModel(PreTrainedModel):
696
+ config: SMBVisionModelConfig
697
+ base_model_prefix = ""
698
+ supports_gradient_checkpointing = True
699
+ _no_split_modules = ["SMBVisionBlock"]
700
+ _supports_flash_attn = True
701
+ _supports_sdpa = True
702
+ _supports_flex_attn = True
703
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
704
+ _supports_attention_backend = True
705
+
706
+ def _init_weights(self, module):
707
+ """Initialize the weights"""
708
+
709
+ init_std = self.config.vision_config.initializer_range
710
+
711
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
712
+ # `trunc_normal_cpu` not implemented in `half` issues
713
+ def trunc_normal_f32_(weight, std):
714
+ data_float_32 = weight.data.to(torch.float32)
715
+ data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std)
716
+ weight.data = data_init.to(weight.dtype)
717
+
718
+ if isinstance(module, SMBVisionEncoder):
719
+ trunc_normal_f32_(module.pos_embed.weight, std=init_std)
720
+ for i, layer in enumerate(module.blocks, 1):
721
+ std = init_std / (i**0.5)
722
+ trunc_normal_f32_(layer.attn.proj.weight, std=std)
723
+ trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
724
+ std = init_std / (len(module.blocks) + 1) ** 0.5
725
+ trunc_normal_f32_(module.mlp.fc2.weight, std=std)
726
+ elif isinstance(module, SMBVisionPredictor):
727
+ trunc_normal_f32_(module.mask_token, std=init_std)
728
+ trunc_normal_f32_(module.in_proj.weight, std=init_std)
729
+ trunc_normal_f32_(module.out_proj.weight, std=init_std)
730
+ for i, layer in enumerate(module.blocks, 1):
731
+ std = init_std / (i**0.5)
732
+ trunc_normal_f32_(layer.attn.proj.weight, std=std)
733
+ trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
734
+ std = init_std / (len(module.blocks) + 1) ** 0.5
735
+ trunc_normal_f32_(module.mlp.fc2.weight, std=std)
736
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
737
+ trunc_normal_f32_(module.weight, std=init_std)
738
+ if module.bias is not None:
739
+ module.bias.data.zero_()
740
+ elif isinstance(module, nn.LayerNorm):
741
+ module.bias.data.zero_()
742
+ module.weight.data.fill_(1.0)
743
+
744
+
745
+ class SMBVisionModel(SMBVisionPretrainedModel):
746
+ def __init__(self, config, *inputs, **kwargs) -> None:
747
+ super().__init__(config, *inputs, **kwargs)
748
+
749
+ self.encoder = SMBVisionEncoder._from_config(config.vision_config)
750
+ self.predictor = SMBVisionPredictor._from_config(config.predictor_config)
751
+ self.to_pixels = nn.Linear(
752
+ config.vision_config.hidden_size,
753
+ config.vision_config.patch_size**2
754
+ * config.vision_config.temporal_patch_size,
755
+ )
756
+ self.masking_ratio = config.masking_ratio
757
+ self.mask_token = nn.Parameter(
758
+ torch.zeros(
759
+ config.vision_config.in_channels
760
+ * config.vision_config.temporal_patch_size
761
+ * config.vision_config.patch_size**2
762
+ )
763
+ )
764
+
765
+ self.mim_loss = nn.L1Loss(reduction="mean")
766
+ self.jepa_loss = nn.MSELoss(reduction="mean")
767
+
768
+ # Initialize weights and apply final processing
769
+ self.post_init()
770
+
771
+ def forward_features(
772
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
773
+ ) -> torch.Tensor:
774
+ return self.encoder(hidden_states, grid_thw, **kwargs)
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ grid_thw: torch.Tensor,
780
+ context_mask: Optional[torch.Tensor],
781
+ target_mask: Optional[torch.Tensor],
782
+ **kwargs,
783
+ ) -> torch.Tensor:
784
+ # modeling masked image reconstruction
785
+ # prepare mask tokens
786
+ num_masked = int(self.masking_ratio * hidden_states.shape[0])
787
+ masked_indices = torch.randperm(hidden_states.shape[0])[:num_masked]
788
+ # replace masked indices with mask tokens
789
+ inputs_mim = hidden_states.clone()
790
+ inputs_mim[masked_indices] = self.mask_token.to(hidden_states.dtype)
791
+ masked_hidden_states, deepstack_feature_lists = self.encoder(
792
+ inputs_mim, grid_thw, **kwargs
793
+ )
794
+ masked_hidden_states = self.to_pixels(masked_hidden_states)
795
+ # compute mim loss
796
+ mim_loss = self.mim_loss(
797
+ masked_hidden_states[masked_indices], hidden_states[masked_indices]
798
+ )
799
+
800
+ # modeling next embedding prediction
801
+ if context_mask is not None and target_mask is not None:
802
+ context_mask = context_mask == 1
803
+ target_mask = target_mask == 1
804
+ # extend context and target masks
805
+ lengths = torch.prod(grid_thw, dim=1)
806
+ extended_context_mask = torch.repeat_interleave(context_mask, lengths)
807
+ extended_target_mask = torch.repeat_interleave(target_mask, lengths)
808
+
809
+ enc_hidden_states, deepstack_feature_lists = self.encoder(
810
+ hidden_states[extended_context_mask], grid_thw[context_mask], **kwargs
811
+ )
812
+ pred_hidden_states = self.predictor(
813
+ enc_hidden_states,
814
+ grid_thw[context_mask],
815
+ extended_target_mask,
816
+ **kwargs,
817
+ )
818
+ jepa_loss = self.jepa_loss(
819
+ pred_hidden_states[extended_target_mask],
820
+ enc_hidden_states[extended_target_mask],
821
+ )
822
+
823
+ loss = mim_loss + jepa_loss
824
+ return SMBVisionModelOutput(
825
+ loss=loss,
826
+ mim_loss=mim_loss,
827
+ jepa_loss=jepa_loss,
828
+ hidden_states=hidden_states,
829
+ enc_hidden_states=enc_hidden_states,
830
+ predicted_hidden_states=pred_hidden_states,
831
+ )
832
+ else:
833
+ return SMBVisionModelOutput(
834
+ loss=mim_loss,
835
+ mim_loss=mim_loss,
836
+ jepa_loss=None,
837
+ hidden_states=hidden_states,
838
+ predicted_hidden_states=None,
839
+ )
840
+
841
+
842
+ __all__ = ["SMBVisionEncoder", "SMBVisionPredictor", "SMBVisionModel"]
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb524eb3bce99caa8ae20f84eb8a9b6eb49534ffab4978a086eaa340f4aef1e7
3
+ size 7313