liuhaogeng commited on
Commit
b0b3b00
1 Parent(s): 6906069

first commit

Browse files
README.md CHANGED
@@ -26,7 +26,7 @@ import torch
26
  from transformers import AutoModelForCausalLM, AutoProcessor
27
  from open_flamingo.eval.models.cruise_model import EvalModel
28
 
29
- processor = AutoProcessor.from_pretrained("/mnt/bn/bohanzhainas1/haogeng/infimm-HD", trust_remote_code=True)
30
 
31
  prompts = [
32
  {
@@ -40,7 +40,7 @@ prompts = [
40
  inputs = processor(prompts)
41
  # use bf16 and gpu 0
42
  model = AutoModelForCausalLM.from_pretrained(
43
- "/opt/tiger/mlm-tool/infimm-HD",
44
  local_files_only=True,
45
  torch_dtype=torch.bfloat16,
46
  trust_remote_code=True,
 
26
  from transformers import AutoModelForCausalLM, AutoProcessor
27
  from open_flamingo.eval.models.cruise_model import EvalModel
28
 
29
+ processor = AutoProcessor.from_pretrained("infimm/infimm-hd", trust_remote_code=True)
30
 
31
  prompts = [
32
  {
 
40
  inputs = processor(prompts)
41
  # use bf16 and gpu 0
42
  model = AutoModelForCausalLM.from_pretrained(
43
+ "infimm/infimm-hd",
44
  local_files_only=True,
45
  torch_dtype=torch.bfloat16,
46
  trust_remote_code=True,
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "<image>": 32001,
3
+ "<|endofchunk|>": 32000
4
+ }
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "InfiMMHDModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_infimm_hd.InfiMMHDConfig",
8
+ "AutoModelForCausalLM": "modeling_infimm_hd.InfiMMHDModel"
9
+ },
10
+ "model_type": "infimm-hd",
11
+ "seq_length": 4096,
12
+ "tokenizer_type": "LlamaTokenizer",
13
+ "torch_dtype": "bfloat16",
14
+ "transformers_version": "4.35.2",
15
+ "use_cache": true,
16
+ "use_flash_attn": false,
17
+ "cross_attn_every_n_layers": 4,
18
+ "use_grad_checkpoint": false,
19
+ "freeze_llm": true,
20
+ "image_token_id": 32001,
21
+ "eoc_token_id": 32000,
22
+ "visual": {
23
+ "image_size": 448,
24
+ "layers": 64,
25
+ "width": 1792,
26
+ "head_width": 112,
27
+ "patch_size": 14,
28
+ "mlp_ratio": 8.571428571428571,
29
+ "eva_model_name": "eva-clip-4b-14-x",
30
+ "drop_path_rate": 0.0,
31
+ "xattn": false,
32
+ "postnorm": true,
33
+ "fusedLN": false,
34
+ "embed_dim": 1024
35
+ },
36
+ "language": {
37
+ "_name_or_path": "lmsys/vicuna-13b-v1.5",
38
+ "architectures": [
39
+ "LlamaForCausalLM"
40
+ ],
41
+ "bos_token_id": 1,
42
+ "eos_token_id": 2,
43
+ "hidden_act": "silu",
44
+ "hidden_size": 5120,
45
+ "initializer_range": 0.02,
46
+ "intermediate_size": 13824,
47
+ "max_position_embeddings": 4096,
48
+ "model_type": "llama",
49
+ "num_attention_heads": 40,
50
+ "num_hidden_layers": 40,
51
+ "num_key_value_heads": 40,
52
+ "pad_token_id": 0,
53
+ "pretraining_tp": 1,
54
+ "rms_norm_eps": 1e-05,
55
+ "rope_scaling": null,
56
+ "tie_word_embeddings": false,
57
+ "torch_dtype": "bfloat16",
58
+ "transformers_version": "4.32.0.dev0",
59
+ "use_cache": true,
60
+ "vocab_size": 32002
61
+ }
62
+ }
configuration_infimm_hd.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class InfiMMHDConfig(PretrainedConfig):
8
+ model_type = "infimmhd"
9
+
10
+ def __init__(
11
+ self,
12
+ model_type="infimm-hd",
13
+ seq_length=1024,
14
+ tokenizer_type="LlamaTokenizer",
15
+ torch_dtype="bfloat16",
16
+ transformers_version="4.28.2",
17
+ use_cache=True,
18
+ use_flash_attn=False,
19
+ cross_attn_every_n_layers=4,
20
+ use_grad_checkpoint=False,
21
+ freeze_llm=True,
22
+ visual=None,
23
+ language=None,
24
+ image_token_id=None,
25
+ eoc_token_id=None,
26
+ **kwargs,
27
+ ):
28
+ self.model_type = model_type
29
+ self.seq_length = seq_length
30
+ self.tokenizer_type = tokenizer_type
31
+ self.torch_dtype = torch_dtype
32
+ self.transformers_version = transformers_version
33
+ self.use_cache = use_cache
34
+ self.use_flash_attn = use_flash_attn
35
+ self.cross_attn_every_n_layers = cross_attn_every_n_layers
36
+ self.use_grad_checkpoint = use_grad_checkpoint
37
+ self.freeze_llm = freeze_llm
38
+ self.visual = visual
39
+ self.language = language
40
+ self.image_token_id = image_token_id
41
+ self.eoc_token_id = eoc_token_id
42
+ super().__init__(**kwargs)
eva_vit_model.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/baaivision/EVA/blob/master/EVA-CLIP/rei/eva_clip/eva_vit_model.py
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ import tempfile
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import Optional, Tuple, Union
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import yaml
14
+ from open_clip.transform import image_transform
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ from open_flamingo.src.util.download_utils import download_pretrained_weights_from_hdfs
18
+ from open_flamingo.src.visual_encoder.rope import VisionRotaryEmbeddingFast
19
+ from open_flamingo.src.visual_encoder.transformer import Attention, PatchDropout
20
+ from open_flamingo.src.xperf_training import FTFlashAttention, FTLayerNorm, FTLinear
21
+
22
+ if os.getenv("ENV_TYPE") == "deepspeed":
23
+ try:
24
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
25
+ except:
26
+ from torch.utils.checkpoint import checkpoint
27
+ else:
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+ from .utils import resize_eva_pos_embed
31
+
32
+ class DropPath(nn.Module):
33
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
34
+
35
+ def __init__(self, drop_prob=None):
36
+ super(DropPath, self).__init__()
37
+ self.drop_prob = drop_prob
38
+
39
+ def forward(self, x):
40
+ return drop_path(x, self.drop_prob, self.training)
41
+
42
+ def extra_repr(self) -> str:
43
+ return "p={}".format(self.drop_prob)
44
+
45
+
46
+ class Mlp(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_features,
50
+ hidden_features=None,
51
+ out_features=None,
52
+ act_layer=nn.GELU,
53
+ norm_layer=nn.LayerNorm,
54
+ drop=0.0,
55
+ subln=False,
56
+ ):
57
+ super().__init__()
58
+ out_features = out_features or in_features
59
+ hidden_features = hidden_features or in_features
60
+
61
+
62
+ use_ft_linear = False
63
+
64
+ if use_ft_linear:
65
+ self.fc1 = FTLinear(in_features, hidden_features)
66
+ else:
67
+ self.fc1 = nn.Linear(in_features, hidden_features)
68
+ self.act = act_layer()
69
+
70
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
71
+
72
+ if use_ft_linear:
73
+ self.fc2 = FTLinear(hidden_features, out_features)
74
+ else:
75
+ self.fc2 = nn.Linear(hidden_features, out_features)
76
+ self.drop = nn.Dropout(drop)
77
+
78
+ def forward(self, x):
79
+ x = self.fc1(x)
80
+ x = self.act(x)
81
+ # x = self.drop(x)
82
+ # commit this for the orignal BERT implement
83
+ x = self.ffn_ln(x)
84
+
85
+ x = self.fc2(x)
86
+ x = self.drop(x)
87
+ return x
88
+
89
+
90
+ class SwiGLU(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_features,
94
+ hidden_features=None,
95
+ out_features=None,
96
+ act_layer=nn.SiLU,
97
+ drop=0.0,
98
+ norm_layer=nn.LayerNorm,
99
+ subln=False,
100
+ ):
101
+ super().__init__()
102
+ out_features = out_features or in_features
103
+ hidden_features = hidden_features or in_features
104
+
105
+ use_ft_linear = False
106
+
107
+ if use_ft_linear:
108
+ self.w1 = FTLinear(in_features, hidden_features)
109
+ self.w2 = FTLinear(in_features, hidden_features)
110
+ else:
111
+ self.w1 = nn.Linear(in_features, hidden_features)
112
+ self.w2 = nn.Linear(in_features, hidden_features)
113
+
114
+ self.act = act_layer()
115
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
116
+
117
+ if use_ft_linear:
118
+ self.w3 = FTLinear(hidden_features, out_features)
119
+ else:
120
+ self.w3 = nn.Linear(hidden_features, out_features)
121
+
122
+ self.drop = nn.Dropout(drop)
123
+
124
+ def forward(self, x):
125
+ x1 = self.w1(x)
126
+ x2 = self.w2(x)
127
+ hidden = self.act(x1) * x2
128
+ x = self.ffn_ln(hidden)
129
+ x = self.w3(x)
130
+ x = self.drop(x)
131
+ return x
132
+
133
+
134
+ class Attention(nn.Module):
135
+ def __init__(
136
+ self,
137
+ dim,
138
+ num_heads=8,
139
+ qkv_bias=False,
140
+ qk_scale=None,
141
+ attn_drop=0.0,
142
+ proj_drop=0.0,
143
+ window_size=None,
144
+ attn_head_dim=None,
145
+ xattn=False,
146
+ rope=None,
147
+ subln=False,
148
+ norm_layer=nn.LayerNorm,
149
+ ):
150
+ super().__init__()
151
+ self.num_heads = num_heads
152
+ head_dim = dim // num_heads
153
+ if attn_head_dim is not None:
154
+ head_dim = attn_head_dim
155
+ all_head_dim = head_dim * self.num_heads
156
+ self.scale = qk_scale or head_dim**-0.5
157
+
158
+
159
+ self.use_ft_flash_attention = False
160
+
161
+ self.subln = subln
162
+ if self.subln:
163
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
164
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
165
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
166
+
167
+ else:
168
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
169
+
170
+ if qkv_bias:
171
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
172
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
173
+ else:
174
+ self.q_bias = None
175
+ self.v_bias = None
176
+
177
+ if window_size:
178
+ self.window_size = window_size
179
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
180
+ 2 * window_size[1] - 1
181
+ ) + 3
182
+ self.relative_position_bias_table = nn.Parameter(
183
+ torch.zeros(self.num_relative_distance, num_heads)
184
+ ) # 2*Wh-1 * 2*Ww-1, nH
185
+ # cls to token & token 2 cls & cls to cls
186
+
187
+ # get pair-wise relative position index for each token inside the window
188
+ coords_h = torch.arange(window_size[0])
189
+ coords_w = torch.arange(window_size[1])
190
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
191
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
192
+ relative_coords = (
193
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
194
+ ) # 2, Wh*Ww, Wh*Ww
195
+ relative_coords = relative_coords.permute(
196
+ 1, 2, 0
197
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
198
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
199
+ relative_coords[:, :, 1] += window_size[1] - 1
200
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
201
+ relative_position_index = torch.zeros(
202
+ size=(window_size[0] * window_size[1] + 1,) * 2,
203
+ dtype=relative_coords.dtype,
204
+ )
205
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
206
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
207
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
208
+ relative_position_index[0, 0] = self.num_relative_distance - 1
209
+
210
+ self.register_buffer("relative_position_index", relative_position_index)
211
+ else:
212
+ self.window_size = None
213
+ self.relative_position_bias_table = None
214
+ self.relative_position_index = None
215
+
216
+ self.attn_drop = nn.Dropout(attn_drop)
217
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
218
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
219
+ self.proj = nn.Linear(all_head_dim, dim)
220
+ self.proj_drop = nn.Dropout(proj_drop)
221
+ self.xattn = xattn
222
+ self.xattn_drop = attn_drop
223
+
224
+ if self.use_ft_flash_attention:
225
+ assert FTFlashAttention is not None
226
+ self.ft_flash_attn = FTFlashAttention()
227
+
228
+ self.rope = rope
229
+
230
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
231
+ B, N, C = x.shape
232
+ if self.subln:
233
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
234
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
235
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
236
+
237
+ q = q.reshape(B, N, self.num_heads, -1).permute(
238
+ 0, 2, 1, 3
239
+ ) # B, num_heads, N, C
240
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
241
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
242
+ else:
243
+ qkv_bias = None
244
+ if self.q_bias is not None:
245
+ qkv_bias = torch.cat(
246
+ (
247
+ self.q_bias,
248
+ torch.zeros_like(self.v_bias, requires_grad=False),
249
+ self.v_bias,
250
+ )
251
+ )
252
+
253
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
254
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
255
+ 2, 0, 3, 1, 4
256
+ ) # 3, B, num_heads, N, C
257
+ q, k, v = qkv[0], qkv[1], qkv[2]
258
+
259
+ if self.rope:
260
+ # slightly fast impl
261
+ q_t = q[:, :, 1:, :]
262
+ ro_q_t = self.rope(q_t)
263
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
264
+
265
+ k_t = k[:, :, 1:, :]
266
+ ro_k_t = self.rope(k_t)
267
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
268
+
269
+ if self.use_ft_flash_attention:
270
+ q = q.permute(0, 2, 1, 3).contiguous()
271
+ q = q.view(
272
+ q.shape[0], q.shape[1], -1
273
+ ) # B, num_heads, N, C -> B, N, num_heads, C
274
+ k = k.permute(0, 2, 1, 3).contiguous()
275
+ k = k.view(k.shape[0], k.shape[1], -1)
276
+ v = v.permute(0, 2, 1, 3).contiguous()
277
+ v = v.view(v.shape[0], v.shape[1], -1)
278
+ x = self.ft_flash_attn(
279
+ [q, k, v],
280
+ self.num_heads,
281
+ attn_mask=None,
282
+ causal=False,
283
+ attention_dropout=self.xattn_drop if self.training else 0.0,
284
+ softmax_scale=self.scale,
285
+ use_rmpad_attn=False,
286
+ )
287
+
288
+ x = self.inner_attn_ln(x)
289
+ x = self.proj(x)
290
+ x = self.proj_drop(x)
291
+
292
+ else:
293
+ q = q * self.scale
294
+ attn = q @ k.transpose(-2, -1)
295
+
296
+ if self.relative_position_bias_table is not None:
297
+ relative_position_bias = self.relative_position_bias_table[
298
+ self.relative_position_index.view(-1)
299
+ ].view(
300
+ self.window_size[0] * self.window_size[1] + 1,
301
+ self.window_size[0] * self.window_size[1] + 1,
302
+ -1,
303
+ ) # Wh*Ww,Wh*Ww,nH
304
+ relative_position_bias = relative_position_bias.permute(
305
+ 2, 0, 1
306
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
307
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
308
+
309
+ if rel_pos_bias is not None:
310
+ attn = attn + rel_pos_bias.type_as(attn)
311
+
312
+ if attn_mask is not None:
313
+ attn_mask = attn_mask.bool()
314
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
315
+
316
+ attn = attn.softmax(dim=-1)
317
+ attn = self.attn_drop(attn)
318
+
319
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
320
+ x = self.inner_attn_ln(x)
321
+ x = self.proj(x)
322
+ x = self.proj_drop(x)
323
+ return x
324
+
325
+
326
+ class Block(nn.Module):
327
+ def __init__(
328
+ self,
329
+ dim,
330
+ num_heads,
331
+ mlp_ratio=4.0,
332
+ qkv_bias=False,
333
+ qk_scale=None,
334
+ drop=0.0,
335
+ attn_drop=0.0,
336
+ drop_path=0.0,
337
+ init_values=None,
338
+ act_layer=nn.GELU,
339
+ norm_layer=nn.LayerNorm,
340
+ window_size=None,
341
+ attn_head_dim=None,
342
+ xattn=False,
343
+ rope=None,
344
+ postnorm=False,
345
+ subln=False,
346
+ naiveswiglu=False,
347
+ ):
348
+ super().__init__()
349
+ self.norm1 = norm_layer(dim)
350
+ self.attn = Attention(
351
+ dim,
352
+ num_heads=num_heads,
353
+ qkv_bias=qkv_bias,
354
+ qk_scale=qk_scale,
355
+ attn_drop=attn_drop,
356
+ proj_drop=drop,
357
+ window_size=window_size,
358
+ attn_head_dim=attn_head_dim,
359
+ xattn=xattn,
360
+ rope=rope,
361
+ subln=subln,
362
+ norm_layer=norm_layer,
363
+ )
364
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
365
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
366
+ self.norm2 = norm_layer(dim)
367
+ mlp_hidden_dim = int(dim * mlp_ratio)
368
+
369
+ if naiveswiglu:
370
+ self.mlp = SwiGLU(
371
+ in_features=dim,
372
+ hidden_features=mlp_hidden_dim,
373
+ subln=subln,
374
+ norm_layer=norm_layer,
375
+ )
376
+ else:
377
+ self.mlp = Mlp(
378
+ in_features=dim,
379
+ hidden_features=mlp_hidden_dim,
380
+ act_layer=act_layer,
381
+ subln=subln,
382
+ drop=drop,
383
+ )
384
+
385
+ if init_values is not None and init_values > 0:
386
+ self.gamma_1 = nn.Parameter(
387
+ init_values * torch.ones((dim)), requires_grad=True
388
+ )
389
+ self.gamma_2 = nn.Parameter(
390
+ init_values * torch.ones((dim)), requires_grad=True
391
+ )
392
+ else:
393
+ self.gamma_1, self.gamma_2 = None, None
394
+
395
+ self.postnorm = postnorm
396
+
397
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
398
+ if self.gamma_1 is None:
399
+ if self.postnorm:
400
+ x = x + self.drop_path(
401
+ self.norm1(
402
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
403
+ )
404
+ )
405
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
406
+ else:
407
+ x = x + self.drop_path(
408
+ self.attn(
409
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
410
+ )
411
+ )
412
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
413
+ else:
414
+ if self.postnorm:
415
+ x = x + self.drop_path(
416
+ self.gamma_1
417
+ * self.norm1(
418
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
419
+ )
420
+ )
421
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
422
+ else:
423
+ x = x + self.drop_path(
424
+ self.gamma_1
425
+ * self.attn(
426
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
427
+ )
428
+ )
429
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
430
+ return x
431
+
432
+
433
+ class PatchEmbed(nn.Module):
434
+ """Image to Patch Embedding"""
435
+
436
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
437
+ super().__init__()
438
+ img_size = to_2tuple(img_size)
439
+ patch_size = to_2tuple(patch_size)
440
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
441
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
442
+ self.img_size = img_size
443
+ self.patch_size = patch_size
444
+ self.num_patches = num_patches
445
+
446
+ self.proj = nn.Conv2d(
447
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
448
+ )
449
+
450
+ def forward(self, x, **kwargs):
451
+ B, C, H, W = x.shape
452
+ # FIXME look at relaxing size constraints
453
+ assert H == self.img_size[0] and W == self.img_size[1], (
454
+ f"Input image size ({H}*{W}) doesn't match model"
455
+ f" ({self.img_size[0]}*{self.img_size[1]})."
456
+ )
457
+ x = self.proj(x).flatten(2).transpose(1, 2)
458
+ return x
459
+
460
+
461
+ class RelativePositionBias(nn.Module):
462
+ def __init__(self, window_size, num_heads):
463
+ super().__init__()
464
+ self.window_size = window_size
465
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
466
+ 2 * window_size[1] - 1
467
+ ) + 3
468
+ self.relative_position_bias_table = nn.Parameter(
469
+ torch.zeros(self.num_relative_distance, num_heads)
470
+ ) # 2*Wh-1 * 2*Ww-1, nH
471
+ # cls to token & token 2 cls & cls to cls
472
+
473
+ # get pair-wise relative position index for each token inside the window
474
+ coords_h = torch.arange(window_size[0])
475
+ coords_w = torch.arange(window_size[1])
476
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
477
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
478
+ relative_coords = (
479
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
480
+ ) # 2, Wh*Ww, Wh*Ww
481
+ relative_coords = relative_coords.permute(
482
+ 1, 2, 0
483
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
484
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
485
+ relative_coords[:, :, 1] += window_size[1] - 1
486
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
487
+ relative_position_index = torch.zeros(
488
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
489
+ )
490
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
491
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
492
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
493
+ relative_position_index[0, 0] = self.num_relative_distance - 1
494
+
495
+ self.register_buffer("relative_position_index", relative_position_index)
496
+
497
+ def forward(self):
498
+ relative_position_bias = self.relative_position_bias_table[
499
+ self.relative_position_index.view(-1)
500
+ ].view(
501
+ self.window_size[0] * self.window_size[1] + 1,
502
+ self.window_size[0] * self.window_size[1] + 1,
503
+ -1,
504
+ ) # Wh*Ww,Wh*Ww,nH
505
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
506
+
507
+
508
+ class EVAVisionTransformer(nn.Module):
509
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
510
+
511
+ def __init__(
512
+ self,
513
+ img_size=224,
514
+ patch_size=16,
515
+ in_chans=3,
516
+ num_classes=1000,
517
+ embed_dim=768,
518
+ depth=12,
519
+ num_heads=12,
520
+ mlp_ratio=4.0,
521
+ qkv_bias=False,
522
+ qk_scale=None,
523
+ drop_rate=0.0,
524
+ attn_drop_rate=0.0,
525
+ drop_path_rate=0.0,
526
+ norm_layer=nn.LayerNorm,
527
+ init_values=None,
528
+ patch_dropout=0.0,
529
+ use_abs_pos_emb=True,
530
+ use_rel_pos_bias=False,
531
+ use_shared_rel_pos_bias=False,
532
+ rope=False,
533
+ use_mean_pooling=True,
534
+ init_scale=0.001,
535
+ grad_checkpointing=False,
536
+ xattn=False,
537
+ postnorm=False,
538
+ pt_hw_seq_len=16,
539
+ intp_freq=False,
540
+ naiveswiglu=False,
541
+ subln=False,
542
+ ):
543
+ super().__init__()
544
+ self.image_size = img_size
545
+ self.num_classes = num_classes
546
+ self.num_features = (
547
+ self.embed_dim
548
+ ) = embed_dim # num_features for consistency with other models
549
+
550
+ self.patch_embed = PatchEmbed(
551
+ img_size=img_size,
552
+ patch_size=patch_size,
553
+ in_chans=in_chans,
554
+ embed_dim=embed_dim,
555
+ )
556
+ num_patches = self.patch_embed.num_patches
557
+
558
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
559
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
560
+ if use_abs_pos_emb:
561
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
562
+ else:
563
+ self.pos_embed = None
564
+ self.pos_drop = nn.Dropout(p=drop_rate)
565
+
566
+ if use_shared_rel_pos_bias:
567
+ self.rel_pos_bias = RelativePositionBias(
568
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
569
+ )
570
+ else:
571
+ self.rel_pos_bias = None
572
+
573
+ if rope:
574
+ half_head_dim = embed_dim // num_heads // 2
575
+ hw_seq_len = img_size // patch_size
576
+ self.rope = VisionRotaryEmbeddingFast(
577
+ dim=half_head_dim,
578
+ pt_seq_len=pt_hw_seq_len,
579
+ ft_seq_len=hw_seq_len if intp_freq else None,
580
+ # patch_dropout=patch_dropout
581
+ )
582
+ else:
583
+ self.rope = None
584
+
585
+ self.naiveswiglu = naiveswiglu
586
+
587
+ dpr = [
588
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
589
+ ] # stochastic depth decay rule
590
+ self.use_rel_pos_bias = use_rel_pos_bias
591
+ self.blocks = nn.ModuleList(
592
+ [
593
+ Block(
594
+ dim=embed_dim,
595
+ num_heads=num_heads,
596
+ mlp_ratio=mlp_ratio,
597
+ qkv_bias=qkv_bias,
598
+ qk_scale=qk_scale,
599
+ drop=drop_rate,
600
+ attn_drop=attn_drop_rate,
601
+ drop_path=dpr[i],
602
+ norm_layer=norm_layer,
603
+ init_values=init_values,
604
+ window_size=(
605
+ self.patch_embed.patch_shape if use_rel_pos_bias else None
606
+ ),
607
+ xattn=xattn,
608
+ rope=self.rope,
609
+ postnorm=postnorm,
610
+ subln=subln,
611
+ naiveswiglu=naiveswiglu,
612
+ )
613
+ for i in range(depth)
614
+ ]
615
+ )
616
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
617
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
618
+ self.head = (
619
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
620
+ )
621
+
622
+ if self.pos_embed is not None:
623
+ trunc_normal_(self.pos_embed, std=0.02)
624
+
625
+ trunc_normal_(self.cls_token, std=0.02)
626
+ # trunc_normal_(self.mask_token, std=.02)
627
+
628
+ self.apply(self._init_weights)
629
+ self.fix_init_weight()
630
+
631
+ if isinstance(self.head, nn.Linear):
632
+ trunc_normal_(self.head.weight, std=0.02)
633
+ self.head.weight.data.mul_(init_scale)
634
+ self.head.bias.data.mul_(init_scale)
635
+
636
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
637
+ self.patch_dropout = (
638
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
639
+ )
640
+
641
+ self.grad_checkpointing = grad_checkpointing
642
+
643
+ def fix_init_weight(self):
644
+ def rescale(param, layer_id):
645
+ param.div_(math.sqrt(2.0 * layer_id))
646
+
647
+ for layer_id, layer in enumerate(self.blocks):
648
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
649
+ if self.naiveswiglu:
650
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
651
+ else:
652
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
653
+
654
+ def get_cast_dtype(self) -> torch.dtype:
655
+ return self.blocks[0].mlp.fc2.weight.dtype
656
+
657
+ def _init_weights(self, m):
658
+ if isinstance(m, nn.Linear):
659
+ trunc_normal_(m.weight, std=0.02)
660
+ if m.bias is not None:
661
+ nn.init.constant_(m.bias, 0)
662
+ elif isinstance(m, nn.LayerNorm):
663
+ nn.init.constant_(m.bias, 0)
664
+ nn.init.constant_(m.weight, 1.0)
665
+
666
+ def get_num_layers(self):
667
+ return len(self.blocks)
668
+
669
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
670
+ assert (
671
+ unlocked_groups == 0
672
+ ), "partial locking not currently supported for this model"
673
+ for param in self.parameters():
674
+ param.requires_grad = False
675
+
676
+ @torch.jit.ignore
677
+ def set_grad_checkpointing(self, enable=True):
678
+ self.grad_checkpointing = enable
679
+
680
+ @torch.jit.ignore
681
+ def no_weight_decay(self):
682
+ return {"pos_embed", "cls_token"}
683
+
684
+ def get_classifier(self):
685
+ return self.head
686
+
687
+ def reset_classifier(self, num_classes, global_pool=""):
688
+ self.num_classes = num_classes
689
+ self.head = (
690
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
691
+ )
692
+
693
+ def forward_features(self, x, return_all_features=False):
694
+ x = self.patch_embed(x)
695
+ batch_size, seq_len, _ = x.size()
696
+
697
+ cls_tokens = self.cls_token.expand(
698
+ batch_size, -1, -1
699
+ ) # stole cls_tokens impl from Phil Wang, thanks
700
+ x = torch.cat((cls_tokens, x), dim=1)
701
+ if self.pos_embed is not None:
702
+ x = x + self.pos_embed
703
+ x = self.pos_drop(x)
704
+
705
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
706
+ if os.getenv("RoPE") == "1":
707
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
708
+ x, patch_indices_keep = self.patch_dropout(x)
709
+ self.rope.forward = partial(
710
+ self.rope.forward, patch_indices_keep=patch_indices_keep
711
+ )
712
+ else:
713
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
714
+ x = self.patch_dropout(x)
715
+ else:
716
+ x = self.patch_dropout(x)
717
+
718
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
719
+ for blk in self.blocks:
720
+ if self.grad_checkpointing:
721
+ x = checkpoint(blk, x, (rel_pos_bias,))
722
+ else:
723
+ x = blk(x, rel_pos_bias=rel_pos_bias)
724
+
725
+ if not return_all_features:
726
+ x = self.norm(x)
727
+ if self.fc_norm is not None:
728
+ return self.fc_norm(x.mean(1))
729
+ else:
730
+ return x[:, 0]
731
+ return x
732
+
733
+ def forward(self, x, return_all_features=False):
734
+ if return_all_features:
735
+ return self.forward_features(x, return_all_features)
736
+ x = self.forward_features(x)
737
+ x = self.head(x)
738
+ return x
739
+
740
+
741
+ @dataclass
742
+ class CLIPVisionCfg:
743
+ layers: Union[Tuple[int, int, int, int], int] = 12
744
+ width: int = 768
745
+ head_width: int = 64
746
+ mlp_ratio: float = 4.0
747
+ patch_size: int = 16
748
+ image_size: Union[Tuple[int, int], int] = 224
749
+ ls_init_value: Optional[float] = None # layer scale initial value
750
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
751
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
752
+ drop_path_rate: Optional[float] = None # drop path rate
753
+ timm_model_name: str = (
754
+ None # a valid model name overrides layers, width, patch_size
755
+ )
756
+ timm_model_pretrained: bool = (
757
+ False # use (imagenet) pretrained weights for named model
758
+ )
759
+ timm_pool: str = ( # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
760
+ "avg"
761
+ )
762
+ timm_proj: str = ( # linear projection for timm model output ('linear', 'mlp', '')
763
+ "linear"
764
+ )
765
+ timm_proj_bias: bool = False # enable bias final projection
766
+ eva_model_name: str = (
767
+ None # a valid eva model name overrides layers, width, patch_size
768
+ )
769
+ qkv_bias: bool = True
770
+ fusedLN: bool = False
771
+ embed_dim: int = 1024
772
+ xattn: bool = False
773
+ postnorm: bool = False
774
+ rope: bool = False
775
+ pt_hw_seq_len: int = 16 # 224/14
776
+ intp_freq: bool = False
777
+ naiveswiglu: bool = False
778
+ subln: bool = False
779
+
780
+
781
+ def load_state_dict(
782
+ checkpoint_path: str,
783
+ map_location: str = "cpu",
784
+ model_key: str = "model|module|state_dict",
785
+ is_openai: bool = False,
786
+ skip_list: list = [],
787
+ ):
788
+ if is_openai:
789
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
790
+ state_dict = model.state_dict()
791
+ for key in ["input_resolution", "context_length", "vocab_size"]:
792
+ state_dict.pop(key, None)
793
+ else:
794
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
795
+ for mk in model_key.split("|"):
796
+ if isinstance(checkpoint, dict) and mk in checkpoint:
797
+ state_dict = checkpoint[mk]
798
+ break
799
+ else:
800
+ state_dict = checkpoint
801
+ if next(iter(state_dict.items()))[0].startswith("module"):
802
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
803
+
804
+ for k in skip_list:
805
+ if k in list(state_dict.keys()):
806
+ print(f"Removing key {k} from pretrained checkpoint")
807
+ del state_dict[k]
808
+
809
+ if os.getenv("RoPE") == "1":
810
+ for k in list(state_dict.keys()):
811
+ if "freqs_cos" in k or "freqs_sin" in k:
812
+ del state_dict[k]
813
+ return state_dict
814
+
815
+
816
+ def load_clip_visual_state_dict(
817
+ checkpoint_path: str,
818
+ map_location: str = "cpu",
819
+ is_openai: bool = False,
820
+ skip_list: list = [],
821
+ ):
822
+ state_dict = load_state_dict(
823
+ checkpoint_path,
824
+ map_location=map_location,
825
+ is_openai=is_openai,
826
+ skip_list=skip_list,
827
+ )
828
+
829
+ for k in list(state_dict.keys()):
830
+ if not k.startswith("visual."):
831
+ del state_dict[k]
832
+ for k in list(state_dict.keys()):
833
+ if k.startswith("visual."):
834
+ new_k = k[7:]
835
+ state_dict[new_k] = state_dict[k]
836
+ del state_dict[k]
837
+ return state_dict
flamingo.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ try:
8
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
9
+ except:
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
13
+ """ Sinusoid position encoding table """
14
+
15
+ def cal_angle(position, hid_idx):
16
+ return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
17
+
18
+ def get_posi_angle_vec(position):
19
+ return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
20
+
21
+ sinusoid_table = np.array(
22
+ [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
23
+ )
24
+
25
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
26
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
27
+
28
+ if padding_idx is not None:
29
+ # zero vector for padding dimension
30
+ sinusoid_table[padding_idx] = 0.0
31
+
32
+ return torch.FloatTensor(sinusoid_table)
33
+
34
+
35
+ def construct_position_encoding(vis_dim, max_pos, rows, cols):
36
+ seq = get_sinusoid_encoding_table(max_pos, int(vis_dim/2))
37
+ y_coords, x_coords = torch.meshgrid(torch.arange(rows), torch.arange(cols), indexing='ij')
38
+
39
+ row_positions = seq[y_coords.flatten(), :]
40
+ col_positions = seq[x_coords.flatten(), :]
41
+
42
+ position_encoding = torch.cat((col_positions, row_positions), dim=-1)
43
+
44
+ return position_encoding
45
+ def unwrap_fsdp(m):
46
+ if isinstance(m, FSDP):
47
+ return unwrap_fsdp(m.module)
48
+ return m
49
+
50
+
51
+ def accepts_parameter(func, parameter_name):
52
+ signature = inspect.signature(func)
53
+ return parameter_name in signature.parameters
54
+
55
+
56
+ class Flamingo(nn.Module):
57
+ def __init__(
58
+ self,
59
+ vision_encoder: nn.Module,
60
+ lang_encoder: nn.Module,
61
+ eoc_token_id: int,
62
+ media_token_id: int,
63
+ vis_dim: int,
64
+ cross_attn_every_n_layers: int = 1,
65
+ gradient_checkpointing: bool = False,
66
+ use_ft_layernorm: bool = False,
67
+ use_ft_flash_attention: bool = False,
68
+ enable_init_network_params: bool = False,
69
+ initializer_range: float = 0.02,
70
+ ):
71
+ """
72
+ Args:
73
+ vision_encoder (nn.Module): HF CLIPModel
74
+ lang_encoder (nn.Module): HF causal language model
75
+ eoc_token_id (int): Token id for <|endofchunk|>
76
+ media_token_id (int): Token id for <image>
77
+ vis_dim (int): Dimension of the visual features.
78
+ Visual features are projected to match this shape along the last dimension.
79
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
80
+ """
81
+ super().__init__()
82
+ self.vit_use_grad = False
83
+ self.eoc_token_id = eoc_token_id
84
+ self.media_token_id = media_token_id
85
+ self.vis_dim = vis_dim
86
+ if hasattr(lang_encoder.config, "d_model"):
87
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
88
+ else:
89
+ self.lang_dim = lang_encoder.config.hidden_size
90
+
91
+ self.vision_encoder = (
92
+ vision_encoder.visual
93
+ if hasattr(vision_encoder, "visual")
94
+ else vision_encoder
95
+ )
96
+
97
+ self.lang_encoder = lang_encoder
98
+ self.lang_encoder.init_flamingo(
99
+ media_token_id=media_token_id,
100
+ lang_hidden_size=self.lang_dim,
101
+ vis_hidden_size=self.vis_dim,
102
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
103
+ gradient_checkpointing=gradient_checkpointing,
104
+ use_ft_layernorm=use_ft_layernorm,
105
+ use_ft_flash_attention=use_ft_flash_attention,
106
+ enable_init_network_params=enable_init_network_params,
107
+ initializer_range=initializer_range,
108
+ )
109
+ self._use_gradient_checkpointing = gradient_checkpointing
110
+
111
+ def forward(
112
+ self,
113
+ vision_x: torch.Tensor,
114
+ lang_x: torch.Tensor,
115
+ attention_mask: torch.Tensor = None,
116
+ labels: torch.Tensor = None,
117
+ image_mask: torch.Tensor = None,
118
+ subimage_shape: torch.Tensor = None,
119
+ clear_conditioned_layers: bool = True,
120
+ past_key_values=None,
121
+ use_cache: bool = False,
122
+ ):
123
+ """
124
+ Forward pass of Flamingo.
125
+
126
+ Args:
127
+ vision_x (torch.Tensor): Vision input
128
+ shape (B, T_img, F, C, H, W) with F=1
129
+ lang_x (torch.Tensor): Language input ids
130
+ shape (B, T_txt)
131
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
132
+ labels (torch.Tensor, optional): Labels. Defaults to None.
133
+ clear_conditioned_layers: if True, clear the conditioned layers
134
+ once the foward pass is completed. Set this to false if the
135
+ same set of images will be reused in another subsequent
136
+ forward pass.
137
+ past_key_values: pre-computed values to pass to language model.
138
+ See past_key_values documentation in Hugging Face
139
+ CausalLM models.
140
+ use_cache: whether to use cached key values. See use_cache
141
+ documentation in Hugging Face CausalLM models.
142
+ """
143
+ assert (
144
+ self.lang_encoder.initialized_flamingo
145
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
146
+
147
+ assert (
148
+ self.lang_encoder._use_cached_vision_x or vision_x is not None
149
+ ), "Must provide either vision_x or have precached media using cache_media()."
150
+
151
+ if self.lang_encoder._use_cached_vision_x:
152
+ # Case: use cached; vision_x should be cached and other
153
+ # vision-related inputs should not be provided.
154
+ assert vision_x is None, (
155
+ "Expect vision_x to be None when media has been cached using"
156
+ " cache_media(). Try uncache_media() first."
157
+ )
158
+ assert self.lang_encoder.is_conditioned()
159
+
160
+ else:
161
+ # Case: do not use caching (i.e. this is a standard forward pass);
162
+ self._encode_vision_x(vision_x=vision_x, image_mask=image_mask, subimage_shape=subimage_shape)
163
+ self._condition_media_locations(input_ids=lang_x)
164
+
165
+ output = self.lang_encoder(
166
+ input_ids=lang_x,
167
+ attention_mask=attention_mask,
168
+ labels=labels,
169
+ past_key_values=past_key_values,
170
+ use_cache=use_cache,
171
+ )
172
+
173
+ if clear_conditioned_layers:
174
+ self.lang_encoder.clear_conditioned_layers()
175
+
176
+ return output
177
+
178
+ def generate(
179
+ self,
180
+ vision_x: torch.Tensor,
181
+ lang_x: torch.Tensor,
182
+ attention_mask: torch.Tensor = None,
183
+ **kwargs,
184
+ ):
185
+ """
186
+ Generate text conditioned on vision and language inputs.
187
+
188
+ Args:
189
+ vision_x (torch.Tensor): Vision input
190
+ shape (B, T_img, F, C, H, W)
191
+ images in the same chunk are collated along T_img, and frames are collated along F
192
+ currently only F=1 is supported (single-frame videos)
193
+ lang_x (torch.Tensor): Language input
194
+ shape (B, T_txt)
195
+ **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
196
+ max_length (int, optional): Maximum length of the output. Defaults to None.
197
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
198
+ num_beams (int, optional): Number of beams. Defaults to 1.
199
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
200
+ temperature (float, optional): Temperature. Defaults to 1.0.
201
+ top_k (int, optional): Top k. Defaults to 50.
202
+ top_p (float, optional): Top p. Defaults to 1.0.
203
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
204
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
205
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
206
+ do_sample (bool, optional): Do sample. Defaults to False.
207
+ early_stopping (bool, optional): Early stopping. Defaults to False.
208
+ Returns:
209
+ torch.Tensor: lang_x with generated tokens appended to it
210
+ """
211
+ subimage_shape = kwargs.pop("subimage_shape", None)
212
+ image_mask = kwargs.pop("image_mask", None)
213
+ num_beams = kwargs.pop("num_beams", 1)
214
+ if num_beams > 1:
215
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
216
+ if image_mask is not None:
217
+ image_mask = image_mask.repeat_interleave(num_beams, dim=0)
218
+ if subimage_shape is not None:
219
+ subimage_shape = subimage_shape.repeat_interleave(num_beams, dim=0)
220
+ self.lang_encoder._use_cached_vision_x = True
221
+ self._encode_vision_x(vision_x=vision_x, image_mask=image_mask, subimage_shape=subimage_shape)
222
+
223
+ eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
224
+ output = self.lang_encoder.generate(
225
+ input_ids=lang_x,
226
+ attention_mask=attention_mask,
227
+ eos_token_id=eos_token_id,
228
+ num_beams=num_beams,
229
+ **kwargs,
230
+ )
231
+
232
+ self.lang_encoder.clear_conditioned_layers()
233
+ self.lang_encoder._use_cached_vision_x = False
234
+ return output
235
+
236
+ def _encode_vision_x(self, vision_x: torch.Tensor, image_mask: torch.Tensor=None, subimage_shape: torch.Tensor=None):
237
+ """
238
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
239
+ Args:
240
+ vision_x (torch.Tensor): Vision input
241
+ shape (B, T_img, F, C, H, W)
242
+ Images in the same chunk are collated along T_img, and frames are collated along F
243
+ Currently only F=1 is supported (single-frame videos)
244
+
245
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
246
+ """
247
+
248
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
249
+ b, T, F = vision_x.shape[:3]
250
+ assert F == 1, "Only single frame supported"
251
+
252
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
253
+
254
+ if not self.vit_use_grad:
255
+ with torch.no_grad():
256
+ module_to_inspect = unwrap_fsdp(self.vision_encoder)
257
+ if accepts_parameter(module_to_inspect.forward, "return_all_features"):
258
+ vision_x = self.vision_encoder(vision_x, return_all_features=True)
259
+ else:
260
+ vision_x = self.vision_encoder(vision_x)[1]
261
+ else:
262
+ module_to_inspect = unwrap_fsdp(self.vision_encoder)
263
+ if accepts_parameter(module_to_inspect.forward, "return_all_features"):
264
+ if self.training:
265
+ vision_x = checkpoint(self.vision_encoder, vision_x, True)
266
+ else:
267
+ vision_x = self.vision_encoder(vision_x, return_all_features=True)
268
+
269
+ else:
270
+ vision_x = self.vision_encoder(vision_x)[1]
271
+
272
+ vision_x = rearrange(vision_x, "(b T F) v d -> b (T F) v d", b=b, T=T, F=F)
273
+ pos_emb = torch.zeros((T,self.vis_dim)).to(vision_x.dtype).to(vision_x.device)
274
+ for i in range(subimage_shape.shape[0]):
275
+ cols, rows = int(subimage_shape[i,0]), int(subimage_shape[i,1])
276
+ tmp_pos_emb = construct_position_encoding(vision_x.shape[-1], 20, rows, cols).to(vision_x.dtype).to(vision_x.device)
277
+ pos_emb[1:int(cols*rows)+1,:] = tmp_pos_emb
278
+ vision_x = vision_x + pos_emb.unsqueeze(1).unsqueeze(0).detach()
279
+ for layer in self.lang_encoder._get_decoder_layers():
280
+ layer.condition_vis_x((vision_x, image_mask))
281
+
282
+ def _condition_media_locations(self, input_ids: torch.Tensor):
283
+ """
284
+ Compute the media token locations from lang_x and condition the language model on these.
285
+ Args:
286
+ input_ids (torch.Tensor): Language input
287
+ shape (B, T_txt)
288
+ """
289
+ print(111)
290
+ media_locations = input_ids == self.media_token_id
291
+ # make all of the seq focus on the first fake image to avoid nan
292
+ # media_locations = torch.where(tmp_mask==False, tmp_mask, media_locations)
293
+ for layer in self.lang_encoder._get_decoder_layers():
294
+ layer.condition_media_locations(media_locations)
295
+
296
+ def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
297
+ """
298
+ Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
299
+ All subsequent calls to forward() will generate attending to the LAST
300
+ image in vision_x.
301
+ This is not meant to be used to cache things for generate().
302
+ Args:
303
+ input_ids (torch.Tensor): Language input
304
+ shape (B, T_txt)
305
+ vision_x (torch.Tensor): Vision input
306
+ shape (B, T_img, F, C, H, W)
307
+ Images in the same chunk are collated along T_img, and frames are collated along F
308
+ Currently only F=1 is supported (single-frame videos)
309
+ """
310
+ self._encode_vision_x(vision_x=vision_x)
311
+ self._condition_media_locations(input_ids=input_ids)
312
+ self.lang_encoder._use_cached_vision_x = True
313
+
314
+ def uncache_media(self):
315
+ """
316
+ Clear all conditioning.
317
+ """
318
+ self.lang_encoder.clear_conditioned_layers()
319
+ self.lang_encoder._use_cached_vision_x = False
flamingo_lm.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange, repeat
5
+ from torch import einsum, nn
6
+ from einops_exts import rearrange_many
7
+ # from .modules import GatedCrossAttentionBlock
8
+ from .utils import getattr_recursive, setattr_recursive
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def FeedForward(
15
+ dim,
16
+ mult=4,
17
+ use_ft_layernorm=False,
18
+ enable_init_network_params=False,
19
+ initializer_range=0.02,
20
+ ):
21
+ inner_dim = int(dim * mult)
22
+ net = nn.Sequential(
23
+ nn.LayerNorm(dim),
24
+ nn.Linear(dim, inner_dim, bias=False),
25
+ nn.GELU(),
26
+ nn.Linear(inner_dim, dim, bias=False),
27
+ )
28
+
29
+ if use_ft_layernorm and enable_init_network_params:
30
+ # only use_ft_layernorm is on and enalbe_init_network_params
31
+ # then start the initialization
32
+ net[0].weight.data.normal_(mean=0.0, std=initializer_range)
33
+ net[0].bias.data.zero_()
34
+ net[1].weight.data.normal_(mean=0.0, std=initializer_range)
35
+ net[3].weight.data.normal_(mean=0.0, std=initializer_range)
36
+ return net
37
+
38
+
39
+ # gated cross attention
40
+ class MaskedCrossAttention(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ dim,
45
+ dim_visual,
46
+ dim_head=64,
47
+ heads=8,
48
+ only_attend_immediate_media=True,
49
+ use_ft_layernorm=False,
50
+ use_ft_flash_attention=False,
51
+ enable_init_network_params=False,
52
+ initializer_range=0.02,
53
+ ):
54
+ super().__init__()
55
+ self.scale = dim_head**-0.5
56
+ self.heads = heads
57
+ self.use_ft_flash_attention = False
58
+ self.initializer_range = initializer_range
59
+ inner_dim = dim_head * heads
60
+
61
+ self.norm = nn.LayerNorm(dim)
62
+
63
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
64
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
65
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
66
+
67
+ # whether for text to only attend to immediate preceding image, or all previous images
68
+ self.only_attend_immediate_media = only_attend_immediate_media
69
+
70
+ if enable_init_network_params:
71
+ self.apply(self._init_weights)
72
+
73
+ def _init_weights(self, module):
74
+ if isinstance(module, nn.Linear):
75
+ # Slightly different from the TF version which uses truncated_normal for initialization
76
+ # cf https://github.com/pytorch/pytorch/pull/5617
77
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
78
+ if module.bias is not None:
79
+ module.bias.data.zero_()
80
+
81
+ elif isinstance(module, nn.LayerNorm):
82
+ module.bias.data.zero_()
83
+ module.weight.data.fill_(1.0)
84
+
85
+ def forward(self, x, media, media_locations=None, use_cached_media=False, image_mask=None):
86
+ """
87
+ Args:
88
+ x (torch.Tensor): text features
89
+ shape (B, T_txt, D_txt)
90
+ media (torch.Tensor): image features
91
+ shape (B, T_img, n, D_img) where n is the dim of the latents
92
+ media_locations: boolean mask identifying the media tokens in x
93
+ shape (B, T_txt)
94
+ use_cached_media: bool
95
+ If true, treat all of x as if they occur after the last media
96
+ registered in media_locations. T_txt does not need to exactly
97
+ equal media_locations.shape[1] in this case
98
+ """
99
+
100
+ if not use_cached_media:
101
+ assert media_locations.shape[1] == x.shape[1], (
102
+ f"media_location.shape is {media_locations.shape} but x.shape is"
103
+ f" {x.shape}"
104
+ )
105
+
106
+ T_txt = x.shape[1]
107
+ _, T_img, n = media.shape[:3]
108
+ h = self.heads
109
+
110
+ x = self.norm(x.contiguous())
111
+ q = self.to_q(x)
112
+ media = rearrange(media, "b t n d -> b (t n) d")
113
+
114
+ k, v = self.to_kv(media).chunk(2, dim=-1)
115
+
116
+ if exists(media_locations):
117
+ media_time = torch.arange(T_img, device=x.device) + 1
118
+
119
+ if use_cached_media:
120
+ # text time is set to the last cached media location
121
+ text_time = repeat(
122
+ torch.count_nonzero(media_locations, dim=1),
123
+ "b -> b i",
124
+ i=T_txt,
125
+ )
126
+ else:
127
+ # at each boolean of True, increment the time counter (relative to media time)
128
+ text_time = media_locations.cumsum(dim=-1)
129
+
130
+ # text time must equal media time if only attending to most immediate image
131
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
132
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
133
+ text_to_media_mask = mask_op(
134
+ rearrange(text_time, "b i -> b 1 i 1"),
135
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
136
+ )
137
+
138
+ if self.only_attend_immediate_media:
139
+ # any text without a preceding media needs to have attention zeroed out
140
+ text_without_media_mask = text_time == 0
141
+ text_without_media_mask = rearrange(
142
+ text_without_media_mask, "b i -> b 1 i 1"
143
+ )
144
+
145
+
146
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
147
+ q = q * self.scale
148
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
149
+
150
+ if exists(image_mask):
151
+ image_mask = image_mask.unsqueeze(1).unsqueeze(1).bool()
152
+ image_mask = image_mask.repeat_interleave(int(sim.shape[3] / image_mask.shape[3]), dim=-1)
153
+ sim = sim.masked_fill(~image_mask, -torch.finfo(sim.dtype).max)
154
+ # if exists(media_locations):
155
+ # sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
156
+
157
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
158
+ attn = sim.softmax(dim=-1)
159
+
160
+ if exists(media_locations) and self.only_attend_immediate_media:
161
+ # any text without a preceding media needs to have attention zeroed out
162
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
163
+
164
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
165
+ out = rearrange(out, "b h n d -> b n (h d)")
166
+
167
+ return self.to_out(out)
168
+
169
+
170
+
171
+ class GatedCrossAttentionBlock(nn.Module):
172
+ def __init__(
173
+ self,
174
+ *,
175
+ dim,
176
+ dim_visual,
177
+ dim_head=64,
178
+ heads=12,
179
+ ff_mult=1,
180
+ only_attend_immediate_media=True,
181
+ use_ft_layernorm=False,
182
+ use_ft_flash_attention=False,
183
+ enable_init_network_params=False,
184
+ initializer_range=0.02,
185
+ gradient_checkpointing=False,
186
+ ):
187
+ super().__init__()
188
+ self.attn = MaskedCrossAttention(
189
+ dim=dim,
190
+ dim_visual=dim_visual,
191
+ dim_head=dim_head,
192
+ heads=heads,
193
+ only_attend_immediate_media=only_attend_immediate_media,
194
+ use_ft_flash_attention=use_ft_flash_attention,
195
+ use_ft_layernorm=use_ft_layernorm,
196
+ enable_init_network_params=enable_init_network_params,
197
+ initializer_range=initializer_range,
198
+ )
199
+ self.attn_gate = nn.Parameter(torch.zeros(dim))
200
+
201
+ self.ff = FeedForward(dim, mult=ff_mult)
202
+ self.ff_gate = nn.Parameter(torch.zeros(dim))
203
+
204
+ self.gradient_checkpointing = gradient_checkpointing
205
+
206
+ def forward(
207
+ self,
208
+ x,
209
+ media,
210
+ media_locations=None,
211
+ use_cached_media=False,
212
+ image_mask=None,
213
+ ):
214
+
215
+ flag = torch.sum(media_locations, dim=-1)
216
+ flag = torch.where(flag > 0.0, 1.0, 0.0)
217
+ flag = flag.unsqueeze(1).unsqueeze(1).to(torch.bfloat16)
218
+ x = (
219
+ flag
220
+ * self.attn(
221
+ x,
222
+ media,
223
+ media_locations=media_locations,
224
+ use_cached_media=use_cached_media,
225
+ image_mask=image_mask,
226
+ )
227
+ * self.attn_gate.tanh()
228
+ + x
229
+ )
230
+
231
+ x = flag * self.ff(x) * self.ff_gate.tanh() + x
232
+
233
+ return x
234
+
235
+
236
+ class FlamingoLayer(nn.Module):
237
+ """
238
+ FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
239
+ """
240
+
241
+ def __init__(
242
+ self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
243
+ ):
244
+ super().__init__()
245
+ self.gated_cross_attn_layer = gated_cross_attn_layer
246
+ self.decoder_layer = decoder_layer
247
+ self.vis_x = None
248
+ self.media_locations = None
249
+ if self.gated_cross_attn_layer is not None:
250
+ self.gated_cross_attn_layer._use_gradient_checkpointing = (
251
+ gradient_checkpointing
252
+ )
253
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
254
+
255
+ def is_conditioned(self) -> bool:
256
+ """Check whether the layer is conditioned."""
257
+ return self.vis_x is not None and self.media_locations is not None
258
+
259
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
260
+ def condition_vis_x(self, vis_x):
261
+ if vis_x is not None:
262
+ self.vis_x, self.image_mask = vis_x
263
+ else:
264
+ self.vis_x, self.image_mask = None, None
265
+
266
+ def condition_media_locations(self, media_locations):
267
+ self.media_locations = media_locations
268
+
269
+ def condition_use_cached_media(self, use_cached_media):
270
+ self.use_cached_media = use_cached_media
271
+
272
+ def forward(
273
+ self,
274
+ lang_x,
275
+ attention_mask=None,
276
+ **decoder_layer_kwargs,
277
+ ):
278
+ # Cross attention
279
+ if self.gated_cross_attn_layer is not None:
280
+ if self.vis_x is None:
281
+ raise ValueError("vis_x must be conditioned before forward pass")
282
+
283
+ if self.media_locations is None:
284
+ raise ValueError(
285
+ "media_locations must be conditioned before forward pass"
286
+ )
287
+
288
+ lang_x = self.gated_cross_attn_layer(
289
+ lang_x,
290
+ self.vis_x,
291
+ media_locations=self.media_locations,
292
+ use_cached_media=self.use_cached_media,
293
+ image_mask=self.image_mask,
294
+ )
295
+
296
+ # Normal decoder layer
297
+ lang_x = self.decoder_layer(
298
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
299
+ )
300
+ return lang_x
301
+
302
+
303
+ class FlamingoLMMixin(nn.Module):
304
+ """
305
+ Mixin to add cross-attention layers to a language model.
306
+ """
307
+
308
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
309
+ self.decoder_layers_attr_name = decoder_layers_attr_name
310
+
311
+ def _get_decoder_layers(self):
312
+ return getattr_recursive(self, self.decoder_layers_attr_name)
313
+
314
+ def _set_decoder_layers(self, value):
315
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
316
+
317
+ def init_flamingo(
318
+ self,
319
+ media_token_id,
320
+ lang_hidden_size,
321
+ vis_hidden_size,
322
+ cross_attn_every_n_layers,
323
+ *,
324
+ use_ft_layernorm=False,
325
+ use_ft_flash_attention=False,
326
+ enable_init_network_params=False,
327
+ initializer_range=0.02,
328
+ gradient_checkpointing=False,
329
+ ):
330
+ """
331
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
332
+ """
333
+ self.old_decoder_blocks = self._get_decoder_layers()
334
+ self.gated_cross_attn_layers = nn.ModuleList(
335
+ [
336
+ (
337
+ GatedCrossAttentionBlock(
338
+ dim=lang_hidden_size,
339
+ dim_visual=vis_hidden_size,
340
+ use_ft_layernorm=use_ft_layernorm,
341
+ use_ft_flash_attention=use_ft_flash_attention,
342
+ enable_init_network_params=enable_init_network_params,
343
+ initializer_range=initializer_range,
344
+ gradient_checkpointing=gradient_checkpointing,
345
+ )
346
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
347
+ else None
348
+ )
349
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
350
+ ]
351
+ )
352
+ self.init_flamingo_layers(gradient_checkpointing)
353
+ self.media_token_id = media_token_id
354
+ self.initialized_flamingo = True
355
+ self._use_cached_vision_x = False
356
+
357
+ def init_flamingo_layers(self, gradient_checkpointing):
358
+ """
359
+ Re initializes the FlamingoLayers.
360
+ Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
361
+ """
362
+ self._set_decoder_layers(
363
+ nn.ModuleList(
364
+ [
365
+ FlamingoLayer(
366
+ gated_cross_attn_layer, decoder_layer, gradient_checkpointing
367
+ )
368
+ for gated_cross_attn_layer, decoder_layer in zip(
369
+ self.gated_cross_attn_layers, self.old_decoder_blocks
370
+ )
371
+ ]
372
+ )
373
+ )
374
+
375
+ def forward(self, input_ids, attention_mask, **kwargs):
376
+ """Condition the Flamingo layers on the media locations before forward()"""
377
+ if not self.initialized_flamingo:
378
+ raise ValueError(
379
+ "Flamingo layers are not initialized. Please call `init_flamingo`"
380
+ " first."
381
+ )
382
+ media_locations = input_ids == self.media_token_id
383
+ # make all of the seq focus on the first fake image to avoid nan
384
+ # if there are media already cached and we're generating and there are no media tokens in the input,
385
+ # we'll assume that ALL input tokens should attend to the last previous media that is cached.
386
+ # this is especially important for HF generate() compatibility, since generate() calls forward()
387
+ # repeatedly one token at a time (with no media tokens).
388
+ # without this check, the model would not attend to any images when generating (after the first token)
389
+ use_cached_media_locations = (
390
+ self._use_cached_vision_x
391
+ and self.is_conditioned()
392
+ and not media_locations.any()
393
+ )
394
+
395
+ for layer in self._get_decoder_layers():
396
+ if not use_cached_media_locations:
397
+ layer.condition_media_locations(media_locations)
398
+ layer.condition_use_cached_media(use_cached_media_locations)
399
+
400
+ # package arguments for the other parent's forward. since we don't know the order of the arguments,
401
+ # make them all kwargs
402
+ kwargs["input_ids"] = input_ids
403
+ kwargs["attention_mask"] = attention_mask
404
+ return super().forward(**kwargs) # Call the other parent's forward method
405
+
406
+ def is_conditioned(self) -> bool:
407
+ """Check whether all decoder layers are already conditioned."""
408
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
409
+
410
+ def clear_conditioned_layers(self):
411
+ for layer in self._get_decoder_layers():
412
+ layer.condition_vis_x(None)
413
+ layer.condition_media_locations(None)
414
+ layer.condition_use_cached_media(None)
modeling_infimm_hd.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Tuple, Union
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.cuda.amp import autocast
9
+
10
+ from transformers import GenerationConfig, PreTrainedTokenizer, StoppingCriteriaList
11
+ from transformers.generation.logits_process import LogitsProcessorList
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers.generation.streamers import BaseStreamer
15
+
16
+ from transformers.generation.utils import GenerateOutput
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ )
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ from transformers.utils import logging
24
+
25
+ try:
26
+ from einops import rearrange
27
+ except ImportError:
28
+ rearrange = None
29
+ from torch import nn
30
+
31
+ from .configuration_infimm_hd import InfiMMHDConfig
32
+ from .eva_vit_model import CLIPVisionCfg, EVAVisionTransformer
33
+ from .flamingo import Flamingo
34
+ from .flamingo_lm import FlamingoLMMixin
35
+ from .utils import _infer_decoder_layers_attr_name, extend_instance
36
+
37
+ SUPPORT_CUDA = torch.cuda.is_available()
38
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
39
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
40
+
41
+
42
+ class InfiMMPreTrainedModel(PreTrainedModel):
43
+ config_class = InfiMMHDConfig
44
+ base_model_prefix = "transformer"
45
+ is_parallelizable = False
46
+ supports_gradient_checkpointing = True
47
+
48
+ def __init__(self, *inputs, **kwargs):
49
+ super().__init__(*inputs, **kwargs)
50
+
51
+
52
+ class InfiMMHDModel(InfiMMPreTrainedModel):
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+
56
+ self.vision_config = config.visual
57
+ vision_encoder = self.build_vision_encoder()
58
+ self.language_config = config.language
59
+ language_encoder = self.build_language_encoder()
60
+
61
+ self.model = self.build_flamingo(vision_encoder, language_encoder)
62
+
63
+ def build_vision_encoder(self, image_size=448):
64
+ vision_cfg = CLIPVisionCfg(**self.vision_config)
65
+
66
+ if image_size:
67
+ vision_cfg.image_size = image_size
68
+ vision_encoder = EVAVisionTransformer(
69
+ img_size=vision_cfg.image_size,
70
+ patch_size=vision_cfg.patch_size,
71
+ num_classes=vision_cfg.embed_dim,
72
+ use_mean_pooling=vision_cfg.global_average_pool, # False
73
+ init_values=vision_cfg.ls_init_value,
74
+ patch_dropout=vision_cfg.patch_dropout,
75
+ embed_dim=vision_cfg.width,
76
+ depth=vision_cfg.layers,
77
+ num_heads=vision_cfg.width // vision_cfg.head_width,
78
+ mlp_ratio=vision_cfg.mlp_ratio,
79
+ qkv_bias=vision_cfg.qkv_bias,
80
+ drop_path_rate=vision_cfg.drop_path_rate,
81
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
82
+ xattn=vision_cfg.xattn,
83
+ rope=vision_cfg.rope,
84
+ postnorm=vision_cfg.postnorm,
85
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
86
+ intp_freq=vision_cfg.intp_freq,
87
+ naiveswiglu=vision_cfg.naiveswiglu,
88
+ subln=vision_cfg.subln,
89
+ )
90
+
91
+ return vision_encoder
92
+
93
+ def build_language_encoder(self):
94
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
95
+ self.language_config["_name_or_path"]
96
+ )
97
+ lang_encoder.resize_token_embeddings(self.language_config["vocab_size"])
98
+ return lang_encoder
99
+
100
+ def build_flamingo(self, vision_encoder, lang_encoder):
101
+ extend_instance(lang_encoder, FlamingoLMMixin)
102
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
103
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
104
+ model = Flamingo(
105
+ vision_encoder,
106
+ lang_encoder,
107
+ self.config.eoc_token_id,
108
+ self.config.image_token_id,
109
+ vis_dim=self.vision_config["width"],
110
+ cross_attn_every_n_layers=self.config.cross_attn_every_n_layers,
111
+ gradient_checkpointing=self.config.use_grad_checkpoint,
112
+ )
113
+
114
+ return model
115
+
116
+ def generate(
117
+ self,
118
+ batch_images,
119
+ input_ids,
120
+ attention_mask,
121
+ **kwargs,
122
+ ):
123
+
124
+ with torch.inference_mode():
125
+ outputs = self.model.generate(
126
+ batch_images,
127
+ input_ids,
128
+ attention_mask,
129
+ **kwargs,
130
+ )
131
+
132
+ # Extract only the new gnerated tokens
133
+ outputs = outputs[:, len(input_ids[0]) :]
134
+ return outputs
modules.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from torch import einsum, nn
8
+ from einops_exts import rearrange_many
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def FeedForward(
15
+ dim,
16
+ mult=4,
17
+ use_ft_layernorm=False,
18
+ enable_init_network_params=False,
19
+ initializer_range=0.02,
20
+ ):
21
+ inner_dim = int(dim * mult)
22
+ net = nn.Sequential(
23
+ nn.LayerNorm(dim),
24
+ nn.Linear(dim, inner_dim, bias=False),
25
+ nn.GELU(),
26
+ nn.Linear(inner_dim, dim, bias=False),
27
+ )
28
+
29
+ if use_ft_layernorm and enable_init_network_params:
30
+ # only use_ft_layernorm is on and enalbe_init_network_params
31
+ # then start the initialization
32
+ net[0].weight.data.normal_(mean=0.0, std=initializer_range)
33
+ net[0].bias.data.zero_()
34
+ net[1].weight.data.normal_(mean=0.0, std=initializer_range)
35
+ net[3].weight.data.normal_(mean=0.0, std=initializer_range)
36
+ return net
37
+
38
+
39
+ # gated cross attention
40
+ class MaskedCrossAttention(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ dim,
45
+ dim_visual,
46
+ dim_head=64,
47
+ heads=8,
48
+ only_attend_immediate_media=True,
49
+ use_ft_layernorm=False,
50
+ use_ft_flash_attention=False,
51
+ enable_init_network_params=False,
52
+ initializer_range=0.02,
53
+ ):
54
+ super().__init__()
55
+ self.scale = dim_head**-0.5
56
+ self.heads = heads
57
+ self.use_ft_flash_attention = False
58
+ self.initializer_range = initializer_range
59
+ inner_dim = dim_head * heads
60
+
61
+ self.norm = nn.LayerNorm(dim)
62
+
63
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
64
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
65
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
66
+
67
+ # whether for text to only attend to immediate preceding image, or all previous images
68
+ self.only_attend_immediate_media = only_attend_immediate_media
69
+
70
+ if enable_init_network_params:
71
+ self.apply(self._init_weights)
72
+
73
+ def _init_weights(self, module):
74
+ if isinstance(module, nn.Linear):
75
+ # Slightly different from the TF version which uses truncated_normal for initialization
76
+ # cf https://github.com/pytorch/pytorch/pull/5617
77
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
78
+ if module.bias is not None:
79
+ module.bias.data.zero_()
80
+
81
+ elif isinstance(module, nn.LayerNorm):
82
+ module.bias.data.zero_()
83
+ module.weight.data.fill_(1.0)
84
+
85
+ def forward(self, x, media, media_locations=None, use_cached_media=False, image_mask=None):
86
+ """
87
+ Args:
88
+ x (torch.Tensor): text features
89
+ shape (B, T_txt, D_txt)
90
+ media (torch.Tensor): image features
91
+ shape (B, T_img, n, D_img) where n is the dim of the latents
92
+ media_locations: boolean mask identifying the media tokens in x
93
+ shape (B, T_txt)
94
+ use_cached_media: bool
95
+ If true, treat all of x as if they occur after the last media
96
+ registered in media_locations. T_txt does not need to exactly
97
+ equal media_locations.shape[1] in this case
98
+ """
99
+
100
+ if not use_cached_media:
101
+ assert media_locations.shape[1] == x.shape[1], (
102
+ f"media_location.shape is {media_locations.shape} but x.shape is"
103
+ f" {x.shape}"
104
+ )
105
+
106
+ T_txt = x.shape[1]
107
+ _, T_img, n = media.shape[:3]
108
+ h = self.heads
109
+
110
+ x = self.norm(x.contiguous())
111
+ q = self.to_q(x)
112
+ media = rearrange(media, "b t n d -> b (t n) d")
113
+
114
+ k, v = self.to_kv(media).chunk(2, dim=-1)
115
+
116
+ if exists(media_locations):
117
+ media_time = torch.arange(T_img, device=x.device) + 1
118
+
119
+ if use_cached_media:
120
+ # text time is set to the last cached media location
121
+ text_time = repeat(
122
+ torch.count_nonzero(media_locations, dim=1),
123
+ "b -> b i",
124
+ i=T_txt,
125
+ )
126
+ else:
127
+ # at each boolean of True, increment the time counter (relative to media time)
128
+ text_time = media_locations.cumsum(dim=-1)
129
+
130
+ # text time must equal media time if only attending to most immediate image
131
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
132
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
133
+ text_to_media_mask = mask_op(
134
+ rearrange(text_time, "b i -> b 1 i 1"),
135
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
136
+ )
137
+
138
+ if self.only_attend_immediate_media:
139
+ # any text without a preceding media needs to have attention zeroed out
140
+ text_without_media_mask = text_time == 0
141
+ text_without_media_mask = rearrange(
142
+ text_without_media_mask, "b i -> b 1 i 1"
143
+ )
144
+
145
+
146
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
147
+ q = q * self.scale
148
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
149
+
150
+ if exists(image_mask):
151
+ image_mask = image_mask.unsqueeze(1).unsqueeze(1).bool()
152
+ image_mask = image_mask.repeat_interleave(int(sim.shape[3] / image_mask.shape[3]), dim=-1)
153
+ sim = sim.masked_fill(~image_mask, -torch.finfo(sim.dtype).max)
154
+ # if exists(media_locations):
155
+ # sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
156
+
157
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
158
+ attn = sim.softmax(dim=-1)
159
+
160
+ if exists(media_locations) and self.only_attend_immediate_media:
161
+ # any text without a preceding media needs to have attention zeroed out
162
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
163
+
164
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
165
+ out = rearrange(out, "b h n d -> b n (h d)")
166
+
167
+ return self.to_out(out)
168
+
169
+
170
+
171
+ class GatedCrossAttentionBlock(nn.Module):
172
+ def __init__(
173
+ self,
174
+ *,
175
+ dim,
176
+ dim_visual,
177
+ dim_head=64,
178
+ heads=12,
179
+ ff_mult=1,
180
+ only_attend_immediate_media=True,
181
+ use_ft_layernorm=False,
182
+ use_ft_flash_attention=False,
183
+ enable_init_network_params=False,
184
+ initializer_range=0.02,
185
+ gradient_checkpointing=False,
186
+ ):
187
+ super().__init__()
188
+ self.attn = MaskedCrossAttention(
189
+ dim=dim,
190
+ dim_visual=dim_visual,
191
+ dim_head=dim_head,
192
+ heads=heads,
193
+ only_attend_immediate_media=only_attend_immediate_media,
194
+ use_ft_flash_attention=use_ft_flash_attention,
195
+ use_ft_layernorm=use_ft_layernorm,
196
+ enable_init_network_params=enable_init_network_params,
197
+ initializer_range=initializer_range,
198
+ )
199
+ self.attn_gate = nn.Parameter(torch.zeros(dim))
200
+
201
+ self.ff = FeedForward(dim, mult=ff_mult)
202
+ self.ff_gate = nn.Parameter(torch.zeros(dim))
203
+
204
+ self.gradient_checkpointing = gradient_checkpointing
205
+
206
+ def forward(
207
+ self,
208
+ x,
209
+ media,
210
+ media_locations=None,
211
+ use_cached_media=False,
212
+ image_mask=None,
213
+ ):
214
+
215
+ flag = torch.sum(media_locations, dim=-1)
216
+ flag = torch.where(flag > 0.0, 1.0, 0.0)
217
+ flag = flag.unsqueeze(1).unsqueeze(1).to(torch.bfloat16)
218
+ x = (
219
+ flag
220
+ * self.attn(
221
+ x,
222
+ media,
223
+ media_locations=media_locations,
224
+ use_cached_media=use_cached_media,
225
+ image_mask=image_mask,
226
+ )
227
+ * self.attn_gate.tanh()
228
+ + x
229
+ )
230
+
231
+ x = flag * self.ff(x) * self.ff_gate.tanh() + x
232
+
233
+ return x
preprocessor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_infimm_hd.InfiMMHDProcessor"
5
+ },
6
+ "image_size": 224
7
+ }
processing_infimm_hd.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for InfiMMHD.
17
+ """
18
+
19
+ import random
20
+ from typing import List, Optional, Tuple, Union
21
+ import torch
22
+ import torchvision.transforms.functional as F
23
+ from PIL import Image
24
+ from torchvision.transforms import (
25
+ CenterCrop,
26
+ Compose,
27
+ InterpolationMode,
28
+ Normalize,
29
+ Resize,
30
+ ToTensor,
31
+ )
32
+
33
+ from transformers import AutoTokenizer
34
+ from transformers.image_processing_utils import ImageProcessingMixin
35
+ from transformers.processing_utils import ProcessorMixin
36
+ from transformers.tokenization_utils_base import BatchEncoding
37
+
38
+ IMAGE_TOKEN = "<image>"
39
+ END_OF_CHUNK_TOKEN = "<|endofchunk|>"
40
+ PAD_TOKEN = "<PAD>"
41
+
42
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
43
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
44
+
45
+
46
+ def _convert_to_rgb(image):
47
+ return image.convert("RGB")
48
+
49
+
50
+ class ResizeKeepRatio:
51
+ """Resize and Keep Ratio
52
+
53
+ Copy & paste from `timm`
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ size,
59
+ longest=0.0,
60
+ interpolation=InterpolationMode.BICUBIC,
61
+ random_scale_prob=0.0,
62
+ random_scale_range=(0.85, 1.05),
63
+ random_aspect_prob=0.0,
64
+ random_aspect_range=(0.9, 1.11),
65
+ ):
66
+ if isinstance(size, (list, tuple)):
67
+ self.size = tuple(size)
68
+ else:
69
+ self.size = (size, size)
70
+ self.interpolation = interpolation
71
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
72
+ self.random_scale_prob = random_scale_prob
73
+ self.random_scale_range = random_scale_range
74
+ self.random_aspect_prob = random_aspect_prob
75
+ self.random_aspect_range = random_aspect_range
76
+
77
+ @staticmethod
78
+ def get_params(
79
+ img,
80
+ target_size,
81
+ longest,
82
+ random_scale_prob=0.0,
83
+ random_scale_range=(0.85, 1.05),
84
+ random_aspect_prob=0.0,
85
+ random_aspect_range=(0.9, 1.11),
86
+ ):
87
+ """Get parameters"""
88
+ source_size = img.size[::-1] # h, w
89
+ h, w = source_size
90
+ target_h, target_w = target_size
91
+ ratio_h = h / target_h
92
+ ratio_w = w / target_w
93
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
94
+ 1.0 - longest
95
+ )
96
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
97
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
98
+ ratio_factor = (ratio_factor, ratio_factor)
99
+ else:
100
+ ratio_factor = (1.0, 1.0)
101
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
102
+ aspect_factor = random.uniform(
103
+ random_aspect_range[0], random_aspect_range[1]
104
+ )
105
+ ratio_factor = (
106
+ ratio_factor[0] / aspect_factor,
107
+ ratio_factor[1] * aspect_factor,
108
+ )
109
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
110
+ return size
111
+
112
+ def __call__(self, img):
113
+ """
114
+ Args:
115
+ img (PIL Image): Image to be cropped and resized.
116
+
117
+ Returns:
118
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
119
+ """
120
+ size = self.get_params(
121
+ img,
122
+ self.size,
123
+ self.longest,
124
+ self.random_scale_prob,
125
+ self.random_scale_range,
126
+ self.random_aspect_prob,
127
+ self.random_aspect_range,
128
+ )
129
+ img = F.resize(img, size, self.interpolation)
130
+ return img
131
+
132
+ def __repr__(self):
133
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
134
+ format_string += f", interpolation={self.interpolation})"
135
+ format_string += f", longest={self.longest:.3f})"
136
+ return format_string
137
+
138
+
139
+ def image_transform(
140
+ image_size: Union[int, Tuple[int, int]],
141
+ mean: Optional[Tuple[float, ...]] = None,
142
+ std: Optional[Tuple[float, ...]] = None,
143
+ resize_mode: Optional[str] = None,
144
+ interpolation: Optional[str] = None,
145
+ ):
146
+ mean = mean or OPENAI_DATASET_MEAN
147
+ if not isinstance(mean, (list, tuple)):
148
+ mean = (mean,) * 3
149
+
150
+ std = std or OPENAI_DATASET_STD
151
+ if not isinstance(std, (list, tuple)):
152
+ std = (std,) * 3
153
+
154
+ interpolation = interpolation or "bicubic"
155
+ assert interpolation in ["bicubic", "bilinear", "random"]
156
+ # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
157
+ interpolation_mode = (
158
+ InterpolationMode.BILINEAR
159
+ if interpolation == "bilinear"
160
+ else InterpolationMode.BICUBIC
161
+ )
162
+
163
+ resize_mode = resize_mode or "shortest"
164
+ assert resize_mode in ("shortest", "longest", "squash")
165
+
166
+ normalize = Normalize(mean=mean, std=std)
167
+
168
+ assert resize_mode == "shortest"
169
+ if not isinstance(image_size, (tuple, list)):
170
+ image_size = (image_size, image_size)
171
+ if image_size[0] == image_size[1]:
172
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
173
+ transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
174
+ else:
175
+ # resize shortest edge to matching target dim for non-square target
176
+ transforms = [ResizeKeepRatio(image_size)]
177
+ transforms += [CenterCrop(image_size)]
178
+
179
+ transforms.extend(
180
+ [
181
+ _convert_to_rgb,
182
+ ToTensor(),
183
+ normalize,
184
+ ]
185
+ )
186
+ return Compose(transforms)
187
+
188
+
189
+ def get_target_size(width, height, max_image_size, min_image_size):
190
+ target_width = 0
191
+ target_height = 0
192
+ if width < min_image_size:
193
+ target_width = min_image_size
194
+ elif width > max_image_size:
195
+ target_width = max_image_size
196
+
197
+ if height < min_image_size:
198
+ target_height = min_image_size
199
+ elif height > max_image_size:
200
+ target_height = max_image_size
201
+
202
+ if target_width == 0:
203
+ ratio = ((width - min_image_size) + int(0.5*min_image_size))//min_image_size
204
+ target_width = ratio * min_image_size + min_image_size
205
+
206
+ if target_height == 0:
207
+ ratio = ((height - min_image_size) + int(0.5*min_image_size))//min_image_size
208
+ target_height = ratio * min_image_size + min_image_size
209
+
210
+ return target_width, target_height
211
+
212
+ class EVAClipImageProcessor(ImageProcessingMixin):
213
+ def __init__(self, **kwargs) -> None:
214
+ super().__init__(**kwargs)
215
+ self.image_processor = image_transform(image_size=448)
216
+ self.img_size = 448
217
+
218
+ def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
219
+ """
220
+ Convert images to tensors, reshape them, and stack them.
221
+ Args:
222
+ batch: A list of lists of images.
223
+ Returns:
224
+ preprocessed images (tensors) or None
225
+ shape (B, T_img, F, C, H, W)
226
+ None if no images in batch
227
+ """
228
+
229
+ target_image_num = []
230
+ target_shape = []
231
+ for x in batch:
232
+ width, height = x[0].size
233
+ tar_wid, tar_hei = get_target_size(width, height, 1344, self.img_size)
234
+ target_shape.append((tar_wid, tar_hei))
235
+ target_image_num.append(int(tar_wid/self.img_size*tar_hei/self.img_size))
236
+
237
+ images_per_example = max(target_image_num)
238
+ batch_images = None
239
+ image_mask = None
240
+ sub_image_shape = None
241
+ for iexample, example in enumerate(batch):
242
+ for img in example:
243
+ img_ori = img
244
+ tar_wid, tar_hei = target_shape[iexample]
245
+ img_new = img.resize((tar_wid, tar_hei), Image.BILINEAR)
246
+ sub_images = [img_ori]
247
+
248
+ for y in range(0, tar_hei, self.img_size):
249
+ for x in range(0, tar_wid, self.img_size):
250
+ sub_img = img_new.crop((x, y, x + self.img_size, y + self.img_size))
251
+ sub_images.append(sub_img)
252
+
253
+ for iimage, image in enumerate(sub_images):
254
+ preprocessed = self.image_processor(image)
255
+ if batch_images is None:
256
+ batch_images = torch.zeros(
257
+ (len(batch), images_per_example+1, 1) + preprocessed.shape,
258
+ dtype=preprocessed.dtype,
259
+ )
260
+ batch_images[iexample, iimage, 0] = preprocessed
261
+ if not torch.is_tensor(image_mask):
262
+ image_mask = torch.zeros((len(batch), images_per_example+1), dtype=preprocessed.dtype)
263
+ image_mask[iexample,:target_image_num[iexample]+1] = 1.0
264
+ if not torch.is_tensor(sub_image_shape):
265
+ sub_image_shape = torch.zeros((len(batch), 2), dtype=preprocessed.dtype)
266
+ sub_image_shape[iexample, 0], sub_image_shape[iexample, 1] = tar_wid/self.img_size, tar_hei/self.img_size
267
+
268
+ # if batch_images is not None:
269
+ # batch_images = batch_images.to(
270
+ # self.device, dtype=self.cast_dtype, non_blocking=True
271
+ # )
272
+
273
+ # if image_mask is not None:
274
+ # image_mask = image_mask.to(
275
+ # self.device, dtype=self.cast_dtype, non_blocking=True
276
+ # )
277
+
278
+ # if sub_image_shape is not None:
279
+ # sub_image_shape = sub_image_shape.to(
280
+ # self.device, dtype=self.cast_dtype, non_blocking=True
281
+ # )
282
+ return batch_images, image_mask, sub_image_shape
283
+
284
+ def preprocess(self, imgpaths=None):
285
+ if imgpaths is None or len(imgpaths) == 0:
286
+ images = [(Image.new("RGB", (224, 224), color="black"))]
287
+ else:
288
+ images = [Image.open(fp) for fp in imgpaths]
289
+ return self._prepare_images([images])
290
+
291
+
292
+ class InfiMMHDProcessor(ProcessorMixin):
293
+ r"""
294
+ Constructs a InfiMMLlama2 processor which wraps a tokenizer and an image processor into a single processor.
295
+
296
+ Args:
297
+ image_processor (`EVAClipImageProcessor`):
298
+ An instance of [`EVAClipImageProcessor`]. The image processor is a required input.
299
+ tokenizer (`LlamaTokenizer`):
300
+ An instance of [`LlamaTokenizer`]. The tokenizer is a required input.
301
+ image_size (`int`, *optional*, defaults to 336): Image size (assuming a square image)
302
+ """
303
+
304
+ attributes = ["tokenizer"]
305
+ tokenizer_class = "LlamaTokenizer"
306
+
307
+ def __init__(self, tokenizer=None, **kwargs):
308
+ self.image_processor = EVAClipImageProcessor()
309
+ if tokenizer is None:
310
+ tokenizer = AutoTokenizer.from_pretrained("infimm-hd", verbose=False)
311
+
312
+ super().__init__(tokenizer, tokenizer)
313
+
314
+ def _prepare_text(
315
+ self,
316
+ batch: List[List[str]],
317
+ padding="longest",
318
+ truncation=True,
319
+ max_length=2048,
320
+ ):
321
+ """
322
+ Tokenize the text and stack them.
323
+ Args:
324
+ batch: A list of lists of strings.
325
+ Returns:
326
+ input_ids (tensor)
327
+ shape (B, T_txt)
328
+ attention_mask (tensor)
329
+ shape (B, T_txt)
330
+ """
331
+ batch = [b.strip() for b in batch]
332
+ encodings = self.tokenizer(
333
+ batch,
334
+ padding=padding,
335
+ truncation=truncation,
336
+ return_tensors="pt",
337
+ max_length=max_length,
338
+ )
339
+ input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"]
340
+ # print(self.tokenizer.convert_ids_to_tokens(input_ids[]))
341
+ return input_ids, attention_mask
342
+
343
+ def __call__(
344
+ self,
345
+ prompts,
346
+ ) -> BatchEncoding:
347
+ """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
348
+ the model was trained on and prepares the image pixel values for the model to process.
349
+ """
350
+ image_paths = self._extract_image_paths(prompts)
351
+ images, image_mask, sub_image_shape = self.image_processor.preprocess(image_paths)
352
+ prompts = self._replace_with_media_tokens(prompts)
353
+ final_prompt = self.apply_template(prompts)
354
+ # system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
355
+ # final_prompt = f"{system_prompt} USER: <image>" + prompts + " ASSISTANT:"
356
+ input_ids, attention_mask = self._prepare_text([final_prompt])
357
+ return BatchEncoding(
358
+ data={
359
+ "input_ids": input_ids,
360
+ "attention_mask": attention_mask,
361
+ "batch_images": images,
362
+ "image_mask": image_mask,
363
+ "subimage_shape": sub_image_shape,
364
+ }
365
+ )
366
+
367
+ def _extract_image_paths(self, prompts):
368
+ image_paths = []
369
+ for round in prompts:
370
+ if round["role"] != "user":
371
+ continue
372
+ for piece in round["content"]:
373
+ if isinstance(piece, dict):
374
+ image_paths.append(piece["image"])
375
+ return image_paths
376
+
377
+ def _replace_with_media_tokens(self, prompts):
378
+ new_prompts = []
379
+ is_first_img = True
380
+ for round in prompts:
381
+ if round["role"] != "user":
382
+ new_prompts.append(round)
383
+ new_content = []
384
+ for piece in round["content"]:
385
+ if isinstance(piece, dict):
386
+ new_content.append(
387
+ f"{IMAGE_TOKEN}" if is_first_img
388
+ else f"{END_OF_CHUNK_TOKEN}{IMAGE_TOKEN}"
389
+ )
390
+ is_first_img = False
391
+ else:
392
+ new_content.append(piece)
393
+ new_prompts.append({"role": "user", "content": "".join(new_content)})
394
+ return new_prompts
395
+
396
+ def apply_template(self, messages, task="generation"):
397
+ prompt = self.tokenizer.apply_chat_template(
398
+ messages,
399
+ tokenize=False,
400
+ add_generation_prompt=True if task == "generation" else False,
401
+ )
402
+ return prompt
403
+
404
+ def batch_decode(self, *args, **kwargs):
405
+ """
406
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
407
+ refer to the docstring of this method for more information.
408
+ """
409
+ return self.tokenizer.batch_decode(*args, **kwargs)
410
+
411
+ def decode(self, *args, **kwargs):
412
+ """
413
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
414
+ the docstring of this method for more information.
415
+ """
416
+ return self.tokenizer.decode(*args, **kwargs)
417
+
418
+ @property
419
+ def model_input_names(self):
420
+ tokenizer_input_names = self.tokenizer.model_input_names
421
+ image_processor_input_names = self.image_processor.model_input_names
422
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:688e7927fe4f8a80b8d6905d77fdb0922b53f61ed5f7345749408a8654bca4fa
3
+ size 35997587561
special_tokens_map.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|endofchunk|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<image>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "bos_token": {
19
+ "content": "<s>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "eos_token": {
26
+ "content": "</s>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "pad_token": {
33
+ "content": "<unk>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ },
39
+ "unk_token": {
40
+ "content": "<unk>",
41
+ "lstrip": false,
42
+ "normalized": false,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ }
46
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<|endofchunk|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<image>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ }
45
+ },
46
+ "additional_special_tokens": [
47
+ "<|endofchunk|>",
48
+ "<image>"
49
+ ],
50
+ "bos_token": "<s>",
51
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ system_message }}{% endif %}{% if message['role'] == 'user' %}{{ ' USER: ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ ' ASSISTANT: ' + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' ASSISTANT:' }}{% endif %}",
52
+ "clean_up_tokenization_spaces": false,
53
+ "eos_token": "</s>",
54
+ "legacy": false,
55
+ "model_max_length": 4096,
56
+ "pad_token": "<unk>",
57
+ "padding_side": "left",
58
+ "sp_model_kwargs": {},
59
+ "spaces_between_special_tokens": false,
60
+ "tokenizer_class": "LlamaTokenizer",
61
+ "unk_token": "<unk>",
62
+ "use_default_system_prompt": false
63
+ }
utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
32
+
33
+
34
+ def _infer_decoder_layers_attr_name(model):
35
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
36
+ if k.lower() in model.__class__.__name__.lower():
37
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
38
+
39
+ raise ValueError(
40
+ "We require the attribute name for the nn.ModuleList in the decoder storing"
41
+ " the transformer block layers. Please supply this string manually."
42
+ )
43
+
44
+
45
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
46
+ "llama": "model.layers",
47
+ "mistral": "model.layers",
48
+ }
49
+
50
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
51
+ # interpolate position embedding
52
+ if "pos_embed" in state_dict:
53
+ pos_embed_checkpoint = state_dict["pos_embed"]
54
+ embedding_size = pos_embed_checkpoint.shape[-1]
55
+ num_patches = model.patch_embed.num_patches
56
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
57
+ # height (== width) for the checkpoint position embedding
58
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
59
+ # height (== width) for the new position embedding
60
+ new_size = int(num_patches**0.5)
61
+ # class_token and dist_token are kept unchanged
62
+ if orig_size != new_size:
63
+ print(
64
+ "Position interpolate from %dx%d to %dx%d"
65
+ % (orig_size, orig_size, new_size, new_size)
66
+ )
67
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
68
+ # only the position tokens are interpolated
69
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
70
+ pos_tokens = pos_tokens.reshape(
71
+ -1, orig_size, orig_size, embedding_size
72
+ ).permute(0, 3, 1, 2)
73
+ # Convert to float for interpolation
74
+ pos_tokens = pos_tokens.float()
75
+
76
+ pos_tokens = torch.nn.functional.interpolate(
77
+ pos_tokens,
78
+ size=(new_size, new_size),
79
+ mode="bicubic",
80
+ align_corners=False,
81
+ )
82
+ # Convert back to Half if needed
83
+ pos_tokens = pos_tokens.half()
84
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
85
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
86
+ state_dict["pos_embed"] = new_pos_embed
87
+
88
+ patch_embed_proj = state_dict["patch_embed.proj.weight"]
89
+ patch_size = model.patch_embed.patch_size
90
+ # Convert to float for interpolation
91
+ patch_embed_proj = patch_embed_proj.float()
92
+ state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(
93
+ patch_embed_proj.float(),
94
+ size=patch_size,
95
+ mode="bicubic",
96
+ align_corners=False,
97
+ )
98
+ state_dict["patch_embed.proj.weight"] = state_dict["patch_embed.proj.weight"].half()