feipengma commited on
Commit
f1298e6
1 Parent(s): 47f5163

initialize wemm

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_wemm.WeMMConfig",
4
+ "AutoModel": "modeling_wemm.WemmForConditionalGeneration"
5
+ },
6
+ "architectures": [
7
+ "WemmForConditionalGeneration"
8
+ ],
9
+ "connector_config": {
10
+ "_name_or_path": "",
11
+ "add_cross_attention": false,
12
+ "architectures": null,
13
+ "attention_dropout": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": null,
25
+ "exponential_decay_length_penalty": null,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 4096,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "integrate_sub_images": null,
36
+ "intermediate_size": 14336,
37
+ "is_decoder": false,
38
+ "is_encoder_decoder": false,
39
+ "label2id": {
40
+ "LABEL_0": 0,
41
+ "LABEL_1": 1
42
+ },
43
+ "length_penalty": 1.0,
44
+ "max_length": 20,
45
+ "min_length": 0,
46
+ "model_type": "Idefics2ConnectorConfig",
47
+ "no_repeat_ngram_size": 0,
48
+ "num_beam_groups": 1,
49
+ "num_beams": 1,
50
+ "num_key_value_heads": 4,
51
+ "num_return_sequences": 1,
52
+ "num_sub_images": null,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_scores": false,
56
+ "pad_token_id": null,
57
+ "prefix": null,
58
+ "problem_type": null,
59
+ "pruned_heads": {},
60
+ "remove_invalid_values": false,
61
+ "repetition_penalty": 1.0,
62
+ "resampler_depth": 3,
63
+ "resampler_head_dim": 96,
64
+ "resampler_n_heads": 16,
65
+ "resampler_n_latents": 64,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "rms_norm_eps": 1e-05,
69
+ "sep_token_id": null,
70
+ "suppress_tokens": null,
71
+ "task_specific_params": null,
72
+ "temperature": 1.0,
73
+ "tf_legacy_loss": false,
74
+ "tie_encoder_decoder": false,
75
+ "tie_word_embeddings": true,
76
+ "tokenizer_class": null,
77
+ "top_k": 50,
78
+ "top_p": 1.0,
79
+ "torch_dtype": null,
80
+ "torchscript": false,
81
+ "typical_p": 1.0,
82
+ "use_bfloat16": false,
83
+ "vision_hidden_size": 1152
84
+ },
85
+ "do_image_splitting": true,
86
+ "downsampler_config": {
87
+ "_name_or_path": "",
88
+ "add_cross_attention": false,
89
+ "architectures": [
90
+ "DownsamplerModel"
91
+ ],
92
+ "auto_map": {
93
+ "AutoConfig": "configuration_downsampler.DownsamplerConfig",
94
+ "AutoModel": "modeling_downsampler.DownsamplerModel"
95
+ },
96
+ "bad_words_ids": null,
97
+ "begin_suppress_tokens": null,
98
+ "bias": false,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "depth": 2,
104
+ "diversity_penalty": 0.0,
105
+ "do_sample": false,
106
+ "early_stopping": false,
107
+ "encoder_no_repeat_ngram_size": 0,
108
+ "eos_token_id": null,
109
+ "exponential_decay_length_penalty": null,
110
+ "finetuning_task": null,
111
+ "forced_bos_token_id": null,
112
+ "forced_eos_token_id": null,
113
+ "hidden_act": "gelu",
114
+ "id2label": {
115
+ "0": "LABEL_0",
116
+ "1": "LABEL_1"
117
+ },
118
+ "is_decoder": false,
119
+ "is_encoder_decoder": false,
120
+ "kernel_size": 1,
121
+ "label2id": {
122
+ "LABEL_0": 0,
123
+ "LABEL_1": 1
124
+ },
125
+ "length_penalty": 1.0,
126
+ "llm_hidden_size": 4096,
127
+ "max_length": 20,
128
+ "min_length": 0,
129
+ "model_type": "downsampler",
130
+ "no_repeat_ngram_size": 0,
131
+ "num_beam_groups": 1,
132
+ "num_beams": 1,
133
+ "num_return_sequences": 1,
134
+ "output_attentions": false,
135
+ "output_hidden_states": false,
136
+ "output_scores": false,
137
+ "pad_token_id": null,
138
+ "prefix": null,
139
+ "problem_type": null,
140
+ "pruned_heads": {},
141
+ "remove_invalid_values": false,
142
+ "repetition_penalty": 1.0,
143
+ "return_dict": true,
144
+ "return_dict_in_generate": false,
145
+ "sep_token_id": null,
146
+ "stride": 1,
147
+ "suppress_tokens": null,
148
+ "task_specific_params": null,
149
+ "temperature": 1.0,
150
+ "tf_legacy_loss": false,
151
+ "tie_encoder_decoder": false,
152
+ "tie_word_embeddings": true,
153
+ "tokenizer_class": null,
154
+ "top_k": 50,
155
+ "top_p": 1.0,
156
+ "torch_dtype": "float32",
157
+ "torchscript": false,
158
+ "typical_p": 1.0,
159
+ "use_bfloat16": false,
160
+ "visual_hidden_size": 1152
161
+ },
162
+ "image_processor": {
163
+ "do_convert_rgb": true,
164
+ "do_image_splitting": true,
165
+ "do_normalize": true,
166
+ "do_pad": true,
167
+ "do_rescale": true,
168
+ "do_resize": true,
169
+ "image_mean": [
170
+ 0.5,
171
+ 0.5,
172
+ 0.5
173
+ ],
174
+ "image_processor_type": "Idefics2ImageProcessor",
175
+ "image_std": [
176
+ 0.5,
177
+ 0.5,
178
+ 0.5
179
+ ],
180
+ "resample": 2,
181
+ "rescale_factor": 0.00392156862745098,
182
+ "size": {
183
+ "longest_edge": 980,
184
+ "shortest_edge": 378
185
+ }
186
+ },
187
+ "model_type": "wemm_hf",
188
+ "projector_config": {
189
+ "_name_or_path": "",
190
+ "add_cross_attention": false,
191
+ "architectures": [
192
+ "ProjectorModel"
193
+ ],
194
+ "auto_map": {
195
+ "AutoConfig": "configuration_projector.ProjectorConfig",
196
+ "AutoModel": "modeling_projector.ProjectorModel"
197
+ },
198
+ "bad_words_ids": null,
199
+ "begin_suppress_tokens": null,
200
+ "bias": true,
201
+ "bos_token_id": null,
202
+ "chunk_size_feed_forward": 0,
203
+ "cross_attention_hidden_size": null,
204
+ "decoder_start_token_id": null,
205
+ "depth": 2,
206
+ "diversity_penalty": 0.0,
207
+ "do_sample": false,
208
+ "early_stopping": false,
209
+ "encoder_no_repeat_ngram_size": 0,
210
+ "eos_token_id": null,
211
+ "exponential_decay_length_penalty": null,
212
+ "finetuning_task": null,
213
+ "forced_bos_token_id": null,
214
+ "forced_eos_token_id": null,
215
+ "hidden_act": "gelu",
216
+ "id2label": {
217
+ "0": "LABEL_0",
218
+ "1": "LABEL_1"
219
+ },
220
+ "is_decoder": false,
221
+ "is_encoder_decoder": false,
222
+ "label2id": {
223
+ "LABEL_0": 0,
224
+ "LABEL_1": 1
225
+ },
226
+ "length_penalty": 1.0,
227
+ "llm_hidden_size": 4096,
228
+ "max_length": 20,
229
+ "min_length": 0,
230
+ "model_type": "projector",
231
+ "no_repeat_ngram_size": 0,
232
+ "num_beam_groups": 1,
233
+ "num_beams": 1,
234
+ "num_return_sequences": 1,
235
+ "output_attentions": false,
236
+ "output_hidden_states": false,
237
+ "output_scores": false,
238
+ "pad_token_id": null,
239
+ "prefix": null,
240
+ "problem_type": null,
241
+ "pruned_heads": {},
242
+ "remove_invalid_values": false,
243
+ "repetition_penalty": 1.0,
244
+ "return_dict": true,
245
+ "return_dict_in_generate": false,
246
+ "sep_token_id": null,
247
+ "suppress_tokens": null,
248
+ "task_specific_params": null,
249
+ "temperature": 1.0,
250
+ "tf_legacy_loss": false,
251
+ "tie_encoder_decoder": false,
252
+ "tie_word_embeddings": true,
253
+ "tokenizer_class": null,
254
+ "top_k": 50,
255
+ "top_p": 1.0,
256
+ "torch_dtype": "float32",
257
+ "torchscript": false,
258
+ "typical_p": 1.0,
259
+ "use_bfloat16": false,
260
+ "visual_hidden_size": 4096
261
+ },
262
+ "spliter_emb_config": {
263
+ "embedding_dim": 4096,
264
+ "num_embeddings": 12
265
+ },
266
+ "text_config": {
267
+ "_name_or_path": "",
268
+ "add_cross_attention": false,
269
+ "architectures": [
270
+ "InternLM2ForCausalLM"
271
+ ],
272
+ "attn_implementation": "flash_attention_2",
273
+ "auto_map": {
274
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
275
+ "AutoModel": "modeling_internlm2.InternLM2ForCausalLM",
276
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
277
+ },
278
+ "bad_words_ids": null,
279
+ "begin_suppress_tokens": null,
280
+ "bias": false,
281
+ "bos_token_id": 1,
282
+ "chunk_size_feed_forward": 0,
283
+ "cross_attention_hidden_size": null,
284
+ "decoder_start_token_id": null,
285
+ "diversity_penalty": 0.0,
286
+ "do_sample": false,
287
+ "early_stopping": false,
288
+ "encoder_no_repeat_ngram_size": 0,
289
+ "eos_token_id": 2,
290
+ "exponential_decay_length_penalty": null,
291
+ "finetuning_task": null,
292
+ "forced_bos_token_id": null,
293
+ "forced_eos_token_id": null,
294
+ "hidden_act": "silu",
295
+ "hidden_size": 4096,
296
+ "id2label": {
297
+ "0": "LABEL_0",
298
+ "1": "LABEL_1"
299
+ },
300
+ "initializer_range": 0.02,
301
+ "intermediate_size": 14336,
302
+ "is_decoder": false,
303
+ "is_encoder_decoder": false,
304
+ "label2id": {
305
+ "LABEL_0": 0,
306
+ "LABEL_1": 1
307
+ },
308
+ "length_penalty": 1.0,
309
+ "max_length": 20,
310
+ "max_position_embeddings": 32768,
311
+ "min_length": 0,
312
+ "model_type": "internlm2",
313
+ "no_repeat_ngram_size": 0,
314
+ "num_attention_heads": 32,
315
+ "num_beam_groups": 1,
316
+ "num_beams": 1,
317
+ "num_hidden_layers": 32,
318
+ "num_key_value_heads": 8,
319
+ "num_return_sequences": 1,
320
+ "output_attentions": false,
321
+ "output_hidden_states": false,
322
+ "output_scores": false,
323
+ "pad_token_id": 2,
324
+ "prefix": null,
325
+ "problem_type": null,
326
+ "pruned_heads": {},
327
+ "remove_invalid_values": false,
328
+ "repetition_penalty": 1.0,
329
+ "return_dict": true,
330
+ "return_dict_in_generate": false,
331
+ "rms_norm_eps": 1e-05,
332
+ "rope_scaling": {
333
+ "factor": 2.0,
334
+ "type": "dynamic"
335
+ },
336
+ "rope_theta": 1000000,
337
+ "sep_token_id": null,
338
+ "suppress_tokens": null,
339
+ "task_specific_params": null,
340
+ "temperature": 1.0,
341
+ "tf_legacy_loss": false,
342
+ "tie_encoder_decoder": false,
343
+ "tie_word_embeddings": false,
344
+ "tokenizer_class": null,
345
+ "top_k": 50,
346
+ "top_p": 1.0,
347
+ "torch_dtype": "float16",
348
+ "torchscript": false,
349
+ "typical_p": 1.0,
350
+ "use_bfloat16": false,
351
+ "use_cache": true,
352
+ "vocab_size": 92544
353
+ },
354
+ "torch_dtype": "bfloat16",
355
+ "transformers_version": "4.38.1",
356
+ "vision_config": {
357
+ "_name_or_path": "",
358
+ "add_cross_attention": false,
359
+ "architectures": null,
360
+ "attention_dropout": 0.0,
361
+ "bad_words_ids": null,
362
+ "begin_suppress_tokens": null,
363
+ "bos_token_id": null,
364
+ "chunk_size_feed_forward": 0,
365
+ "cross_attention_hidden_size": null,
366
+ "decoder_start_token_id": null,
367
+ "diversity_penalty": 0.0,
368
+ "do_sample": false,
369
+ "early_stopping": false,
370
+ "encoder_no_repeat_ngram_size": 0,
371
+ "eos_token_id": null,
372
+ "exponential_decay_length_penalty": null,
373
+ "finetuning_task": null,
374
+ "forced_bos_token_id": null,
375
+ "forced_eos_token_id": null,
376
+ "hidden_act": "gelu_pytorch_tanh",
377
+ "hidden_size": 1152,
378
+ "id2label": {
379
+ "0": "LABEL_0",
380
+ "1": "LABEL_1"
381
+ },
382
+ "image_size": 980,
383
+ "initializer_range": 0.02,
384
+ "intermediate_size": 4304,
385
+ "is_decoder": false,
386
+ "is_encoder_decoder": false,
387
+ "label2id": {
388
+ "LABEL_0": 0,
389
+ "LABEL_1": 1
390
+ },
391
+ "layer_norm_eps": 1e-06,
392
+ "length_penalty": 1.0,
393
+ "max_length": 20,
394
+ "min_length": 0,
395
+ "model_type": "Idefics2VisionConfig",
396
+ "no_repeat_ngram_size": 0,
397
+ "num_attention_heads": 16,
398
+ "num_beam_groups": 1,
399
+ "num_beams": 1,
400
+ "num_channels": 3,
401
+ "num_hidden_layers": 27,
402
+ "num_return_sequences": 1,
403
+ "output_attentions": false,
404
+ "output_hidden_states": false,
405
+ "output_scores": false,
406
+ "pad_token_id": null,
407
+ "patch_size": 14,
408
+ "prefix": null,
409
+ "problem_type": null,
410
+ "pruned_heads": {},
411
+ "remove_invalid_values": false,
412
+ "repetition_penalty": 1.0,
413
+ "return_dict": true,
414
+ "return_dict_in_generate": false,
415
+ "sep_token_id": null,
416
+ "suppress_tokens": null,
417
+ "task_specific_params": null,
418
+ "temperature": 1.0,
419
+ "tf_legacy_loss": false,
420
+ "tie_encoder_decoder": false,
421
+ "tie_word_embeddings": true,
422
+ "tokenizer_class": null,
423
+ "top_k": 50,
424
+ "top_p": 1.0,
425
+ "torch_dtype": null,
426
+ "torchscript": false,
427
+ "typical_p": 1.0,
428
+ "use_bfloat16": false
429
+ }
430
+ }
configuration_connector.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import json
3
+
4
+ class Idefics2ConnectorConfig(PretrainedConfig):
5
+ r"""
6
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
7
+ documentation from [`PretrainedConfig`] for more information.
8
+
9
+ Args:
10
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
11
+ The non-linear activation function (function or string) in the perceiver block.
12
+ resampler_n_latents (`int`, *optional*, defaults to 64):
13
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
14
+ resampler_depth (`int`, *optional*, defaults to 3):
15
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
16
+ resampler_n_heads (`int`, *optional*, defaults to 16):
17
+ Number of heads in each Transformer block (for multi-headed self-attention).
18
+ resampler_head_dim (`int`, *optional*, defaults to 96):
19
+ Dimensionality of each head projection in the Transformer block.
20
+ num_key_value_heads (`int`, *optional*, defaults to 4):
21
+ Number of key-value heads in the perceiver attention block.
22
+ attention_dropout (`float`, *optional*, defaults to 0.0):
23
+ The dropout ratio for the attention probabilities.
24
+ """
25
+ _auto_class = 'AutoConfig'
26
+ model_type = "Idefics2ConnectorConfig"
27
+
28
+ def __init__(
29
+ self,
30
+ vision_hidden_size=1152,
31
+ hidden_size=4096,
32
+ hidden_act="silu",
33
+ resampler_n_latents=64,
34
+ resampler_depth=3,
35
+ rms_norm_eps=1e-05,
36
+ resampler_n_heads=16,
37
+ resampler_head_dim=96,
38
+ num_key_value_heads=4,
39
+ attention_dropout=0.0,
40
+ intermediate_size=14336,
41
+ integrate_sub_images=None,
42
+ num_sub_images=None,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+ self.vision_hidden_size = vision_hidden_size
47
+ self.hidden_size = hidden_size
48
+ self.hidden_act = hidden_act
49
+ self.resampler_n_latents = resampler_n_latents
50
+ self.resampler_depth = resampler_depth
51
+ self.rms_norm_eps = rms_norm_eps
52
+ self.resampler_n_heads = resampler_n_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.resampler_head_dim = resampler_head_dim
55
+ self.attention_dropout = attention_dropout
56
+ self.intermediate_size = intermediate_size
57
+ self.integrate_sub_images = integrate_sub_images
58
+ self.num_sub_images = num_sub_images
59
+
60
+ if self.num_key_value_heads > self.resampler_n_heads:
61
+ raise ValueError(
62
+ f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
63
+ f" resampler_n_heads={self.resampler_n_heads}"
64
+ )
65
+
66
+ @classmethod
67
+ def from_pretrained(cls, config_path, **kwargs) -> "PretrainedConfig":
68
+
69
+ with open(config_path, "r", encoding="utf-8") as f:
70
+ config_dict = json.load(f)
71
+ cls = Idefics2ConnectorConfig(
72
+ vision_hidden_size=config_dict['vision_hidden_size'],
73
+ hidden_size=config_dict['hidden_size'],
74
+ hidden_act="silu",
75
+ resampler_n_latents=config_dict['resampler_n_latents'],
76
+ resampler_depth=config_dict['resampler_depth'],
77
+ rms_norm_eps=config_dict['rms_norm_eps'],
78
+ intermediate_size=config_dict['intermediate_size'],
79
+ integrate_sub_images=config_dict['integrate_sub_images'],
80
+ num_sub_images=config_dict['num_sub_images']
81
+ )
82
+
83
+ return cls
configuration_downsampler.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class DownsamplerConfig(PretrainedConfig):
6
+ model_type = 'downsampler'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ kernel_size=1,
12
+ stride=1,
13
+ visual_hidden_size=4096,
14
+ llm_hidden_size=4096,
15
+ depth=2,
16
+ hidden_act='gelu',
17
+ bias=False,
18
+ **kwargs,
19
+ ):
20
+ self.visual_hidden_size = visual_hidden_size
21
+ self.llm_hidden_size = llm_hidden_size
22
+ self.depth = depth
23
+ self.hidden_act = hidden_act
24
+ self.bias = bias
25
+ self.kernel_size = kernel_size
26
+ self.stride = stride
27
+ super().__init__(**kwargs)
configuration_internlm2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ Example:
73
+
74
+ """
75
+ model_type = "internlm2"
76
+ _auto_class = "AutoConfig"
77
+
78
+ def __init__( # pylint: disable=W0102
79
+ self,
80
+ vocab_size=103168,
81
+ hidden_size=4096,
82
+ intermediate_size=11008,
83
+ num_hidden_layers=32,
84
+ num_attention_heads=32,
85
+ num_key_value_heads=None,
86
+ hidden_act="silu",
87
+ max_position_embeddings=2048,
88
+ initializer_range=0.02,
89
+ rms_norm_eps=1e-6,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ tie_word_embeddings=False,
95
+ bias=True,
96
+ rope_theta=10000,
97
+ rope_scaling=None,
98
+ attn_implementation="eager",
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.bias = bias
108
+
109
+ if num_key_value_heads is None:
110
+ num_key_value_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+
113
+ self.hidden_act = hidden_act
114
+ self.initializer_range = initializer_range
115
+ self.rms_norm_eps = rms_norm_eps
116
+ self.use_cache = use_cache
117
+ self.rope_theta = rope_theta
118
+ self.rope_scaling = rope_scaling
119
+ self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
+ super().__init__(
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+ def _rope_scaling_validation(self):
133
+ """
134
+ Validate the `rope_scaling` configuration.
135
+ """
136
+ if self.rope_scaling is None:
137
+ return
138
+
139
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140
+ raise ValueError(
141
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
142
+ f"got {self.rope_scaling}"
143
+ )
144
+ rope_scaling_type = self.rope_scaling.get("type", None)
145
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
146
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
147
+ raise ValueError(
148
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149
+ )
150
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
configuration_vision.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig, PreTrainedModel
3
+ import json
4
+
5
+ class Idefics2VisionConfig(PretrainedConfig):
6
+ model_type = "Idefics2VisionConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ num_channels=3,
15
+ image_size=224,
16
+ patch_size=32,
17
+ hidden_act="gelu_pytorch_tanh",
18
+ layer_norm_eps=1e-6,
19
+ attention_dropout=0.0,
20
+ initializer_range=0.02,
21
+ model_type='Idefics2VisionConfig',
22
+ **kwargs,
23
+ ):
24
+
25
+ self.hidden_size = hidden_size
26
+ self.intermediate_size = intermediate_size
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.num_channels = num_channels
30
+ self.patch_size = patch_size
31
+ self.image_size = image_size
32
+ self.attention_dropout = attention_dropout
33
+ self.layer_norm_eps = layer_norm_eps
34
+ self.hidden_act = hidden_act
35
+ self.initializer_range = initializer_range
36
+
37
+ super().__init__(**kwargs)
38
+
configuration_wemm.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+ import json
4
+ # from transformers import CONFIG_MAPPING
5
+ from peft import PeftConfig
6
+ from .configuration_vision import Idefics2VisionConfig
7
+ from .configuration_internlm2 import InternLM2Config
8
+ from .configuration_projector import ProjectorConfig
9
+ from .configuration_connector import Idefics2ConnectorConfig
10
+ from .image_processor import Idefics2ImageProcessor
11
+ from .configuration_downsampler import DownsamplerConfig
12
+
13
+ class WeMMConfig(PretrainedConfig):
14
+ model_type = "wemm_hf"
15
+
16
+ def __init__(
17
+ self,
18
+ vision_config = None,
19
+ text_config = None,
20
+ projector_config = None,
21
+ connector_config = None,
22
+ adapter_path = None,
23
+ image_processor = None,
24
+ do_image_splitting = False,
25
+ spliter_emb_config = None,
26
+ downsampler_config = None,
27
+ tokenizer_config = None,
28
+ **kwargs
29
+ ):
30
+ # vision_config
31
+ if vision_config is not None:
32
+ self.vision_config = Idefics2VisionConfig(**vision_config)
33
+
34
+
35
+ # text_config
36
+ if text_config is not None:
37
+ self.text_config = InternLM2Config(**text_config)
38
+
39
+ # projector_config
40
+ if projector_config is not None:
41
+ self.projector_config = ProjectorConfig(**projector_config)
42
+
43
+ # connector_config
44
+ if connector_config is not None:
45
+ self.connector_config = Idefics2ConnectorConfig(**connector_config)
46
+
47
+ if image_processor is not None:
48
+ self.image_processor = image_processor
49
+
50
+
51
+ if adapter_path is not None:
52
+ self.adapter_path = adapter_path
53
+
54
+ self.do_image_splitting = do_image_splitting
55
+
56
+ if spliter_emb_config is not None:
57
+ self.spliter_emb_config = spliter_emb_config
58
+
59
+ if downsampler_config is not None:
60
+ self.downsampler_config = DownsamplerConfig(**downsampler_config)
61
+
62
+ if tokenizer_config is not None:
63
+ self.tokenizer_config = tokenizer_config
64
+
65
+ super().__init__(**kwargs)
66
+
67
+ if __name__=="__main__":
68
+ wemm_config_path = "/mnt/csp/mmvision/home/feipengma/projects/wemm_evaluation/WeMM/config.json"
69
+ wemm_config = WeMMConfig.from_pretrained(wemm_config_path)
70
+ print(wemm_config.connector_config)
connector.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+
3
+ import inspect
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
18
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
19
+ from transformers.utils import (
20
+ add_start_docstrings,
21
+ add_start_docstrings_to_model_forward,
22
+ is_flash_attn_2_available,
23
+ is_flash_attn_greater_or_equal_2_10,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+
32
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
33
+
34
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
35
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
36
+ """
37
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
38
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
39
+ """
40
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
41
+ if n_rep == 1:
42
+ return hidden_states
43
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
44
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
45
+
46
+ class Idefics2ConnectorConfig(PretrainedConfig):
47
+ r"""
48
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
49
+ documentation from [`PretrainedConfig`] for more information.
50
+
51
+ Args:
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function (function or string) in the perceiver block.
54
+ resampler_n_latents (`int`, *optional*, defaults to 64):
55
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
56
+ resampler_depth (`int`, *optional*, defaults to 3):
57
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
58
+ resampler_n_heads (`int`, *optional*, defaults to 16):
59
+ Number of heads in each Transformer block (for multi-headed self-attention).
60
+ resampler_head_dim (`int`, *optional*, defaults to 96):
61
+ Dimensionality of each head projection in the Transformer block.
62
+ num_key_value_heads (`int`, *optional*, defaults to 4):
63
+ Number of key-value heads in the perceiver attention block.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ """
67
+ _auto_class = 'AutoConfig'
68
+ model_type = "Idefics2ConnectorConfig"
69
+
70
+ def __init__(
71
+ self,
72
+ vision_hidden_size=1152,
73
+ hidden_size=4096,
74
+ hidden_act="silu",
75
+ resampler_n_latents=64,
76
+ resampler_depth=3,
77
+ rms_norm_eps=1e-05,
78
+ resampler_n_heads=16,
79
+ resampler_head_dim=96,
80
+ num_key_value_heads=4,
81
+ attention_dropout=0.0,
82
+ intermediate_size=14336,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+ self.vision_hidden_size = vision_hidden_size
87
+ self.hidden_size = hidden_size
88
+ self.hidden_act = hidden_act
89
+ self.resampler_n_latents = resampler_n_latents
90
+ self.resampler_depth = resampler_depth
91
+ self.rms_norm_eps = rms_norm_eps
92
+ self.resampler_n_heads = resampler_n_heads
93
+ self.num_key_value_heads = num_key_value_heads
94
+ self.resampler_head_dim = resampler_head_dim
95
+ self.attention_dropout = attention_dropout
96
+ self.intermediate_size = intermediate_size
97
+ if self.num_key_value_heads > self.resampler_n_heads:
98
+ raise ValueError(
99
+ f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
100
+ f" resampler_n_heads={self.resampler_n_heads}"
101
+ )
102
+
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, config_path, **kwargs) -> "PretrainedConfig":
106
+
107
+ with open(config_path, "r", encoding="utf-8") as f:
108
+ config_dict = json.load(f)
109
+ cls = Idefics2ConnectorConfig(
110
+ vision_hidden_size=config_dict['vision_hidden_size'],
111
+ hidden_size=config_dict['hidden_size'],
112
+ hidden_act="silu",
113
+ resampler_n_latents=config_dict['resampler_n_latents'],
114
+ resampler_depth=config_dict['resampler_depth'],
115
+ rms_norm_eps=config_dict['rms_norm_eps'],
116
+ intermediate_size = config_dict['intermediate_size']
117
+ )
118
+
119
+ return cls
120
+
121
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
122
+ def _get_unpad_data(attention_mask):
123
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
124
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
125
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
126
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
127
+ return (
128
+ indices,
129
+ cu_seqlens,
130
+ max_seqlen_in_batch,
131
+ )
132
+
133
+ class Idefics2PerceiverAttention(nn.Module):
134
+ def __init__(self, config, layer_idx: Optional[int] = None) -> None:
135
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
136
+ super().__init__()
137
+
138
+ self.layer_idx = None
139
+ self.hidden_size = config.hidden_size
140
+ self.num_heads = config.resampler_n_heads
141
+ self.head_dim = config.resampler_head_dim
142
+ self.num_key_value_heads = config.num_key_value_heads
143
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
144
+ self.attention_dropout = config.attention_dropout
145
+
146
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
147
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
148
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
149
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
150
+
151
+ self.is_causal = False
152
+
153
+ def forward(
154
+ self,
155
+ latents: torch.Tensor,
156
+ context: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
160
+ output_attentions: bool = False,
161
+ use_cache: bool = False,
162
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
163
+ """
164
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
165
+
166
+ Args:
167
+ latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
168
+ context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
169
+ attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
170
+ position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
171
+ past_key_value (`Tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states.
172
+ output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
173
+ use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
174
+ """
175
+ bsz, q_len, _ = latents.size()
176
+ kv_seq_len = q_len + context.size()[1]
177
+
178
+ hidden_states = torch.concat([context, latents], dim=-2)
179
+
180
+ query_states = self.q_proj(latents)
181
+ key_states = self.k_proj(hidden_states)
182
+ value_states = self.v_proj(hidden_states)
183
+
184
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
185
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
186
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
187
+
188
+ past_key_value = getattr(self, "past_key_value", past_key_value)
189
+
190
+ if past_key_value is not None:
191
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
192
+
193
+ # repeat k/v heads if n_kv_heads < n_heads
194
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
195
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
196
+
197
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
+
199
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
+ raise ValueError(
201
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
202
+ f" {attn_weights.size()}"
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
+ raise ValueError(
208
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
+ )
210
+
211
+ attn_weights = attn_weights + attention_mask
212
+
213
+ # upcast attention to fp32
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+
217
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
+ f" {attn_output.size()}"
221
+ )
222
+
223
+ attn_output = attn_output.transpose(1, 2).contiguous()
224
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
225
+
226
+ attn_output = self.o_proj(attn_output)
227
+
228
+ if not output_attentions:
229
+ attn_weights = None
230
+
231
+ return attn_output, attn_weights, past_key_value
232
+
233
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
234
+ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
235
+ """
236
+ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
237
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
238
+ flash attention and deal with padding tokens in case the input contains any of them.
239
+ """
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
242
+ def __init__(self, *args, **kwargs):
243
+ super().__init__(*args, **kwargs)
244
+
245
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
246
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
247
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
248
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
249
+
250
+ # Ignore copy
251
+ def forward(
252
+ self,
253
+ latents: torch.Tensor,
254
+ context: torch.Tensor,
255
+ attention_mask: Optional[torch.LongTensor] = None,
256
+ position_ids: Optional[torch.LongTensor] = None,
257
+ past_key_value: Optional[Cache] = None,
258
+ output_attentions: bool = False,
259
+ use_cache: bool = False,
260
+ **kwargs,
261
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
262
+
263
+ bsz, q_len, _ = latents.size()
264
+ kv_seq_len = q_len + context.size()[1]
265
+
266
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
267
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
268
+ query_states = self.q_proj(latents)
269
+ key_states = self.k_proj(torch.cat([context, latents], dim=-2))
270
+ value_states = self.v_proj(torch.cat([context, latents], dim=-2))
271
+
272
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
273
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
275
+
276
+ kv_seq_len = key_states.shape[-2]
277
+ if past_key_value is not None:
278
+ kv_seq_len += past_key_value[0].shape[-2]
279
+
280
+ if past_key_value is not None:
281
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
282
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
283
+ slicing_tokens = kv_seq_len - self.config.sliding_window
284
+
285
+ past_key = past_key_value[0]
286
+ past_value = past_key_value[1]
287
+
288
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
289
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
290
+
291
+ if past_key.shape[-2] != self.config.sliding_window - 1:
292
+ raise ValueError(
293
+ "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
294
+ f" head_dim`), got {past_key.shape}"
295
+ )
296
+
297
+ past_key_value = (past_key, past_value)
298
+
299
+ if attention_mask is not None:
300
+ attention_mask = attention_mask[:, slicing_tokens:]
301
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
302
+
303
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
304
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
305
+
306
+ past_key_value = (key_states, value_states) if use_cache else None
307
+
308
+ # repeat k/v heads if n_kv_heads < n_heads
309
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
310
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
311
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
312
+
313
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
314
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
315
+ # cast them back in float16 just to be sure everything works as expected.
316
+ input_dtype = query_states.dtype
317
+ if input_dtype == torch.float32:
318
+ if torch.is_autocast_enabled():
319
+ target_dtype = torch.get_autocast_gpu_dtype()
320
+ # Handle the case where the model is quantized
321
+ elif hasattr(self.config, "_pre_quantization_dtype"):
322
+ target_dtype = self.config._pre_quantization_dtype
323
+ else:
324
+ target_dtype = self.q_proj.weight.dtype
325
+
326
+ logger.warning_once(
327
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
328
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
329
+ f" {target_dtype}."
330
+ )
331
+
332
+ query_states = query_states.to(target_dtype)
333
+ key_states = key_states.to(target_dtype)
334
+ value_states = value_states.to(target_dtype)
335
+
336
+ # Reashape to the expected shape for Flash Attention
337
+ query_states = query_states.transpose(1, 2)
338
+ key_states = key_states.transpose(1, 2)
339
+ value_states = value_states.transpose(1, 2)
340
+
341
+ attn_output = self._flash_attention_forward(
342
+ query_states,
343
+ key_states,
344
+ value_states,
345
+ attention_mask,
346
+ q_len,
347
+ dropout=dropout_rate,
348
+ use_sliding_windows=False,
349
+ )
350
+
351
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
352
+ attn_output = self.o_proj(attn_output)
353
+
354
+ if not output_attentions:
355
+ attn_weights = None
356
+
357
+ return attn_output, attn_weights, past_key_value
358
+
359
+ def _flash_attention_forward(
360
+ self,
361
+ query_states,
362
+ key_states,
363
+ value_states,
364
+ attention_mask,
365
+ query_length,
366
+ dropout=0.0,
367
+ softmax_scale=None,
368
+ use_sliding_windows=False,
369
+ ):
370
+ """
371
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
372
+ first unpad the input, then computes the attention scores and pad the final attention scores.
373
+
374
+ Args:
375
+ query_states (`torch.Tensor`):
376
+ Input query states to be passed to Flash Attention API
377
+ key_states (`torch.Tensor`):
378
+ Input key states to be passed to Flash Attention API
379
+ value_states (`torch.Tensor`):
380
+ Input value states to be passed to Flash Attention API
381
+ attention_mask (`torch.Tensor`):
382
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
383
+ position of padding tokens and 1 for the position of non-padding tokens.
384
+ dropout (`float`):
385
+ Attention dropout
386
+ softmax_scale (`float`, *optional*):
387
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
388
+ use_sliding_windows (`bool`, *optional*):
389
+ Whether to activate sliding window attention.
390
+ """
391
+ if not self._flash_attn_uses_top_left_mask:
392
+ causal = self.is_causal
393
+ else:
394
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
395
+ causal = self.is_causal and query_length != 1
396
+
397
+ # Contains at least one padding token in the sequence
398
+ if attention_mask is not None:
399
+ batch_size = query_states.shape[0]
400
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
401
+ query_states, key_states, value_states, attention_mask, query_length
402
+ )
403
+
404
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
405
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
406
+
407
+ if not use_sliding_windows:
408
+ attn_output_unpad = flash_attn_varlen_func(
409
+ query_states,
410
+ key_states,
411
+ value_states,
412
+ cu_seqlens_q=cu_seqlens_q,
413
+ cu_seqlens_k=cu_seqlens_k,
414
+ max_seqlen_q=max_seqlen_in_batch_q,
415
+ max_seqlen_k=max_seqlen_in_batch_k,
416
+ dropout_p=dropout,
417
+ softmax_scale=softmax_scale,
418
+ causal=causal,
419
+ )
420
+ else:
421
+ attn_output_unpad = flash_attn_varlen_func(
422
+ query_states,
423
+ key_states,
424
+ value_states,
425
+ cu_seqlens_q=cu_seqlens_q,
426
+ cu_seqlens_k=cu_seqlens_k,
427
+ max_seqlen_q=max_seqlen_in_batch_q,
428
+ max_seqlen_k=max_seqlen_in_batch_k,
429
+ dropout_p=dropout,
430
+ softmax_scale=softmax_scale,
431
+ causal=causal,
432
+ window_size=(self.config.sliding_window, self.config.sliding_window),
433
+ )
434
+
435
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
436
+ else:
437
+ if not use_sliding_windows:
438
+ attn_output = flash_attn_func(
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ dropout,
443
+ softmax_scale=softmax_scale,
444
+ causal=causal,
445
+ )
446
+ else:
447
+ attn_output = flash_attn_func(
448
+ query_states,
449
+ key_states,
450
+ value_states,
451
+ dropout,
452
+ softmax_scale=softmax_scale,
453
+ causal=causal,
454
+ window_size=(self.config.sliding_window, self.config.sliding_window),
455
+ )
456
+
457
+ return attn_output
458
+
459
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
460
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
461
+
462
+ # On the first iteration we need to properly re-create the padding mask
463
+ # by slicing it on the proper place
464
+ if kv_seq_len != attention_mask.shape[-1]:
465
+ attention_mask_num_tokens = attention_mask.shape[-1]
466
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
467
+
468
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
469
+
470
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
471
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
472
+
473
+ if query_length == kv_seq_len:
474
+ query_layer = index_first_axis(
475
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
476
+ )
477
+ cu_seqlens_q = cu_seqlens_k
478
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
479
+ indices_q = indices_k
480
+ elif query_length == 1:
481
+ max_seqlen_in_batch_q = 1
482
+ cu_seqlens_q = torch.arange(
483
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
484
+ ) # There is a memcpy here, that is very bad.
485
+ indices_q = cu_seqlens_q[:-1]
486
+ query_layer = query_layer.squeeze(1)
487
+ else:
488
+ # The -q_len: slice assumes left padding.
489
+ attention_mask = attention_mask[:, -query_length:]
490
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
491
+
492
+ return (
493
+ query_layer,
494
+ key_layer,
495
+ value_layer,
496
+ indices_q,
497
+ (cu_seqlens_q, cu_seqlens_k),
498
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
499
+ )
500
+
501
+ IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
502
+ "eager": Idefics2PerceiverAttention,
503
+ "flash_attention_2": Idefics2PerceiverFlashAttention2,
504
+ }
505
+
506
+
507
+ class Idefics2MLP(nn.Module):
508
+ def __init__(
509
+ self,
510
+ hidden_size: int,
511
+ intermediate_size: int,
512
+ output_size: int,
513
+ hidden_act: str,
514
+ ):
515
+ super().__init__()
516
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
517
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
518
+ self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
519
+ self.act_fn = ACT2FN[hidden_act]
520
+
521
+ def forward(self, x):
522
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
523
+
524
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
525
+ class Idefics2RMSNorm(nn.Module):
526
+ def __init__(self, hidden_size, eps=1e-6):
527
+ """
528
+ Idefics2RMSNorm is equivalent to T5LayerNorm
529
+ """
530
+ super().__init__()
531
+ self.weight = nn.Parameter(torch.ones(hidden_size))
532
+ self.variance_epsilon = eps
533
+
534
+ def forward(self, hidden_states):
535
+ input_dtype = hidden_states.dtype
536
+ hidden_states = hidden_states.to(torch.float32)
537
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
538
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
539
+ return self.weight * hidden_states.to(input_dtype)
540
+
541
+ class Idefics2PerceiverLayer(nn.Module):
542
+ def __init__(self, config, layer_idx: int):
543
+ super().__init__()
544
+ self.hidden_size = config.hidden_size
545
+ self.n_latents = config.resampler_n_latents
546
+ self.depth = config.resampler_depth
547
+ self.rms_norm_eps = config.rms_norm_eps
548
+
549
+ self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
550
+ self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
551
+ self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
552
+ self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
553
+ self.mlp = Idefics2MLP(
554
+ hidden_size=config.hidden_size,
555
+ intermediate_size=config.hidden_size * 4,
556
+ output_size=config.hidden_size,
557
+ hidden_act=config.hidden_act,
558
+ )
559
+
560
+ def forward(
561
+ self,
562
+ latents: torch.Tensor,
563
+ context: torch.Tensor,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ position_ids: Optional[torch.LongTensor] = None,
566
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
567
+ output_attentions: Optional[bool] = False,
568
+ use_cache: Optional[bool] = False,
569
+ **kwargs,
570
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
571
+ """
572
+ Args:
573
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
574
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
575
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
576
+ `(batch, sequence_length)` where padding elements are indicated by 0.
577
+ output_attentions (`bool`, *optional*):
578
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
579
+ returned tensors for more detail.
580
+ use_cache (`bool`, *optional*):
581
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
582
+ (see `past_key_values`).
583
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
584
+ """
585
+ residual = latents
586
+
587
+ latents = self.input_latents_norm(latents)
588
+ context = self.input_context_norm(context)
589
+
590
+ latents, self_attn_weights, present_key_value = self.self_attn(
591
+ latents=latents,
592
+ context=context,
593
+ attention_mask=attention_mask,
594
+ )
595
+ latents = residual + latents
596
+ residual = latents
597
+
598
+ latents = self.post_attention_layernorm(latents)
599
+ latents = self.mlp(latents)
600
+ latents = residual + latents
601
+
602
+ outputs = (latents,)
603
+
604
+ if output_attentions:
605
+ outputs += (self_attn_weights,)
606
+
607
+ if use_cache:
608
+ outputs += (present_key_value,)
609
+
610
+ return outputs
611
+
612
+ class Idefics2Qformer(nn.Module):
613
+
614
+ def __init__(self, config) -> None:
615
+ """
616
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
617
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
618
+ returns a Tensor of shape [bsz, n_latents, embed_dim]. The Resampler acts as a form of learned pooling and
619
+ is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206).
620
+ """
621
+ super().__init__()
622
+ config._attn_implementation = "flash_attention_2"
623
+ self._use_flash_attention_2 = True
624
+
625
+ self.hidden_size = config.hidden_size
626
+ self.hidden_act = config.hidden_act
627
+ self.n_latents = config.resampler_n_latents
628
+ self.depth = config.resampler_depth
629
+ self.rms_norm_eps = config.rms_norm_eps
630
+
631
+ # Create Latents for Perceiver
632
+ self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
633
+ # Create Transformer Blocks
634
+ self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
635
+ self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
636
+
637
+
638
+
639
+
640
+ def forward(
641
+ self,
642
+ context: torch.Tensor,
643
+ attention_mask,
644
+ ) -> torch.Tensor:
645
+ # seq embed -> bsz seq embed
646
+ latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
647
+
648
+ latent_attention_mask = torch.ones(
649
+ (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device
650
+ )
651
+ attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
652
+ attention_mask = (
653
+ _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
654
+ if not self._use_flash_attention_2
655
+ else attention_mask
656
+ )
657
+ #all_latents = []
658
+ compressed_context = latents
659
+ #all_latents.append(latents)
660
+ for perceiver_layer in self.layers:
661
+ layer_outputs = torch.utils.checkpoint.checkpoint(
662
+ perceiver_layer.__call__,
663
+ compressed_context,
664
+ context,
665
+ attention_mask,
666
+ None,
667
+ None,
668
+ False,
669
+ False,
670
+ use_reentrant=True)
671
+ #layer_outputs = perceiver_layer(
672
+ # compressed_context,
673
+ # context,
674
+ # attention_mask=attention_mask,
675
+ # position_ids=None,
676
+ # past_key_value=None,
677
+ # output_attentions=False,
678
+ # use_cache=False,
679
+ #)
680
+ compressed_context = layer_outputs[0]
681
+ #all_latents.append(compressed_context)
682
+
683
+ compressed_context = self.norm(compressed_context)
684
+
685
+ return compressed_context
686
+
687
+ class Idefics2Connector(PreTrainedModel):
688
+ _auto_class = 'AutoModel'
689
+ config_class = Idefics2ConnectorConfig
690
+
691
+ def __init__(self, config):
692
+ super().__init__(config)
693
+ self.modality_projection = Idefics2MLP(
694
+ hidden_size=config.vision_hidden_size,
695
+ intermediate_size=config.intermediate_size,
696
+ output_size=config.hidden_size,
697
+ hidden_act=config.hidden_act,
698
+ )
699
+ self.perceiver_resampler = Idefics2Qformer(config)
700
+ self.config = config
701
+
702
+ def forward(self, image_hidden_states, attention_mask):
703
+ image_hidden_states = self.modality_projection(image_hidden_states)
704
+ image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
705
+
706
+ vision_hidden_size = image_hidden_states.shape[-1]
707
+ num_image = image_hidden_states.shape[0]
708
+ reshaped_image_hidden_states = image_hidden_states.view(num_image, -1, vision_hidden_size)
709
+
710
+ return reshaped_image_hidden_states
711
+
712
+ @classmethod
713
+ def from_pretrained(self, config_path="/mnt/csp/mmvision/home/arrayyang/idefics2-8b/idefics2_connector"):
714
+ config = Idefics2ConnectorConfig.from_pretrained(f'{config_path}/config.json')
715
+ cls = Idefics2Connector(config=config)
716
+
717
+ state_dict = torch.load(f'{config_path}/connector.pth', map_location='cpu')
718
+ ret = cls.load_state_dict(state_dict, strict=False)
719
+ print("Loading idefics2 Connector from : {}".format(config_path))
720
+ return cls
image_processor.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import json
21
+ import torch
22
+
23
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from transformers.image_transforms import PaddingMode, pad, resize, to_channel_dimension_format
25
+ from transformers.image_utils import (
26
+ IMAGENET_STANDARD_MEAN,
27
+ IMAGENET_STANDARD_STD,
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ get_image_size,
32
+ infer_channel_dimension_format,
33
+ is_scaled_image,
34
+ is_valid_image,
35
+ to_numpy_array,
36
+ valid_images,
37
+ validate_preprocess_arguments,
38
+ )
39
+ from transformers.utils import TensorType, is_vision_available, logging
40
+ import PIL
41
+ from PIL import Image
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+
47
+
48
+ def get_resize_output_image_size(image, size, input_data_format) -> Tuple[int, int]:
49
+ """
50
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
51
+
52
+ Args:
53
+ image (`np.ndarray`):
54
+ Image to resize.
55
+ size (`Dict[str, int]`):
56
+ Size of the output image containing the keys "shortest_edge" and "longest_edge".
57
+ input_data_format (`ChannelDimension` or `str`):
58
+ The channel dimension format of the input image.
59
+
60
+ Returns:
61
+ The output size of the image after resizing.
62
+ """
63
+ height, width = get_image_size(image, channel_dim=input_data_format)
64
+
65
+ min_len = size["shortest_edge"]
66
+ max_len = size["longest_edge"]
67
+ aspect_ratio = width / height
68
+
69
+ if width >= height and width > max_len:
70
+ width = max_len
71
+ height = int(width / aspect_ratio)
72
+ elif height > width and height > max_len:
73
+ height = max_len
74
+ width = int(height * aspect_ratio)
75
+ height = max(height, min_len)
76
+ width = max(width, min_len)
77
+ return height, width
78
+
79
+
80
+ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
81
+ """
82
+ Convert a single image or a list of images to a list of numpy arrays.
83
+
84
+ Args:
85
+ images (`ImageInput`):
86
+ A single image or a list of images.
87
+
88
+ Returns:
89
+ A list of numpy arrays.
90
+ """
91
+ # If it's a single image, convert it to a list of lists
92
+ if is_valid_image(images):
93
+ images = [[images]]
94
+ # If it's a list of images, it's a single batch, so convert it to a list of lists
95
+ elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
96
+ images = [images]
97
+ # If it's a list of batches, it's already in the right format
98
+ elif (
99
+ isinstance(images, (list, tuple))
100
+ and len(images) > 0
101
+ and isinstance(images[0], (list, tuple))
102
+ and is_valid_image(images[0][0])
103
+ ):
104
+ pass
105
+ else:
106
+ raise ValueError(
107
+ "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
108
+ )
109
+ return images
110
+
111
+
112
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
113
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
114
+ """
115
+ Return the maximum value across all indices of an iterable of values.
116
+ """
117
+ return [max(values_i) for values_i in zip(*values)]
118
+
119
+
120
+ def get_max_height_width(
121
+ images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
122
+ ) -> List[int]:
123
+ """
124
+ Get the maximum height and width across all images in a batch.
125
+ """
126
+ if input_data_format is None:
127
+ input_data_format = infer_channel_dimension_format(images_list[0][0])
128
+
129
+ image_sizes = []
130
+ for images in images_list:
131
+ for image in images:
132
+ image_sizes.append(get_image_size(image, channel_dim=input_data_format))
133
+
134
+ max_height, max_width = max_across_indices(image_sizes)
135
+ return (max_height, max_width)
136
+
137
+
138
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
139
+ def make_pixel_mask(
140
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
141
+ ) -> np.ndarray:
142
+ """
143
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
144
+
145
+ Args:
146
+ image (`np.ndarray`):
147
+ Image to make the pixel mask for.
148
+ output_size (`Tuple[int, int]`):
149
+ Output size of the mask.
150
+ """
151
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
152
+ mask = np.zeros(output_size, dtype=np.int64)
153
+ mask[:input_height, :input_width] = 1
154
+ return mask
155
+
156
+
157
+ # FIXME Amy: merge this function with the one in image_transforms.py
158
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
159
+ """
160
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
161
+ as is.
162
+ Args:
163
+ image (Image):
164
+ The image to convert.
165
+ """
166
+ if not isinstance(image, PIL.Image.Image):
167
+ return image
168
+
169
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
170
+ # for transparent images. The call to `alpha_composite` handles this case
171
+ if image.mode == "RGB":
172
+ return image
173
+
174
+ image_rgba = image.convert("RGBA")
175
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
176
+ alpha_composite = Image.alpha_composite(background, image_rgba)
177
+ alpha_composite = alpha_composite.convert("RGB")
178
+ return alpha_composite
179
+
180
+
181
+ class Idefics2ImageProcessor(BaseImageProcessor):
182
+ r"""
183
+ Constructs a Idefics image processor.
184
+
185
+ Args:
186
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
187
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
188
+ Only has an effect if the input image is in the PIL format.
189
+ do_resize (`bool`, *optional*, defaults to `True`):
190
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
191
+ shortest edge resized to keep the input aspect ratio, with a minimum size of `size["shortest_edge"]`.
192
+ size (`Dict`, *optional*):
193
+ Controls the size of the output image. This is a dictionary containing the keys "shortest_edge" and "longest_edge".
194
+ resample (`Resampling`, *optional*, defaults to `Resampling.BILINEAR`):
195
+ Resampling filter to use when resizing the image.
196
+ do_rescale (`bool`, *optional*, defaults to `True`):
197
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
198
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
199
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
200
+ do_normalize (`bool`, *optional*, defaults to `True`):
201
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
202
+ a standard deviation of `image_std`.
203
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
204
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
205
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
206
+ overridden by the `image_mean` parameter in the `preprocess` method.
207
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
208
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
209
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
210
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
211
+ do_pad (`bool`, *optional*, defaults to `True`):
212
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
213
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
214
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
215
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
216
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
217
+ """
218
+
219
+ model_input_names = ["pixel_values"]
220
+
221
+ def __init__(
222
+ self,
223
+ do_convert_rgb: bool = True,
224
+ do_resize: bool = True,
225
+ size: Dict[str, int] = None,
226
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
227
+ do_rescale: bool = True,
228
+ rescale_factor: float = 1 / 255,
229
+ do_normalize: bool = True,
230
+ image_mean: Optional[Union[float, List[float]]] = None,
231
+ image_std: Optional[Union[float, List[float]]] = None,
232
+ do_pad: bool = True,
233
+ do_image_splitting: bool = False,
234
+ **kwargs,
235
+ ) -> None:
236
+ super().__init__(**kwargs)
237
+ self.do_convert_rgb = do_convert_rgb
238
+ self.do_resize = do_resize
239
+ self.size = size if size is not None else {"shortest_edge": 378, "longest_edge": 980}
240
+ self.resample = resample
241
+ self.do_rescale = do_rescale
242
+ self.rescale_factor = rescale_factor
243
+ self.do_normalize = do_normalize
244
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
245
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
246
+ self.do_pad = do_pad
247
+ self.do_image_splitting = do_image_splitting
248
+
249
+ def resize(
250
+ self,
251
+ image: np.ndarray,
252
+ size: Dict[str, int],
253
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
254
+ data_format: Optional[Union[str, ChannelDimension]] = None,
255
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
256
+ **kwargs,
257
+ ) -> np.ndarray:
258
+ """
259
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
260
+ resized to keep the input aspect ratio.
261
+
262
+ Args:
263
+ image (`np.ndarray`):
264
+ Image to resize.
265
+ size (`Dict[str, int]`):
266
+ Size of the output image.
267
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
268
+ Resampling filter to use when resiizing the image.
269
+ data_format (`str` or `ChannelDimension`, *optional*):
270
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
271
+ input_data_format (`ChannelDimension` or `str`, *optional*):
272
+ The channel dimension format of the input image. If not provided, it will be inferred.
273
+ """
274
+ if "shortest_edge" in size and "longest_edge" in size:
275
+ size = get_resize_output_image_size(image, size, input_data_format)
276
+ elif "height" in size and "width" in size:
277
+ size = (size["height"], size["width"])
278
+ else:
279
+ raise ValueError(
280
+ "size must be a dictionary with keys 'shortest_edge' and 'longest_edge' or 'height' and 'width'."
281
+ )
282
+ try:
283
+ resized = resize(
284
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
285
+ )
286
+ except Exception as err:
287
+ print(f"resize error with image: {image.shape} {image}")
288
+
289
+ return resize(
290
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
291
+ )
292
+
293
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
294
+ def _pad_image(
295
+ self,
296
+ image: np.ndarray,
297
+ output_size: Tuple[int, int],
298
+ constant_values: Union[float, Iterable[float]] = 0,
299
+ data_format: Optional[ChannelDimension] = None,
300
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
301
+ ) -> np.ndarray:
302
+ """
303
+ Pad an image with zeros to the given size.
304
+ """
305
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
306
+ output_height, output_width = output_size
307
+
308
+ pad_bottom = output_height - input_height
309
+ pad_right = output_width - input_width
310
+ padding = ((0, pad_bottom), (0, pad_right))
311
+ padded_image = pad(
312
+ image,
313
+ padding,
314
+ mode=PaddingMode.CONSTANT,
315
+ constant_values=constant_values,
316
+ data_format=data_format,
317
+ input_data_format=input_data_format,
318
+ )
319
+ return padded_image
320
+
321
+ def pad(
322
+ self,
323
+ images: List[np.ndarray],
324
+ constant_values: Union[float, Iterable[float]] = 0,
325
+ return_pixel_mask: bool = True,
326
+ return_tensors: Optional[Union[str, TensorType]] = None,
327
+ data_format: Optional[ChannelDimension] = None,
328
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
329
+ ) -> BatchFeature:
330
+ """
331
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
332
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
333
+
334
+ Args:
335
+ images (`np.ndarray`):
336
+ List of list of images to pad. Pads to the largest height and width in the batch.
337
+ constant_values (`float` or `Iterable[float]`, *optional*):
338
+ The value to use for the padding if `mode` is `"constant"`.
339
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
340
+ Whether to return a pixel mask.
341
+ return_tensors (`str` or `TensorType`, *optional*):
342
+ The type of tensors to return. Can be one of:
343
+ - Unset: Return a list of `np.ndarray`.
344
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
345
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
346
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
347
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
348
+ data_format (`str` or `ChannelDimension`, *optional*):
349
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
350
+ input_data_format (`ChannelDimension` or `str`, *optional*):
351
+ The channel dimension format of the input image. If not provided, it will be inferred.
352
+ """
353
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
354
+
355
+ batch_size = len(images)
356
+ max_num_images = max(len(images_) for images_ in images)
357
+ input_data_format = (
358
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
359
+ )
360
+ data_format = input_data_format if data_format is None else data_format
361
+
362
+ def empty_image(size, input_data_format):
363
+ if input_data_format == ChannelDimension.FIRST:
364
+ return np.zeros((3, *size), dtype=np.uint8)
365
+ elif input_data_format == ChannelDimension.LAST:
366
+ return np.zeros((*size, 3), dtype=np.uint8)
367
+ raise ValueError("Invalid channel dimension format.")
368
+
369
+ padded_images_list = [
370
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
371
+ ]
372
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
373
+
374
+ for batch_idx in range(batch_size):
375
+ for sample_idx, image in enumerate(images[batch_idx]):
376
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
377
+ image,
378
+ pad_size,
379
+ constant_values=constant_values,
380
+ data_format=data_format,
381
+ input_data_format=input_data_format,
382
+ )
383
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
384
+ image, output_size=pad_size, input_data_format=input_data_format
385
+ )
386
+
387
+ padded_masks = padded_masks if return_pixel_mask else None
388
+ return padded_images_list, padded_masks
389
+
390
+ def _crop(
391
+ self,
392
+ im: np.ndarray,
393
+ w1: int,
394
+ h1: int,
395
+ w2: int,
396
+ h2: int,
397
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
398
+ ) -> np.ndarray:
399
+ if input_data_format == ChannelDimension.FIRST:
400
+ return im[:, h1:h2, w1:w2]
401
+ elif input_data_format == ChannelDimension.LAST:
402
+ return im[h1:h2, w1:w2, :]
403
+
404
+ def split_image(
405
+ self,
406
+ image: np.ndarray,
407
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
408
+ ):
409
+ """
410
+ Split an image into 4 equal sub-images, and the concatenate that sequence with the original image.
411
+ That means that a single image becomes a sequence of 5 images.
412
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
413
+
414
+ Args:
415
+ image (`np.ndarray`):
416
+ Images to split.
417
+ input_data_format (`ChannelDimension` or `str`, *optional*):
418
+ The channel dimension format of the input image. If not provided, it will be inferred.
419
+ """
420
+ height, width = get_image_size(image, input_data_format)
421
+
422
+ mid_width = width // 2
423
+ mid_height = height // 2
424
+ image_list = [
425
+ self._crop(image, 0, 0, mid_width, mid_height, input_data_format),
426
+ self._crop(image, mid_width, 0, width, mid_height, input_data_format),
427
+ self._crop(image, 0, mid_height, mid_width, height, input_data_format),
428
+ self._crop(image, mid_width, mid_height, width, height, input_data_format),
429
+ image,
430
+ ]
431
+ return image_list
432
+
433
+ def preprocess(
434
+ self,
435
+ images: ImageInput,
436
+ do_convert_rgb: Optional[bool] = None,
437
+ do_resize: Optional[bool] = None,
438
+ size: Optional[Dict[str, int]] = None,
439
+ resample: PILImageResampling = None,
440
+ do_rescale: Optional[bool] = None,
441
+ rescale_factor: Optional[float] = None,
442
+ do_normalize: Optional[bool] = None,
443
+ image_mean: Optional[Union[float, List[float]]] = None,
444
+ image_std: Optional[Union[float, List[float]]] = None,
445
+ do_pad: Optional[bool] = None,
446
+ do_image_splitting: Optional[bool] = None,
447
+ return_tensors: Optional[Union[str, TensorType]] = None,
448
+ input_data_format: Optional[ChannelDimension] = None,
449
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
450
+ ):
451
+ """
452
+ Preprocess a batch of images.
453
+
454
+ Args:
455
+ images (`ImageInput`):
456
+ A list of images to preprocess.
457
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
458
+ Whether to convert the image to RGB.
459
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
460
+ Whether to resize the image.
461
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
462
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
463
+ the longest edge resized to keep the input aspect ratio.
464
+ resample (`int`, *optional*, defaults to `self.resample`):
465
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
466
+ has an effect if `do_resize` is set to `True`.
467
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
468
+ Whether to rescale the image.
469
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
470
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
471
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
472
+ Whether to normalize the image.
473
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
474
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
475
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
476
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
477
+ `True`.
478
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
479
+ Whether or not to pad the images to the largest height and width in the batch.
480
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
481
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
482
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
483
+ return_tensors (`str` or `TensorType`, *optional*):
484
+ The type of tensors to return. Can be one of:
485
+ - Unset: Return a list of `np.ndarray`.
486
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
487
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
488
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
489
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
490
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
491
+ The channel dimension format for the output image. Can be one of:
492
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
493
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
494
+ - Unset: Use the channel dimension format of the input image.
495
+ input_data_format (`ChannelDimension` or `str`, *optional*):
496
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
497
+ from the input image. Can be one of:
498
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
499
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
500
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
501
+ """
502
+ do_resize = do_resize if do_resize is not None else self.do_resize
503
+ size = size if size is not None else self.size
504
+ resample = resample if resample is not None else self.resample
505
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
506
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
507
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
508
+ image_mean = image_mean if image_mean is not None else self.image_mean
509
+ image_std = image_std if image_std is not None else self.image_std
510
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
511
+ do_pad = do_pad if do_pad is not None else self.do_pad
512
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
513
+
514
+ images_list = make_list_of_images(images)
515
+
516
+ if not valid_images(images_list[0]):
517
+ raise ValueError(
518
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
519
+ "torch.Tensor, tf.Tensor or jax.ndarray."
520
+ )
521
+
522
+ validate_preprocess_arguments(
523
+ do_rescale=do_rescale,
524
+ rescale_factor=rescale_factor,
525
+ do_normalize=do_normalize,
526
+ image_mean=image_mean,
527
+ image_std=image_std,
528
+ do_resize=do_resize,
529
+ size=size,
530
+ resample=resample,
531
+ )
532
+
533
+ if do_convert_rgb:
534
+ images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
535
+
536
+ # All transformations expect numpy arrays.
537
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
538
+
539
+ if is_scaled_image(images_list[0][0]) and do_rescale:
540
+ logger.warning_once(
541
+ "It looks like you are trying to rescale already rescaled images. If the input"
542
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
543
+ )
544
+
545
+ if input_data_format is None:
546
+ # We assume that all images have the same channel dimension format.
547
+ input_data_format = ChannelDimension.LAST #infer_channel_dimension_format(images_list[0][0])
548
+
549
+ if do_image_splitting:
550
+ new_images_list = []
551
+ for images in images_list:
552
+ new_images = []
553
+ for image in images:
554
+ new_images.extend(self.split_image(image, input_data_format))
555
+ new_images_list.append(new_images)
556
+ images_list = new_images_list
557
+
558
+ if do_resize:
559
+ images_list = [
560
+ [
561
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
562
+ for image in images
563
+ ]
564
+ for images in images_list
565
+ ]
566
+
567
+ if do_rescale:
568
+ images_list = [
569
+ [
570
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
571
+ for image in images
572
+ ]
573
+ for images in images_list
574
+ ]
575
+
576
+ if do_normalize:
577
+ images_list = [
578
+ [
579
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
580
+ for image in images
581
+ ]
582
+ for images in images_list
583
+ ]
584
+
585
+ pixel_attention_mask = None
586
+ if do_pad:
587
+ images_list, pixel_attention_mask = self.pad(
588
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
589
+ )
590
+
591
+ if data_format is not None:
592
+ images_list = [
593
+ [
594
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
595
+ for image in images
596
+ ]
597
+ for images in images_list
598
+ ]
599
+
600
+ data = {"pixel_values": np.array(images_list) if do_pad else images_list} # Faster tensor conversion
601
+ if pixel_attention_mask is not None:
602
+ data["pixel_attention_mask"] = np.array(pixel_attention_mask) if do_pad else pixel_attention_mask
603
+
604
+
605
+ temp_pixel_values = data["pixel_values"].copy()
606
+ temp_pixel_values = torch.from_numpy(temp_pixel_values)
607
+ batch_size, num_images, num_channels, height, width = temp_pixel_values.shape
608
+ temp_pixel_values = temp_pixel_values.view(batch_size * num_images, *temp_pixel_values.shape[2:])
609
+ # Remove padding images - padding images are full 0.
610
+ nb_values_per_image = temp_pixel_values.shape[1:].numel()
611
+ real_images_inds = (temp_pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
612
+ temp_pixel_values = temp_pixel_values[real_images_inds].contiguous()
613
+ # if 'pixel_attention_mask' is not none
614
+ if 'pixel_attention_mask' in data:
615
+ pixel_attention_mask = torch.from_numpy(data['pixel_attention_mask'])
616
+ # Remove padding images from the mask/pP p
617
+ pixel_attention_mask = pixel_attention_mask.view(
618
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
619
+ )
620
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
621
+ pixel_attention_mask = pixel_attention_mask.to(torch.bool)
622
+ else:
623
+ pixel_attention_mask = torch.ones(
624
+ size=(temp_pixel_values.size(0), temp_pixel_values.size(2), temp_pixel_values.size(3)),
625
+ dtype=torch.bool,
626
+ device=temp_pixel_values.device,
627
+ )
628
+ patch_size = 14 #self.config.vision_config.patch_size
629
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
630
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
631
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
632
+
633
+ data["navit_pixel_values"] = temp_pixel_values
634
+ data["pixel_attention_mask"] = patch_attention_mask
635
+
636
+ return BatchFeature(data=data, tensor_type=return_tensors)
637
+
638
+ @classmethod
639
+ def from_pretrained(self, config_path="/mnt/csp/mmvision/home/arrayyang/idefics2-8b/idefics2_image_processor"):
640
+ with open(f'{config_path}/config.json', "r", encoding="utf-8") as f:
641
+ config = json.load(f)
642
+
643
+ cls = Idefics2ImageProcessor(
644
+ do_convert_rgb = config['do_convert_rgb'],
645
+ do_resize = config['do_resize'],
646
+ size = config['size'],
647
+ resample = config['resample'],
648
+ do_rescale = config['do_rescale'],
649
+ rescale_factor = config['rescale_factor'],
650
+ do_normalize = config['do_normalize'],
651
+ image_mean = config['image_mean'],
652
+ image_std = config['image_std'],
653
+ do_pad = config['do_pad'],
654
+ do_image_splitting = config['do_image_splitting']
655
+ )
656
+ #print("Loading idefics2 image Processor: {}".format(config_path))
657
+ return cls
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a63aa34ebb22ce9607601f38bb1304c4c632f8f8708687540ea2effda145648
3
+ size 4916334192
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ac3a55783b156725e1b85d76aeb8a8d62a3a87f111937d7a1dc055c97d634e2
3
+ size 4959969744
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c21a3cdb95c20ffabb3211af5a24df1c36d87110c1ea5c3b08580b48e1de66f
3
+ size 4993507248
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bda83c8f7b945b4cefc1b0b7abed8842ccc4eec99f07d4ae1c0d3fe7b1674afd
3
+ size 4245939544
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_downsampler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_downsampler import DownsamplerConfig
8
+
9
+
10
+ class DownsamplerModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = DownsamplerConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: DownsamplerConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ self.group_op = nn.Conv2d(
21
+ in_channels=config.visual_hidden_size,
22
+ out_channels=config.llm_hidden_size,
23
+ bias=config.bias,
24
+ kernel_size=config.kernel_size, stride=config.stride)
25
+ modules = list()
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.linear_model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, DownsamplerModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def _forward(self, x):
47
+
48
+ # (B, FULL_H, FULL_W, D) -> (B, D, FULL_H, FULL_W)
49
+ x = x.permute(0, 3, 1, 2)
50
+ x = self.group_op(x)
51
+ # (B, D, FULL_H, FULL_W) -> (B, FULL_H, FULL_W, D)
52
+ x = x.permute(0, 2, 3, 1)
53
+ x = self.linear_model(x)
54
+
55
+ return x
56
+
57
+ def forward(self, x):
58
+ if self.gradient_checkpointing and self.training:
59
+ layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x)
60
+ else:
61
+ layer_outputs = self._forward(x)
62
+ return layer_outputs
modeling_internlm2.py ADDED
@@ -0,0 +1,1494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+
43
+ try:
44
+ from transformers.generation.streamers import BaseStreamer
45
+ except: # noqa # pylint: disable=bare-except
46
+ BaseStreamer = None
47
+
48
+ from .configuration_internlm2 import InternLM2Config
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "InternLM2Config"
53
+
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
+ def _make_causal_mask(
82
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
83
+ ):
84
+ """
85
+ Make causal mask used for bi-directional self-attention.
86
+ """
87
+ bsz, tgt_len = input_ids_shape
88
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
89
+ mask_cond = torch.arange(mask.size(-1), device=device)
90
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
91
+ mask = mask.to(dtype)
92
+
93
+ if past_key_values_length > 0:
94
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
95
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
96
+
97
+
98
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
99
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100
+ """
101
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102
+ """
103
+ bsz, src_len = mask.size()
104
+ tgt_len = tgt_len if tgt_len is not None else src_len
105
+
106
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
+
112
+
113
+ class PLoRA(nn.Module):
114
+
115
+ def __init__(self,
116
+ in_features: int,
117
+ out_features: int,
118
+ bias: bool = True,
119
+ device=None,
120
+ dtype=None,
121
+ lora_r=8,
122
+ lora_alpha=16,
123
+ lora_dropout=0.05,
124
+ lora_len=0,
125
+ **kwargs) -> None:
126
+ super().__init__()
127
+
128
+ self.original_linear = nn.Linear(in_features, out_features, bias, device, dtype)
129
+
130
+ self.lora_r = lora_r
131
+ self.lora_alpha = lora_alpha
132
+ self.lora_len = lora_len
133
+ if lora_dropout > 0.:
134
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
135
+ else:
136
+ self.lora_dropout = lambda x: x
137
+ self.lora_scaling = self.lora_alpha / self.lora_r
138
+
139
+ self.Plora_A = nn.Linear(
140
+ in_features, self.lora_r, bias=False, device=device, dtype=dtype)
141
+ self.Plora_B = nn.Linear(
142
+ self.lora_r, out_features, bias=False, device=device, dtype=dtype)
143
+
144
+ self.reset_parameters()
145
+
146
+ def reset_parameters(self):
147
+ if hasattr(self, 'lora_A'):
148
+ # initialize A the same way as the default for nn.Linear and B to zero
149
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
150
+ nn.init.zeros_(self.lora_B.weight)
151
+
152
+ def forward(self, x, im_mask=None):
153
+ res = self.original_linear(x)
154
+
155
+ if im_mask is not None:
156
+ if torch.sum(im_mask) > 0:
157
+ part_x = x[im_mask]
158
+ res[im_mask] += self.Plora_B(
159
+ self.Plora_A(
160
+ self.lora_dropout(part_x))) * self.lora_scaling
161
+ else:
162
+ part_x = x[:, :1]
163
+ res[:, :1] += self.Plora_B(
164
+ self.Plora_A(self.lora_dropout(part_x))) * 0
165
+ return res
166
+
167
+
168
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
169
+ class InternLM2RMSNorm(nn.Module):
170
+ def __init__(self, hidden_size, eps=1e-6):
171
+ """
172
+ InternLM2RMSNorm is equivalent to T5LayerNorm
173
+ """
174
+ super().__init__()
175
+ self.weight = nn.Parameter(torch.ones(hidden_size))
176
+ self.variance_epsilon = eps
177
+
178
+ def forward(self, hidden_states):
179
+ input_dtype = hidden_states.dtype
180
+ hidden_states = hidden_states.to(torch.float32)
181
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
182
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
183
+ return self.weight * hidden_states.to(input_dtype)
184
+
185
+
186
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
187
+ class InternLM2RotaryEmbedding(nn.Module):
188
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
189
+ super().__init__()
190
+
191
+ self.dim = dim
192
+ self.max_position_embeddings = max_position_embeddings
193
+ self.base = base
194
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
195
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
196
+
197
+ # Build here to make `torch.jit.trace` work.
198
+ self._set_cos_sin_cache(
199
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
200
+ )
201
+
202
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
203
+ self.max_seq_len_cached = seq_len
204
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
205
+
206
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
207
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
208
+ emb = torch.cat((freqs, freqs), dim=-1)
209
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
210
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
211
+
212
+ def forward(self, x, seq_len=None):
213
+ # x: [bs, num_attention_heads, seq_len, head_size]
214
+ if seq_len > self.max_seq_len_cached:
215
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
216
+
217
+ return (
218
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
219
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
220
+ )
221
+
222
+
223
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
224
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
225
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
226
+
227
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
228
+ self.scaling_factor = scaling_factor
229
+ super().__init__(dim, max_position_embeddings, base, device)
230
+
231
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
232
+ self.max_seq_len_cached = seq_len
233
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
234
+ t = t / self.scaling_factor
235
+
236
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
237
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
238
+ emb = torch.cat((freqs, freqs), dim=-1)
239
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
240
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
241
+
242
+
243
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
244
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
245
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
246
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
247
+ """
248
+
249
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
250
+ self.scaling_factor = scaling_factor
251
+ super().__init__(dim, max_position_embeddings, base, device)
252
+
253
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
254
+ self.max_seq_len_cached = seq_len
255
+
256
+ if seq_len > self.max_position_embeddings:
257
+ base = self.base * (
258
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
259
+ ) ** (self.dim / (self.dim - 2))
260
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
261
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
262
+
263
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
264
+
265
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
266
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
267
+ emb = torch.cat((freqs, freqs), dim=-1)
268
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
269
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
270
+
271
+
272
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
273
+ def rotate_half(x):
274
+ """Rotates half the hidden dims of the input."""
275
+ x1 = x[..., : x.shape[-1] // 2]
276
+ x2 = x[..., x.shape[-1] // 2 :]
277
+ return torch.cat((-x2, x1), dim=-1)
278
+
279
+
280
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
281
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
282
+ """Applies Rotary Position Embedding to the query and key tensors."""
283
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
284
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
285
+ q_embed = (q * cos) + (rotate_half(q) * sin)
286
+ k_embed = (k * cos) + (rotate_half(k) * sin)
287
+ return q_embed, k_embed
288
+
289
+
290
+ class InternLM2MLP(nn.Module):
291
+ def __init__(self, config):
292
+ super().__init__()
293
+ self.config = config
294
+ self.hidden_size = config.hidden_size
295
+ self.intermediate_size = config.intermediate_size
296
+ # self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
297
+ # self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
298
+ # self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
299
+
300
+ self.w1 = PLoRA(
301
+ self.hidden_size,
302
+ self.intermediate_size,
303
+ bias=False,
304
+ lora_r=256,
305
+ lora_alpha=256,
306
+ lora_len=576)
307
+ self.w3 = PLoRA(
308
+ self.hidden_size,
309
+ self.intermediate_size,
310
+ bias=False,
311
+ lora_r=256,
312
+ lora_alpha=256,
313
+ lora_len=576)
314
+ self.w2 = PLoRA(
315
+ self.intermediate_size,
316
+ self.hidden_size,
317
+ bias=False,
318
+ lora_r=256,
319
+ lora_alpha=256,
320
+ lora_len=576)
321
+
322
+ self.act_fn = ACT2FN[config.hidden_act]
323
+
324
+ def forward(self, x, im_mask):
325
+ down_proj = self.w2(self.act_fn(self.w1(x, im_mask)) * self.w3(x, im_mask), im_mask)
326
+ return down_proj
327
+
328
+
329
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
330
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
331
+ """
332
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
333
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
334
+ """
335
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
336
+ if n_rep == 1:
337
+ return hidden_states
338
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
339
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
340
+
341
+
342
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
343
+ class InternLM2Attention(nn.Module):
344
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
345
+
346
+ def __init__(self, config: InternLM2Config):
347
+ super().__init__()
348
+ self.config = config
349
+ self.hidden_size = config.hidden_size
350
+ self.num_heads = config.num_attention_heads
351
+ self.head_dim = self.hidden_size // self.num_heads
352
+ self.num_key_value_heads = config.num_key_value_heads
353
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
354
+ self.max_position_embeddings = config.max_position_embeddings
355
+ self.is_causal = True
356
+
357
+ if (self.head_dim * self.num_heads) != self.hidden_size:
358
+ raise ValueError(
359
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
360
+ f" and `num_heads`: {self.num_heads})."
361
+ )
362
+
363
+ # self.wqkv = nn.Linear(
364
+ # self.hidden_size,
365
+ # (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
366
+ # bias=config.bias,
367
+ # )
368
+ #
369
+ # self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
370
+
371
+ self.wqkv = PLoRA(
372
+ self.hidden_size,
373
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
374
+ bias=config.bias,
375
+ lora_r=256,
376
+ lora_alpha=256,
377
+ lora_len=576)
378
+
379
+ self.wo = PLoRA(
380
+ self.num_heads * self.head_dim,
381
+ self.hidden_size,
382
+ bias=config.bias,
383
+ lora_r=256,
384
+ lora_alpha=256,
385
+ lora_len=576)
386
+ self._init_rope()
387
+
388
+ def _init_rope(self):
389
+ if self.config.rope_scaling is None:
390
+ self.rotary_emb = InternLM2RotaryEmbedding(
391
+ self.head_dim,
392
+ max_position_embeddings=self.max_position_embeddings,
393
+ base=self.config.rope_theta,
394
+ )
395
+ else:
396
+ scaling_type = self.config.rope_scaling["type"]
397
+ scaling_factor = self.config.rope_scaling["factor"]
398
+ if scaling_type == "dynamic":
399
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
400
+ self.head_dim,
401
+ max_position_embeddings=self.max_position_embeddings,
402
+ base=self.config.rope_theta,
403
+ scaling_factor=scaling_factor,
404
+ )
405
+ elif scaling_type == "linear":
406
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
407
+ self.head_dim,
408
+ max_position_embeddings=self.max_position_embeddings,
409
+ base=self.config.rope_theta,
410
+ scaling_factor=scaling_factor,
411
+ )
412
+ else:
413
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
414
+ return self.rotary_emb
415
+
416
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
417
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states: torch.Tensor,
422
+ attention_mask: Optional[torch.Tensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
425
+ output_attentions: bool = False,
426
+ use_cache: bool = False,
427
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
428
+ **kwargs,
429
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
430
+ if "padding_mask" in kwargs:
431
+ warnings.warn(
432
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
433
+ "Please make sure use `attention_mask` instead.`"
434
+ )
435
+
436
+ bsz, q_len, _ = hidden_states.size()
437
+ qkv_states = self.wqkv(hidden_states, im_mask)
438
+
439
+ qkv_states = rearrange(
440
+ qkv_states,
441
+ "b q (h gs d) -> b q h gs d",
442
+ gs=2 + self.num_key_value_groups,
443
+ d=self.head_dim,
444
+ )
445
+
446
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
447
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
448
+ key_states = qkv_states[..., -2, :]
449
+ value_states = qkv_states[..., -1, :]
450
+
451
+ query_states = query_states.transpose(1, 2)
452
+ key_states = key_states.transpose(1, 2)
453
+ value_states = value_states.transpose(1, 2)
454
+
455
+ kv_seq_len = key_states.shape[-2]
456
+ if past_key_value is not None:
457
+ kv_seq_len += past_key_value[0].shape[-2]
458
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
459
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
460
+
461
+ if past_key_value is not None:
462
+ # reuse k, v, self_attention
463
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
464
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
465
+
466
+ past_key_value = (key_states, value_states) if use_cache else None
467
+
468
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
469
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
470
+
471
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
472
+
473
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
474
+ raise ValueError(
475
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
476
+ f" {attn_weights.size()}"
477
+ )
478
+
479
+ if attention_mask is not None:
480
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
481
+ raise ValueError(
482
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
483
+ )
484
+ attn_weights = attn_weights + attention_mask
485
+
486
+ # upcast attention to fp32
487
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
488
+ attn_output = torch.matmul(attn_weights, value_states)
489
+
490
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
491
+ raise ValueError(
492
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
493
+ f" {attn_output.size()}"
494
+ )
495
+
496
+ attn_output = attn_output.transpose(1, 2).contiguous()
497
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
498
+
499
+ attn_output = self.wo(attn_output, im_mask)
500
+
501
+ if not output_attentions:
502
+ attn_weights = None
503
+
504
+ return attn_output, attn_weights, past_key_value
505
+
506
+
507
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
508
+ class InternLM2FlashAttention2(InternLM2Attention):
509
+ """
510
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
511
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
512
+ flash attention and deal with padding tokens in case the input contains any of them.
513
+ """
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.Tensor,
518
+ attention_mask: Optional[torch.LongTensor] = None,
519
+ position_ids: Optional[torch.LongTensor] = None,
520
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
521
+ output_attentions: bool = False,
522
+ use_cache: bool = False,
523
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
524
+ **kwargs,
525
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
526
+ # InternLM2FlashAttention2 attention does not support output_attentions
527
+ if "padding_mask" in kwargs:
528
+ warnings.warn(
529
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
530
+ "Please make sure use `attention_mask` instead.`"
531
+ )
532
+
533
+ # overwrite attention_mask with padding_mask
534
+ attention_mask = kwargs.pop("padding_mask")
535
+
536
+ output_attentions = False
537
+
538
+ bsz, q_len, _ = hidden_states.size()
539
+ qkv_states = self.wqkv(hidden_states, im_mask)
540
+
541
+ qkv_states = rearrange(
542
+ qkv_states,
543
+ "b q (h gs d) -> b q h gs d",
544
+ gs=2 + self.num_key_value_groups,
545
+ d=self.head_dim,
546
+ )
547
+
548
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
549
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
550
+ key_states = qkv_states[..., -2, :]
551
+ value_states = qkv_states[..., -1, :]
552
+
553
+ query_states = query_states.transpose(1, 2)
554
+ key_states = key_states.transpose(1, 2)
555
+ value_states = value_states.transpose(1, 2)
556
+
557
+ kv_seq_len = key_states.shape[-2]
558
+ if past_key_value is not None:
559
+ kv_seq_len += past_key_value[0].shape[-2]
560
+
561
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
562
+
563
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
564
+
565
+ if past_key_value is not None:
566
+ # reuse k, v, self_attention
567
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
568
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
569
+
570
+ past_key_value = (key_states, value_states) if use_cache else None
571
+
572
+ query_states = query_states.transpose(1, 2)
573
+ key_states = key_states.transpose(1, 2)
574
+ value_states = value_states.transpose(1, 2)
575
+
576
+ attn_output = self._flash_attention_forward(
577
+ query_states, key_states, value_states, attention_mask, q_len
578
+ )
579
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
580
+ attn_output = self.wo(attn_output, im_mask)
581
+
582
+ if not output_attentions:
583
+ attn_weights = None
584
+
585
+ return attn_output, attn_weights, past_key_value
586
+
587
+ def _flash_attention_forward(
588
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
589
+ ):
590
+ """
591
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
592
+ first unpad the input, then computes the attention scores and pad the final attention scores.
593
+
594
+ Args:
595
+ query_states (`torch.Tensor`):
596
+ Input query states to be passed to Flash Attention API
597
+ key_states (`torch.Tensor`):
598
+ Input key states to be passed to Flash Attention API
599
+ value_states (`torch.Tensor`):
600
+ Input value states to be passed to Flash Attention API
601
+ attention_mask (`torch.Tensor`):
602
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
603
+ position of padding tokens and 1 for the position of non-padding tokens.
604
+ dropout (`int`, *optional*):
605
+ Attention dropout
606
+ softmax_scale (`float`, *optional*):
607
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
608
+ """
609
+ # Contains at least one padding token in the sequence
610
+ causal = self.is_causal and query_length != 1
611
+ if attention_mask is not None:
612
+ batch_size = query_states.shape[0]
613
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
614
+ query_states, key_states, value_states, attention_mask, query_length
615
+ )
616
+
617
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
618
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
619
+
620
+ attn_output_unpad = flash_attn_varlen_func(
621
+ query_states,
622
+ key_states,
623
+ value_states,
624
+ cu_seqlens_q=cu_seqlens_q,
625
+ cu_seqlens_k=cu_seqlens_k,
626
+ max_seqlen_q=max_seqlen_in_batch_q,
627
+ max_seqlen_k=max_seqlen_in_batch_k,
628
+ dropout_p=dropout,
629
+ softmax_scale=softmax_scale,
630
+ causal=causal,
631
+ )
632
+
633
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
634
+ else:
635
+ attn_output = flash_attn_func(
636
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
637
+ )
638
+
639
+ return attn_output
640
+
641
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
642
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
643
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
644
+
645
+ key_layer = index_first_axis(
646
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
647
+ )
648
+ value_layer = index_first_axis(
649
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
650
+ )
651
+
652
+ if query_length == kv_seq_len:
653
+ query_layer = index_first_axis(
654
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
655
+ )
656
+ cu_seqlens_q = cu_seqlens_k
657
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
658
+ indices_q = indices_k
659
+ elif query_length == 1:
660
+ max_seqlen_in_batch_q = 1
661
+ cu_seqlens_q = torch.arange(
662
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
663
+ ) # There is a memcpy here, that is very bad.
664
+ indices_q = cu_seqlens_q[:-1]
665
+ query_layer = query_layer.squeeze(1)
666
+ else:
667
+ # The -q_len: slice assumes left padding.
668
+ attention_mask = attention_mask[:, -query_length:]
669
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
670
+
671
+ return (
672
+ query_layer,
673
+ key_layer,
674
+ value_layer,
675
+ indices_q.to(torch.int64),
676
+ (cu_seqlens_q, cu_seqlens_k),
677
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
678
+ )
679
+
680
+ INTERNLM2_ATTENTION_CLASSES = {
681
+ "eager": InternLM2Attention,
682
+ "flash_attention_2": InternLM2FlashAttention2,
683
+ }
684
+
685
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
686
+ class InternLM2DecoderLayer(nn.Module):
687
+ def __init__(self, config: InternLM2Config):
688
+ super().__init__()
689
+ self.hidden_size = config.hidden_size
690
+
691
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
692
+ self.feed_forward = InternLM2MLP(config)
693
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
694
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states: torch.Tensor,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ position_ids: Optional[torch.LongTensor] = None,
701
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
702
+ output_attentions: Optional[bool] = False,
703
+ use_cache: Optional[bool] = False,
704
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
705
+ **kwargs,
706
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
707
+ """
708
+ Args:
709
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
710
+ attention_mask (`torch.FloatTensor`, *optional*):
711
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
712
+ query_sequence_length, key_sequence_length)` if default attention is used.
713
+ output_attentions (`bool`, *optional*):
714
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
715
+ returned tensors for more detail.
716
+ use_cache (`bool`, *optional*):
717
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
718
+ (see `past_key_values`).
719
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
720
+ """
721
+ if "padding_mask" in kwargs:
722
+ warnings.warn(
723
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
724
+ "Please make sure use `attention_mask` instead.`"
725
+ )
726
+
727
+ residual = hidden_states
728
+
729
+ hidden_states = self.attention_norm(hidden_states)
730
+ # Self Attention
731
+ hidden_states, self_attn_weights, present_key_value = self.attention(
732
+ hidden_states=hidden_states,
733
+ attention_mask=attention_mask,
734
+ position_ids=position_ids,
735
+ past_key_value=past_key_value,
736
+ output_attentions=output_attentions,
737
+ use_cache=use_cache,
738
+ im_mask=im_mask,
739
+ **kwargs,
740
+ )
741
+ hidden_states = residual + hidden_states
742
+
743
+ # Fully Connected
744
+ residual = hidden_states
745
+ hidden_states = self.ffn_norm(hidden_states)
746
+ hidden_states = self.feed_forward(hidden_states, im_mask)
747
+ hidden_states = residual + hidden_states
748
+
749
+ outputs = (hidden_states,)
750
+
751
+ if output_attentions:
752
+ outputs += (self_attn_weights,)
753
+
754
+ if use_cache:
755
+ outputs += (present_key_value,)
756
+
757
+ return outputs
758
+
759
+
760
+ InternLM2_START_DOCSTRING = r"""
761
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
762
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
763
+ etc.)
764
+
765
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
766
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
767
+ and behavior.
768
+
769
+ Parameters:
770
+ config ([`InternLM2Config`]):
771
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
772
+ load the weights associated with the model, only the configuration. Check out the
773
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
774
+ """
775
+
776
+
777
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
778
+ @add_start_docstrings(
779
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
780
+ InternLM2_START_DOCSTRING,
781
+ )
782
+ class InternLM2PreTrainedModel(PreTrainedModel):
783
+ config_class = InternLM2Config
784
+ base_model_prefix = "model"
785
+ supports_gradient_checkpointing = True
786
+ _no_split_modules = ["InternLM2DecoderLayer"]
787
+ _skip_keys_device_placement = "past_key_values"
788
+
789
+ def _init_weights(self, module):
790
+ std = self.config.initializer_range
791
+ if isinstance(module, nn.Linear):
792
+ module.weight.data.normal_(mean=0.0, std=std)
793
+ if module.bias is not None:
794
+ module.bias.data.zero_()
795
+ elif isinstance(module, nn.Embedding):
796
+ module.weight.data.normal_(mean=0.0, std=std)
797
+ if module.padding_idx is not None:
798
+ module.weight.data[module.padding_idx].zero_()
799
+
800
+
801
+ InternLM2_INPUTS_DOCSTRING = r"""
802
+ Args:
803
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
804
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
805
+ it.
806
+
807
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
808
+ [`PreTrainedTokenizer.__call__`] for details.
809
+
810
+ [What are input IDs?](../glossary#input-ids)
811
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
812
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
813
+
814
+ - 1 for tokens that are **not masked**,
815
+ - 0 for tokens that are **masked**.
816
+
817
+ [What are attention masks?](../glossary#attention-mask)
818
+
819
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
820
+ [`PreTrainedTokenizer.__call__`] for details.
821
+
822
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
823
+ `past_key_values`).
824
+
825
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
826
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
827
+ information on the default strategy.
828
+
829
+ - 1 indicates the head is **not masked**,
830
+ - 0 indicates the head is **masked**.
831
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
832
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
833
+ config.n_positions - 1]`.
834
+
835
+ [What are position IDs?](../glossary#position-ids)
836
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
837
+ when `config.use_cache=True`):
838
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
839
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
840
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
841
+
842
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
843
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
844
+
845
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
846
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
847
+ of shape `(batch_size, sequence_length)`.
848
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
849
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
850
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
851
+ model's internal embedding lookup matrix.
852
+ use_cache (`bool`, *optional*):
853
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
854
+ `past_key_values`).
855
+ output_attentions (`bool`, *optional*):
856
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
857
+ tensors for more detail.
858
+ output_hidden_states (`bool`, *optional*):
859
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
860
+ more detail.
861
+ return_dict (`bool`, *optional*):
862
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
863
+ """
864
+
865
+
866
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
867
+ @add_start_docstrings(
868
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
869
+ InternLM2_START_DOCSTRING,
870
+ )
871
+ class InternLM2Model(InternLM2PreTrainedModel):
872
+ """
873
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
874
+
875
+ Args:
876
+ config: InternLM2Config
877
+ """
878
+
879
+ _auto_class = "AutoModel"
880
+
881
+ def __init__(self, config: InternLM2Config):
882
+ super().__init__(config)
883
+ self.padding_idx = config.pad_token_id
884
+ self.vocab_size = config.vocab_size
885
+ self.config = config
886
+
887
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
888
+
889
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
890
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
891
+
892
+ self.gradient_checkpointing = False
893
+ # Initialize weights and apply final processing
894
+ self.post_init()
895
+
896
+ def get_input_embeddings(self):
897
+ return self.tok_embeddings
898
+
899
+ def set_input_embeddings(self, value):
900
+ self.tok_embeddings = value
901
+
902
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
903
+ # create causal mask
904
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
905
+ combined_attention_mask = None
906
+ if input_shape[-1] > 1:
907
+ combined_attention_mask = _make_causal_mask(
908
+ input_shape,
909
+ inputs_embeds.dtype,
910
+ device=inputs_embeds.device,
911
+ past_key_values_length=past_key_values_length,
912
+ )
913
+
914
+ if attention_mask is not None:
915
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
916
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
917
+ inputs_embeds.device
918
+ )
919
+ combined_attention_mask = (
920
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
921
+ )
922
+
923
+ return combined_attention_mask
924
+
925
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
926
+ def forward(
927
+ self,
928
+ input_ids: torch.LongTensor = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ position_ids: Optional[torch.LongTensor] = None,
931
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
932
+ inputs_embeds: Optional[torch.FloatTensor] = None,
933
+ use_cache: Optional[bool] = None,
934
+ output_attentions: Optional[bool] = None,
935
+ output_hidden_states: Optional[bool] = None,
936
+ return_dict: Optional[bool] = None,
937
+ **kwargs
938
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
939
+
940
+ im_mask = kwargs.get('im_mask', None)
941
+
942
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
943
+ output_hidden_states = (
944
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
945
+ )
946
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
947
+
948
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
949
+
950
+ if self.config.attn_implementation == "flash_attention_2":
951
+ _import_flash_attn()
952
+
953
+ # retrieve input_ids and inputs_embeds
954
+ if input_ids is not None and inputs_embeds is not None:
955
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
956
+ elif input_ids is not None:
957
+ batch_size, seq_length = input_ids.shape[:2]
958
+ elif inputs_embeds is not None:
959
+ batch_size, seq_length = inputs_embeds.shape[:2]
960
+ else:
961
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
962
+
963
+ seq_length_with_past = seq_length
964
+ past_key_values_length = 0
965
+ if past_key_values is not None:
966
+ past_key_values_length = past_key_values[0][0].shape[2]
967
+ seq_length_with_past = seq_length_with_past + past_key_values_length
968
+
969
+ if position_ids is None:
970
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
971
+ position_ids = torch.arange(
972
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
973
+ )
974
+ position_ids = position_ids.unsqueeze(0)
975
+
976
+ if inputs_embeds is None:
977
+ inputs_embeds = self.tok_embeddings(input_ids)
978
+ im_mask = torch.zeros(inputs_embeds.shape[:2]).to(
979
+ inputs_embeds.device).bool()
980
+
981
+ if self.config.attn_implementation == "flash_attention_2":
982
+ # 2d mask is passed through the layers
983
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
984
+ else:
985
+ if attention_mask is None:
986
+ attention_mask = torch.ones(
987
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
988
+ )
989
+ attention_mask = self._prepare_decoder_attention_mask(
990
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
991
+ )
992
+
993
+ # embed positions
994
+ hidden_states = inputs_embeds
995
+
996
+ if self.gradient_checkpointing and self.training:
997
+ if use_cache:
998
+ logger.warning_once(
999
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1000
+ )
1001
+ use_cache = False
1002
+
1003
+ # decoder layers
1004
+ all_hidden_states = () if output_hidden_states else None
1005
+ all_self_attns = () if output_attentions else None
1006
+ next_decoder_cache = () if use_cache else None
1007
+
1008
+ for idx, decoder_layer in enumerate(self.layers):
1009
+ if output_hidden_states:
1010
+ all_hidden_states += (hidden_states,)
1011
+
1012
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1013
+
1014
+ if self.gradient_checkpointing and self.training:
1015
+
1016
+ def create_custom_forward(module):
1017
+ def custom_forward(*inputs):
1018
+ # None for past_key_value
1019
+ return module(*inputs, output_attentions, None, im_mask)
1020
+
1021
+ return custom_forward
1022
+
1023
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1024
+ create_custom_forward(decoder_layer),
1025
+ hidden_states,
1026
+ attention_mask,
1027
+ position_ids,
1028
+ None,
1029
+ )
1030
+ else:
1031
+ layer_outputs = decoder_layer(
1032
+ hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ past_key_value=past_key_value,
1036
+ output_attentions=output_attentions,
1037
+ use_cache=use_cache,
1038
+ im_mask=im_mask,
1039
+ )
1040
+
1041
+ hidden_states = layer_outputs[0]
1042
+
1043
+ if use_cache:
1044
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1045
+
1046
+ if output_attentions:
1047
+ all_self_attns += (layer_outputs[1],)
1048
+
1049
+ hidden_states = self.norm(hidden_states)
1050
+
1051
+ # add hidden states from the last decoder layer
1052
+ if output_hidden_states:
1053
+ all_hidden_states += (hidden_states,)
1054
+
1055
+ next_cache = next_decoder_cache if use_cache else None
1056
+ if not return_dict:
1057
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1058
+ return BaseModelOutputWithPast(
1059
+ last_hidden_state=hidden_states,
1060
+ past_key_values=next_cache,
1061
+ hidden_states=all_hidden_states,
1062
+ attentions=all_self_attns,
1063
+ )
1064
+
1065
+
1066
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
1067
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1068
+ _auto_class = "AutoModelForCausalLM"
1069
+
1070
+ _tied_weights_keys = ["output.weight"]
1071
+
1072
+ def __init__(self, config):
1073
+ super().__init__(config)
1074
+ self.model = InternLM2Model(config)
1075
+ self.vocab_size = config.vocab_size
1076
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1077
+
1078
+ # Initialize weights and apply final processing
1079
+ self.post_init()
1080
+
1081
+ def get_input_embeddings(self):
1082
+ return self.model.tok_embeddings
1083
+
1084
+ def set_input_embeddings(self, value):
1085
+ self.model.tok_embeddings = value
1086
+
1087
+ def get_output_embeddings(self):
1088
+ return self.output
1089
+
1090
+ def set_output_embeddings(self, new_embeddings):
1091
+ self.output = new_embeddings
1092
+
1093
+ def set_decoder(self, decoder):
1094
+ self.model = decoder
1095
+
1096
+ def get_decoder(self):
1097
+ return self.model
1098
+
1099
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1100
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1101
+ def forward(
1102
+ self,
1103
+ input_ids: torch.LongTensor = None,
1104
+ attention_mask: Optional[torch.Tensor] = None,
1105
+ position_ids: Optional[torch.LongTensor] = None,
1106
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1107
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1108
+ labels: Optional[torch.LongTensor] = None,
1109
+ use_cache: Optional[bool] = None,
1110
+ output_attentions: Optional[bool] = None,
1111
+ output_hidden_states: Optional[bool] = None,
1112
+ return_dict: Optional[bool] = None,
1113
+ im_mask: Optional[torch.Tensor] = None,
1114
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1115
+ r"""
1116
+ Args:
1117
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1118
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1119
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1120
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1121
+
1122
+ Returns:
1123
+
1124
+ Example:
1125
+
1126
+ ```python
1127
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1128
+
1129
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1130
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1131
+
1132
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1133
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1134
+
1135
+ >>> # Generate
1136
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1137
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1138
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1139
+ ```"""
1140
+
1141
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1142
+ output_hidden_states = (
1143
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1144
+ )
1145
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1146
+
1147
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1148
+ outputs = self.model(
1149
+ input_ids=input_ids,
1150
+ attention_mask=attention_mask,
1151
+ position_ids=position_ids,
1152
+ past_key_values=past_key_values,
1153
+ inputs_embeds=inputs_embeds,
1154
+ use_cache=use_cache,
1155
+ output_attentions=output_attentions,
1156
+ output_hidden_states=output_hidden_states,
1157
+ return_dict=return_dict,
1158
+ im_mask=im_mask,
1159
+ )
1160
+
1161
+ hidden_states = outputs[0]
1162
+ logits = self.output(hidden_states)
1163
+ logits = logits.float()
1164
+
1165
+ loss = None
1166
+ if labels is not None:
1167
+ # Shift so that tokens < n predict n
1168
+ shift_logits = logits[..., :-1, :].contiguous()
1169
+ shift_labels = labels[..., 1:].contiguous()
1170
+ # Flatten the tokens
1171
+ loss_fct = CrossEntropyLoss()
1172
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1173
+ shift_labels = shift_labels.view(-1)
1174
+ # Enable model parallelism
1175
+ shift_labels = shift_labels.to(shift_logits.device)
1176
+ loss = loss_fct(shift_logits, shift_labels)
1177
+
1178
+ if not return_dict:
1179
+ output = (logits,) + outputs[1:]
1180
+ return (loss,) + output if loss is not None else output
1181
+
1182
+ return CausalLMOutputWithPast(
1183
+ loss=loss,
1184
+ logits=logits,
1185
+ past_key_values=outputs.past_key_values,
1186
+ hidden_states=outputs.hidden_states,
1187
+ attentions=outputs.attentions,
1188
+ )
1189
+
1190
+ def prepare_inputs_for_generation(
1191
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1192
+ ):
1193
+ if past_key_values is not None:
1194
+ past_length = past_key_values[0][0].shape[2]
1195
+
1196
+ # Some generation methods already pass only the last input ID
1197
+ if input_ids.shape[1] > past_length:
1198
+ remove_prefix_length = past_length
1199
+ else:
1200
+ # Default to old behavior: keep only final ID
1201
+ remove_prefix_length = input_ids.shape[1] - 1
1202
+
1203
+ input_ids = input_ids[:, remove_prefix_length:]
1204
+
1205
+ position_ids = kwargs.get("position_ids", None)
1206
+ if attention_mask is not None and position_ids is None:
1207
+ # create position_ids on the fly for batch generation
1208
+ position_ids = attention_mask.long().cumsum(-1) - 1
1209
+ position_ids.masked_fill_(attention_mask == 0, 1)
1210
+ if past_key_values:
1211
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1212
+
1213
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1214
+ if inputs_embeds is not None and past_key_values is None:
1215
+ model_inputs = {"inputs_embeds": inputs_embeds}
1216
+ else:
1217
+ model_inputs = {"input_ids": input_ids}
1218
+
1219
+ model_inputs.update(
1220
+ {
1221
+ "position_ids": position_ids,
1222
+ "past_key_values": past_key_values,
1223
+ "use_cache": kwargs.get("use_cache"),
1224
+ "attention_mask": attention_mask,
1225
+ "im_mask": kwargs.get("im_mask", None),
1226
+ }
1227
+ )
1228
+ return model_inputs
1229
+
1230
+ @staticmethod
1231
+ def _reorder_cache(past_key_values, beam_idx):
1232
+ reordered_past = ()
1233
+ for layer_past in past_key_values:
1234
+ reordered_past += (
1235
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1236
+ )
1237
+ return reordered_past
1238
+
1239
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1240
+ if tokenizer.add_bos_token:
1241
+ prompt = ""
1242
+ else:
1243
+ prompt = tokenizer.bos_token
1244
+ if meta_instruction:
1245
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1246
+ for record in history:
1247
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1248
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1249
+ return tokenizer([prompt], return_tensors="pt")
1250
+
1251
+ @torch.no_grad()
1252
+ def chat(
1253
+ self,
1254
+ tokenizer,
1255
+ query: str,
1256
+ history: List[Tuple[str, str]] = [],
1257
+ streamer: Optional[BaseStreamer] = None,
1258
+ max_new_tokens: int = 1024,
1259
+ do_sample: bool = True,
1260
+ temperature: float = 0.8,
1261
+ top_p: float = 0.8,
1262
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1263
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1264
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1265
+ **kwargs,
1266
+ ):
1267
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1268
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1269
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1270
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1271
+ outputs = self.generate(
1272
+ **inputs,
1273
+ streamer=streamer,
1274
+ max_new_tokens=max_new_tokens,
1275
+ do_sample=do_sample,
1276
+ temperature=temperature,
1277
+ top_p=top_p,
1278
+ eos_token_id=eos_token_id,
1279
+ **kwargs,
1280
+ )
1281
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1282
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1283
+ response = response.split("<|im_end|>")[0]
1284
+ history = history + [(query, response)]
1285
+ return response, history
1286
+
1287
+ @torch.no_grad()
1288
+ def stream_chat(
1289
+ self,
1290
+ tokenizer,
1291
+ query: str,
1292
+ history: List[Tuple[str, str]] = [],
1293
+ max_new_tokens: int = 1024,
1294
+ do_sample: bool = True,
1295
+ temperature: float = 0.8,
1296
+ top_p: float = 0.8,
1297
+ **kwargs,
1298
+ ):
1299
+ """
1300
+ Return a generator in format: (response, history)
1301
+ Eg.
1302
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1303
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1304
+ """
1305
+ if BaseStreamer is None:
1306
+ raise ModuleNotFoundError(
1307
+ "The version of `transformers` is too low. Please make sure "
1308
+ "that you have installed `transformers>=4.28.0`."
1309
+ )
1310
+
1311
+ response_queue = queue.Queue(maxsize=20)
1312
+
1313
+ class ChatStreamer(BaseStreamer):
1314
+ def __init__(self, tokenizer) -> None:
1315
+ super().__init__()
1316
+ self.tokenizer = tokenizer
1317
+ self.queue = response_queue
1318
+ self.query = query
1319
+ self.history = history
1320
+ self.response = ""
1321
+ self.cache = []
1322
+ self.received_inputs = False
1323
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1324
+
1325
+ def put(self, value):
1326
+ if len(value.shape) > 1 and value.shape[0] > 1:
1327
+ raise ValueError("ChatStreamer only supports batch size 1")
1328
+ elif len(value.shape) > 1:
1329
+ value = value[0]
1330
+
1331
+ if not self.received_inputs:
1332
+ # The first received value is input_ids, ignore here
1333
+ self.received_inputs = True
1334
+ return
1335
+
1336
+ self.cache.extend(value.tolist())
1337
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1338
+ if token.strip() != "<|im_end|>":
1339
+ self.response = self.response + token
1340
+ history = self.history + [(self.query, self.response)]
1341
+ self.queue.put((self.response, history))
1342
+ self.cache = []
1343
+ else:
1344
+ self.end()
1345
+
1346
+ def end(self):
1347
+ self.queue.put(None)
1348
+
1349
+ def stream_producer():
1350
+ return self.chat(
1351
+ tokenizer=tokenizer,
1352
+ query=query,
1353
+ streamer=ChatStreamer(tokenizer=tokenizer),
1354
+ history=history,
1355
+ max_new_tokens=max_new_tokens,
1356
+ do_sample=do_sample,
1357
+ temperature=temperature,
1358
+ top_p=top_p,
1359
+ **kwargs,
1360
+ )
1361
+
1362
+ def consumer():
1363
+ producer = threading.Thread(target=stream_producer)
1364
+ producer.start()
1365
+ while True:
1366
+ res = response_queue.get()
1367
+ if res is None:
1368
+ return
1369
+ yield res
1370
+
1371
+ return consumer()
1372
+
1373
+
1374
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1375
+ @add_start_docstrings(
1376
+ """
1377
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1378
+
1379
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1380
+ as other causal models (e.g. GPT-2) do.
1381
+
1382
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1383
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1384
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1385
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1386
+ each row of the batch).
1387
+ """,
1388
+ InternLM2_START_DOCSTRING,
1389
+ )
1390
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1391
+ def __init__(self, config):
1392
+ super().__init__(config)
1393
+ self.num_labels = config.num_labels
1394
+ self.model = InternLM2Model(config)
1395
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1396
+
1397
+ # Initialize weights and apply final processing
1398
+ self.post_init()
1399
+
1400
+ def get_input_embeddings(self):
1401
+ return self.model.tok_embeddings
1402
+
1403
+ def set_input_embeddings(self, value):
1404
+ self.model.tok_embeddings = value
1405
+
1406
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1407
+ def forward(
1408
+ self,
1409
+ input_ids: torch.LongTensor = None,
1410
+ attention_mask: Optional[torch.Tensor] = None,
1411
+ position_ids: Optional[torch.LongTensor] = None,
1412
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1414
+ labels: Optional[torch.LongTensor] = None,
1415
+ use_cache: Optional[bool] = None,
1416
+ output_attentions: Optional[bool] = None,
1417
+ output_hidden_states: Optional[bool] = None,
1418
+ return_dict: Optional[bool] = None,
1419
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1420
+ r"""
1421
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1422
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1423
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1424
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1425
+ """
1426
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1427
+
1428
+ transformer_outputs = self.model(
1429
+ input_ids,
1430
+ attention_mask=attention_mask,
1431
+ position_ids=position_ids,
1432
+ past_key_values=past_key_values,
1433
+ inputs_embeds=inputs_embeds,
1434
+ use_cache=use_cache,
1435
+ output_attentions=output_attentions,
1436
+ output_hidden_states=output_hidden_states,
1437
+ return_dict=return_dict,
1438
+ )
1439
+ hidden_states = transformer_outputs[0]
1440
+ logits = self.score(hidden_states)
1441
+
1442
+ if input_ids is not None:
1443
+ batch_size = input_ids.shape[0]
1444
+ else:
1445
+ batch_size = inputs_embeds.shape[0]
1446
+
1447
+ if self.config.pad_token_id is None and batch_size != 1:
1448
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1449
+ if self.config.pad_token_id is None:
1450
+ sequence_lengths = -1
1451
+ else:
1452
+ if input_ids is not None:
1453
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1454
+ logits.device
1455
+ )
1456
+ else:
1457
+ sequence_lengths = -1
1458
+
1459
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1460
+
1461
+ loss = None
1462
+ if labels is not None:
1463
+ labels = labels.to(logits.device)
1464
+ if self.config.problem_type is None:
1465
+ if self.num_labels == 1:
1466
+ self.config.problem_type = "regression"
1467
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1468
+ self.config.problem_type = "single_label_classification"
1469
+ else:
1470
+ self.config.problem_type = "multi_label_classification"
1471
+
1472
+ if self.config.problem_type == "regression":
1473
+ loss_fct = MSELoss()
1474
+ if self.num_labels == 1:
1475
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1476
+ else:
1477
+ loss = loss_fct(pooled_logits, labels)
1478
+ elif self.config.problem_type == "single_label_classification":
1479
+ loss_fct = CrossEntropyLoss()
1480
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1481
+ elif self.config.problem_type == "multi_label_classification":
1482
+ loss_fct = BCEWithLogitsLoss()
1483
+ loss = loss_fct(pooled_logits, labels)
1484
+ if not return_dict:
1485
+ output = (pooled_logits,) + transformer_outputs[1:]
1486
+ return ((loss,) + output) if loss is not None else output
1487
+
1488
+ return SequenceClassifierOutputWithPast(
1489
+ loss=loss,
1490
+ logits=pooled_logits,
1491
+ past_key_values=transformer_outputs.past_key_values,
1492
+ hidden_states=transformer_outputs.hidden_states,
1493
+ attentions=transformer_outputs.attentions,
1494
+ )
modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
modeling_wemm.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers import PreTrainedModel
7
+ from transformers.activations import ACT2FN
8
+ from transformers.cache_utils import Cache
9
+ from transformers.modeling_outputs import ModelOutput
10
+ from transformers.utils import (
11
+ add_start_docstrings,
12
+ add_start_docstrings_to_model_forward,
13
+ logging,
14
+ replace_return_docstrings,
15
+ )
16
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig
17
+ from .configuration_wemm import WeMMConfig
18
+ from .vision_model import Idefics2VisionTransformer
19
+ from .connector import Idefics2Connector
20
+ from .image_processor import Idefics2ImageProcessor
21
+ from .modeling_downsampler import DownsamplerModel
22
+ from .modeling_projector import ProjectorModel
23
+ from .modeling_internlm2 import InternLM2ForCausalLM
24
+ from .tokenization_internlm2 import InternLM2Tokenizer
25
+ from peft import PeftModel
26
+ from peft import PeftConfig
27
+ import os
28
+ from PIL import Image
29
+ import numpy as np
30
+ IMAGE_TOKEN_INDEX = -200
31
+ DEFAULT_IMAGE_TOKEN = "<image>"
32
+ IGNORE_INDEX = -100
33
+ from transformers import StoppingCriteria
34
+ from transformers import PreTrainedTokenizerFast, StoppingCriteriaList
35
+ import torch.nn.functional as F
36
+ class StopWordStoppingCriteria(StoppingCriteria):
37
+ """StopWord stopping criteria."""
38
+ def __init__(self, tokenizer, stop_word):
39
+ self.tokenizer = tokenizer
40
+ self.stop_word = stop_word
41
+ self.length = len(self.stop_word)
42
+ def __call__(self, input_ids, *args, **kwargs) -> bool:
43
+ cur_text = self.tokenizer.decode(input_ids[0])
44
+ cur_text = cur_text.replace('\r', '').replace('\n', '')
45
+ return cur_text[-self.length:] == self.stop_word
46
+ def get_stop_criteria(
47
+ tokenizer,
48
+ stop_words=[],
49
+ ):
50
+ stop_criteria = StoppingCriteriaList()
51
+ for word in stop_words:
52
+ stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
53
+ return stop_criteria
54
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
55
+ assert embed_dim % 2 == 0
56
+ # use half of dimensions to encode grid_h
57
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H, W, D/2)
58
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H, W, D/2)
59
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
60
+ return emb
61
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
62
+ """
63
+ embed_dim: output dimension for each position
64
+ pos: a list of positions to be encoded: size (M,)
65
+ out: (M, D)
66
+ """
67
+ assert embed_dim % 2 == 0
68
+ omega = np.arange(embed_dim // 2, dtype=np.float)
69
+ omega /= embed_dim / 2.
70
+ omega = 1. / 10000**omega # (D/2,)
71
+ pos = np.squeeze(pos) # (1, H, W) -> (H, W)
72
+ out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
73
+ emb_sin = np.sin(out) # (H, W, D/2)
74
+ emb_cos = np.cos(out) # (H, W, D/2)
75
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
76
+ return emb
77
+ # 2D sine-cosine position embedding
78
+ # References:
79
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
80
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
81
+ # --------------------------------------------------------
82
+ def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False):
83
+ """
84
+ grid_size: int of the grid height and width
85
+ return:
86
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
87
+ """
88
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
89
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
90
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
91
+ grid = np.stack(grid, axis=0)
92
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
93
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
94
+ if cls_token:
95
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
96
+ return pos_embed
97
+ def recover_navit_subimages_with_pos_emb(
98
+ sub_image_hidden_states,
99
+ attention_mask,
100
+ num_sub_images,
101
+ visual_embedding_group,
102
+ pos_hidden_size,
103
+ thumbnail_only=False):
104
+ _slice = int(np.sqrt(num_sub_images))
105
+ N, L, D = sub_image_hidden_states.shape
106
+ _, H, W = attention_mask.shape
107
+ if thumbnail_only is True:
108
+ num_sub_images += 1
109
+ sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D)
110
+ attention_mask = attention_mask.reshape(-1, num_sub_images, H, W)
111
+ if thumbnail_only is True:
112
+ sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :]
113
+ attention_mask = attention_mask[:, -1:, :, :]
114
+ _slice = 1
115
+ def _infer_ori_image_patch_shape(sub_image_attention_mask):
116
+ ind_h, ind_w = torch.where(sub_image_attention_mask > 0)
117
+ return torch.max(ind_h) + 1, torch.max(ind_w) + 1
118
+ def _pad_to_same(image_hidden):
119
+ _dtype = image_hidden.dtype
120
+ visual_downsample_stride = int(np.sqrt(visual_embedding_group))
121
+ full_h, full_w, _ = image_hidden.shape
122
+ target_h, target_w = H * _slice, W * _slice
123
+ # ensure all contents are included during downsampling
124
+ to_pad_h = (target_h - full_h) + (
125
+ visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride
126
+ to_pad_w = (target_w - full_w) + (
127
+ visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride
128
+ # (H,W,D) -> (1,D,H,W) to support replicate padding
129
+ image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0)
130
+ pad_size = (0, to_pad_w, 0, to_pad_h)
131
+ # (1,D,H,W) -> (H,W,D)
132
+ image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0)
133
+ return image_hidden.to(_dtype)
134
+ image_hidden_states = list()
135
+ valid_image_token = list()
136
+ image_2d_pos = list()
137
+ for batch_id in range(len(sub_image_hidden_states)):
138
+ ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0])
139
+ full_h, full_w = ori_h * _slice, ori_w * _slice
140
+ # (S,H,W,D) -> (S_h,S_w,H,W,D) -> (S_h,H,S_w,W,D) -> (S_h*H,S_w*W,D)
141
+ this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \
142
+ .view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D)
143
+ pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h,
144
+ grid_size_w=full_w) # (H, W, D)
145
+ pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device)
146
+ image_hidden_states.append(_pad_to_same(this_image_hidden))
147
+ image_2d_pos.append(_pad_to_same(pos_emb))
148
+ valid_image_token.append([full_h, full_w])
149
+ image_hidden_states = torch.stack(image_hidden_states)
150
+ image_2d_pos = torch.stack(image_2d_pos)
151
+ valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64)
152
+ return image_hidden_states, image_2d_pos, valid_image_token
153
+ def visiual_token_downsample(
154
+ visual_downsampler,
155
+ image_hidden_states,
156
+ valid_image_token,
157
+ visual_embedding_group,
158
+ image_2d_pos):
159
+ if image_2d_pos is not None:
160
+ image_hidden_states = image_hidden_states + image_2d_pos
161
+ image_hidden_states = visual_downsampler(image_hidden_states)
162
+ valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64)
163
+ return image_hidden_states, valid_image_token
164
+ def merge_native_qformer(
165
+ clip_embeddings_native_patch,
166
+ valid_image_token_shape,
167
+ clip_embeddings_qformer,
168
+ visual_source_spliter,
169
+ num_sub_images):
170
+ assert clip_embeddings_native_patch.size(0) == valid_image_token_shape.size(0) == clip_embeddings_qformer.size(0)
171
+ def add_split_token_for_qformer_token(qformer_emb):
172
+ # + 1 for thumbnail
173
+ len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1))
174
+ qformer_emb_with_spliter = list()
175
+ for i in range(num_sub_images + 1):
176
+ qformer_emb_with_spliter.append(
177
+ visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device))
178
+ )
179
+ qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token])
180
+ qformer_emb_with_spliter.append(
181
+ visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device))
182
+ )
183
+ return torch.cat(qformer_emb_with_spliter, dim=0)
184
+ merged_visual_embeddings = list()
185
+ for batch_id in range(clip_embeddings_native_patch.size(0)):
186
+ h, w = valid_image_token_shape[batch_id]
187
+ native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1)
188
+ qformer_emb = clip_embeddings_qformer[batch_id]
189
+ qformer_emb = add_split_token_for_qformer_token(qformer_emb)
190
+ merged_visual_embeddings.append(
191
+ torch.cat(
192
+ [visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)),
193
+ native_patch_emb,
194
+ visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)),
195
+ qformer_emb],
196
+ dim=0))
197
+ return merged_visual_embeddings
198
+ class WemmForConditionalGeneration(PreTrainedModel):
199
+ config_class = WeMMConfig
200
+ def __init__(self, config: WeMMConfig):
201
+ super().__init__(config)
202
+ self.vision_tower = Idefics2VisionTransformer(config.vision_config)
203
+ self.image_processor = Idefics2ImageProcessor(config.image_processor)
204
+ self.connector = Idefics2Connector(config.connector_config)
205
+ self.projector = ProjectorModel(config.projector_config)
206
+ self.language_model = InternLM2ForCausalLM(config.text_config)
207
+ self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm2-chat-7b", trust_remote_code=True, encode_special_tokens=True)
208
+ self.downsampler = DownsamplerModel(config.downsampler_config)
209
+ self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config)
210
+ self.gen_config = GenerationConfig(
211
+ max_new_tokens=512,
212
+ do_sample=False,
213
+ eos_token_id=self.tokenizer.eos_token_id,
214
+ pad_token_id=self.tokenizer.pad_token_id
215
+ if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
216
+ )
217
+ self.do_image_splitting = config.do_image_splitting
218
+ self.stop_criteria = get_stop_criteria(
219
+ tokenizer=self.tokenizer, stop_words=['<|im_end|>'])
220
+ self.config = config
221
+ def mm_generate(self, image_path, prompt, gen_config=None):
222
+ prompt = "<image>" + '\n' + prompt
223
+ prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n"
224
+ image = Image.open(image_path).convert('RGB')
225
+ navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting)
226
+ batch_size_navit = navit980_images['pixel_values'].shape[0]
227
+ navit_pixel_values = navit980_images['navit_pixel_values'].cuda()
228
+ navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda()
229
+ clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state
230
+ super_image_hidden_states, image_2d_pos, valid_image_token_shape = \
231
+ recover_navit_subimages_with_pos_emb(
232
+ clip_visual_outputs, navit_patch_attention_mask, num_sub_images=4,
233
+ visual_embedding_group=1,
234
+ pos_hidden_size=4096,
235
+ thumbnail_only=True
236
+ )
237
+ clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample(
238
+ self.downsampler,
239
+ super_image_hidden_states, valid_image_token_shape,
240
+ visual_embedding_group=1, image_2d_pos=None
241
+ )
242
+ clip_embeddings_qformer = self.connector(clip_visual_outputs, attention_mask=navit_patch_attention_mask.view(navit_pixel_values.size(0), -1))
243
+ hidden_size = clip_embeddings_qformer.shape[-1]
244
+ clip_embeddings_qformer = clip_embeddings_qformer.view(batch_size_navit, -1, hidden_size)
245
+ clip_embeddings_qformer = self.projector(clip_embeddings_qformer)
246
+ merged_visual_embeddings = \
247
+ merge_native_qformer(
248
+ clip_embeddings_native_patch,
249
+ valid_image_token_shape,
250
+ clip_embeddings_qformer,
251
+ visual_source_spliter=self.visual_source_spliter_emb,
252
+ num_sub_images=4
253
+ )
254
+ chunk_encode = []
255
+ for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)):
256
+ if idx == 0:
257
+ cur_encode = self.tokenizer.encode(chunk)
258
+ else:
259
+ cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
260
+ chunk_encode.append(cur_encode)
261
+ assert len(chunk_encode) == 2
262
+ ids = []
263
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
264
+ ids.extend(cur_chunk_encode)
265
+ if idx != len(chunk_encode) - 1:
266
+ ids.append(IMAGE_TOKEN_INDEX)
267
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
268
+ pixel_values = None
269
+ mm_inputs = self.prepare_inputs_labels_for_multimodal(
270
+ llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings)
271
+ generate_output = self.language_model.generate(
272
+ **mm_inputs,
273
+ generation_config=gen_config if gen_config is not None else self.gen_config,
274
+ streamer=None,
275
+ bos_token_id=self.tokenizer.bos_token_id,
276
+ stopping_criteria=self.stop_criteria
277
+ )
278
+ predict = self.tokenizer.decode(
279
+ generate_output[0], skip_special_tokens=True).strip()
280
+ return predict
281
+ def get_valid_visual_embedding(self, embedding, valid_token_shape):
282
+ if valid_token_shape is None:
283
+ return embedding
284
+ h, w = valid_token_shape
285
+ return embedding[:h, :w, :].reshape(h*w, -1)
286
+ # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
287
+ def prepare_inputs_labels_for_multimodal(
288
+ self,
289
+ llm: PreTrainedModel,
290
+ input_ids: torch.LongTensor = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ attention_mask: Optional[torch.Tensor] = None,
293
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
294
+ labels: Optional[torch.LongTensor] = None,
295
+ pixel_values: Optional[torch.FloatTensor] = None,
296
+ clip_embeddings: Optional[torch.FloatTensor] = None,
297
+ hard_coded_max_len: Optional[int] = None,
298
+ **kwargs):
299
+ if pixel_values is None and clip_embeddings is None:
300
+ return {
301
+ 'input_ids': input_ids,
302
+ 'position_ids': position_ids,
303
+ 'attention_mask': attention_mask,
304
+ 'past_key_values': past_key_values,
305
+ 'inputs_embeds': None,
306
+ 'labels': labels
307
+ }
308
+ valid_image_token_shape = kwargs.get('valid_image_token_shape', None)
309
+ _labels = labels
310
+ _position_ids = position_ids
311
+ _attention_mask = attention_mask
312
+ if attention_mask is None:
313
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
314
+ else:
315
+ attention_mask = attention_mask.bool()
316
+ if position_ids is None:
317
+ position_ids = torch.arange(
318
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
319
+ if labels is None:
320
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
321
+ # remove the padding using attention_mask -- TODO: double check
322
+ input_ids = [
323
+ cur_input_ids[cur_attention_mask]
324
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
325
+ ]
326
+ labels = [
327
+ cur_labels[cur_attention_mask]
328
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
329
+ ]
330
+ new_inputs_embeds = []
331
+ new_labels = []
332
+ new_img_masks = []
333
+ cur_image_idx = 0
334
+ for batch_idx, cur_input_ids in enumerate(input_ids):
335
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
336
+ if num_images == 0:
337
+ cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
338
+ cur_clip_emb = self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) if clip_embeddings is not None else None
339
+ cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
340
+ if cur_clip_emb is not None and cur_pixel_values is not None:
341
+ cur_inputs_embeds = torch.cat(
342
+ [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0)
343
+ elif cur_pixel_values is not None:
344
+ cur_inputs_embeds = torch.cat(
345
+ [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
346
+ elif cur_clip_emb is not None:
347
+ cur_inputs_embeds = torch.cat(
348
+ [cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0)
349
+ else:
350
+ raise ValueError
351
+ new_inputs_embeds.append(cur_inputs_embeds)
352
+ new_labels.append(labels[batch_idx])
353
+ new_img_masks.append(torch.zeros(
354
+ cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool())
355
+ cur_image_idx += 1
356
+ continue
357
+ image_token_indices = [-1] + torch.where(
358
+ cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
359
+ cur_input_ids.shape[0]
360
+ ]
361
+ cur_input_ids_noim = []
362
+ cur_labels = labels[batch_idx]
363
+ cur_labels_noim = []
364
+ for i in range(len(image_token_indices) - 1):
365
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] +
366
+ 1:image_token_indices[i +
367
+ 1]])
368
+ cur_labels_noim.append(cur_labels[image_token_indices[i] +
369
+ 1:image_token_indices[i + 1]])
370
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
371
+ cur_inputs_embeds = llm.get_input_embeddings()(
372
+ torch.cat(cur_input_ids_noim))
373
+ cur_inputs_embeds_no_im = torch.split(
374
+ cur_inputs_embeds, split_sizes, dim=0)
375
+ cur_new_inputs_embeds = []
376
+ cur_new_labels = []
377
+ cur_img_masks = []
378
+ for i in range(num_images + 1):
379
+ cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
380
+ cur_new_labels.append(cur_labels_noim[i])
381
+ cur_img_masks.append(torch.zeros(
382
+ cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool())
383
+ if i < num_images:
384
+ cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
385
+ if(valid_image_token_shape is not None):
386
+ cur_clip_emb = \
387
+ self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) \
388
+ if clip_embeddings is not None else None
389
+ else:
390
+ cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None
391
+ cur_image_idx += 1
392
+ # discrete token embeddings
393
+ if cur_pixel_values is not None:
394
+ cur_new_inputs_embeds.append(cur_pixel_values)
395
+ cur_img_masks.append(torch.ones(
396
+ cur_pixel_values.shape[0], device=cur_pixel_values.device).bool())
397
+ cur_new_labels.append(
398
+ torch.full((cur_pixel_values.shape[0], ),
399
+ IGNORE_INDEX,
400
+ device=cur_labels.device,
401
+ dtype=cur_labels.dtype))
402
+ # clip embeddings
403
+ if cur_clip_emb is not None:
404
+ cur_new_inputs_embeds.append(cur_clip_emb)
405
+ cur_img_masks.append(torch.zeros(
406
+ cur_clip_emb.shape[0], device=cur_clip_emb.device).bool())
407
+ cur_new_labels.append(
408
+ torch.full((cur_clip_emb.shape[0],),
409
+ IGNORE_INDEX,
410
+ device=cur_labels.device,
411
+ dtype=cur_labels.dtype))
412
+ cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
413
+ cur_new_labels = torch.cat(cur_new_labels)
414
+ cur_img_masks = torch.cat(cur_img_masks)
415
+ new_inputs_embeds.append(cur_new_inputs_embeds)
416
+ new_labels.append(cur_new_labels)
417
+ new_img_masks.append(cur_img_masks)
418
+ # Combine them
419
+ max_len = max(x.shape[0] for x in new_inputs_embeds)
420
+ if hard_coded_max_len is not None:
421
+ max_len = min(max_len, hard_coded_max_len)
422
+ batch_size = len(new_inputs_embeds)
423
+ new_inputs_embeds_padded = []
424
+ new_labels_padded = torch.full((batch_size, max_len),
425
+ IGNORE_INDEX,
426
+ dtype=new_labels[0].dtype,
427
+ device=new_labels[0].device)
428
+ attention_mask = torch.zeros((batch_size, max_len),
429
+ dtype=attention_mask.dtype,
430
+ device=attention_mask.device)
431
+ position_ids = torch.zeros((batch_size, max_len),
432
+ dtype=position_ids.dtype,
433
+ device=position_ids.device)
434
+ new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool()
435
+ for i, (cur_new_embed,
436
+ cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)):
437
+ cur_new_embed = cur_new_embed[:max_len]
438
+ cur_new_labels = cur_new_labels[:max_len]
439
+ cur_new_img_masks = cur_new_img_masks[:max_len]
440
+ cur_len = cur_new_embed.shape[0]
441
+ new_inputs_embeds_padded.append(
442
+ torch.cat((cur_new_embed,
443
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]),
444
+ dtype=cur_new_embed.dtype,
445
+ device=cur_new_embed.device)),
446
+ dim=0))
447
+ if cur_len > 0:
448
+ new_labels_padded[i, :cur_len] = cur_new_labels
449
+ attention_mask[i, :cur_len] = True
450
+ position_ids[i, :cur_len] = torch.arange(
451
+ 0,
452
+ cur_len,
453
+ dtype=position_ids.dtype,
454
+ device=position_ids.device)
455
+ new_img_masks_padded[i, :cur_len] = cur_new_img_masks
456
+ new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
457
+ if _labels is None:
458
+ new_labels = None
459
+ else:
460
+ new_labels = new_labels_padded
461
+ if _attention_mask is None:
462
+ attention_mask = None
463
+ else:
464
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
465
+ if _position_ids is None:
466
+ position_ids = None
467
+ prepared_data = {
468
+ 'input_ids': None,
469
+ 'position_ids': position_ids,
470
+ 'attention_mask': attention_mask,
471
+ 'past_key_values': past_key_values,
472
+ 'inputs_embeds': new_inputs_embeds,
473
+ 'labels': new_labels,
474
+ }
475
+ if pixel_values is not None:
476
+ prepared_data.update({'im_mask': new_img_masks_padded})
477
+ return prepared_data
478
+ AutoConfig.register("wemm_hf", WeMMConfig)
479
+ AutoModel.register(WeMMConfig, WemmForConditionalGeneration)
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "</s>",
5
+ "unk_token": "<unk>"
6
+ }
tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ return decoders.Sequence(
60
+ [
61
+ decoders.Replace("▁", " "),
62
+ decoders.ByteFallback(),
63
+ decoders.Fuse(),
64
+ decoders.Strip(content=" ", left=1),
65
+ ]
66
+ )
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
tokenizer_config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_internlm2.InternLM2Tokenizer",
5
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
6
+ ]
7
+ },
8
+ "bos_token": "<s>",
9
+ "clean_up_tokenization_spaces": false,
10
+ "eos_token": "</s>",
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "pad_token": "</s>",
13
+ "tokenizer_class": "InternLM2Tokenizer",
14
+ "unk_token": "<unk>",
15
+ "added_tokens_decoder": {
16
+ "0": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false,
22
+ "special": true
23
+ },
24
+ "1": {
25
+ "content": "<s>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false,
30
+ "special": true
31
+ },
32
+ "2": {
33
+ "content": "</s>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false,
38
+ "special": true
39
+ },
40
+ "92543": {
41
+ "content": "<|im_start|>",
42
+ "lstrip": false,
43
+ "normalized": false,
44
+ "rstrip": false,
45
+ "single_word": false,
46
+ "special": true
47
+ },
48
+ "92542": {
49
+ "content": "<|im_end|>",
50
+ "lstrip": false,
51
+ "normalized": false,
52
+ "rstrip": false,
53
+ "single_word": false,
54
+ "special": true
55
+ },
56
+ "92541": {
57
+ "content": "<|action_start|>",
58
+ "lstrip": false,
59
+ "normalized": false,
60
+ "rstrip": false,
61
+ "single_word": false,
62
+ "special": true
63
+ },
64
+ "92540": {
65
+ "content": "<|action_end|>",
66
+ "lstrip": false,
67
+ "normalized": false,
68
+ "rstrip": false,
69
+ "single_word": false,
70
+ "special": true
71
+ },
72
+ "92539": {
73
+ "content": "<|interpreter|>",
74
+ "lstrip": false,
75
+ "normalized": false,
76
+ "rstrip": false,
77
+ "single_word": false,
78
+ "special": true
79
+ },
80
+ "92538": {
81
+ "content": "<|plugin|>",
82
+ "lstrip": false,
83
+ "normalized": false,
84
+ "rstrip": false,
85
+ "single_word": false,
86
+ "special": true
87
+ }
88
+ },
89
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
90
+ }
vision_model.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+
3
+ import inspect
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
18
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
19
+ from transformers.utils import (
20
+ add_start_docstrings,
21
+ add_start_docstrings_to_model_forward,
22
+ is_flash_attn_2_available,
23
+ is_flash_attn_greater_or_equal_2_10,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+
32
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
33
+
34
+
35
+ class Idefics2VisionConfig(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`Idefics2VisionModel`]. It is used to instantiate a
38
+ Idefics2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
39
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
40
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics2 model
41
+ [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b).
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ Args:
47
+ hidden_size (`int`, *optional*, defaults to 768):
48
+ Dimensionality of the encoder layers and the pooler layer.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ num_channels (`int`, *optional*, defaults to 3):
56
+ Number of channels in the input images.
57
+ image_size (`int`, *optional*, defaults to 224):
58
+ The size (resolution) of each image.
59
+ patch_size (`int`, *optional*, defaults to 32):
60
+ The size (resolution) of each patch.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
62
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
63
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
64
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the layer normalization layers.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for the attention probabilities.
68
+ intializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation for initializing all weight matrices in the model.
70
+
71
+ Example:
72
+
73
+ ```python
74
+ >>> from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
75
+ >>> from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
76
+
77
+ >>> # Initializing a Idefics2VisionConfig with google/siglip-base-patch16-224 style configuration
78
+ >>> configuration = Idefics2VisionConfig()
79
+
80
+ >>> # Initializing a Idefics2VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
81
+ >>> model = Idefics2VisionTransformer(configuration)
82
+
83
+ >>> # Accessing the model configuration
84
+ >>> configuration = model.config
85
+ ```"""
86
+ _auto_class = 'AutoConfig'
87
+ model_type = "Idefics2VisionConfig"
88
+
89
+ def __init__(
90
+ self,
91
+ hidden_size=768,
92
+ intermediate_size=3072,
93
+ num_hidden_layers=12,
94
+ num_attention_heads=12,
95
+ num_channels=3,
96
+ image_size=224,
97
+ patch_size=32,
98
+ hidden_act="gelu_pytorch_tanh",
99
+ layer_norm_eps=1e-6,
100
+ attention_dropout=0.0,
101
+ initializer_range=0.02,
102
+ model_type='Idefics2VisionConfig',
103
+ **kwargs,
104
+ ):
105
+ super().__init__(**kwargs)
106
+
107
+ self.hidden_size = hidden_size
108
+ self.intermediate_size = intermediate_size
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_attention_heads = num_attention_heads
111
+ self.num_channels = num_channels
112
+ self.patch_size = patch_size
113
+ self.image_size = image_size
114
+ self.attention_dropout = attention_dropout
115
+ self.layer_norm_eps = layer_norm_eps
116
+ self.hidden_act = hidden_act
117
+ self.initializer_range = initializer_range
118
+ """
119
+ @classmethod
120
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
121
+
122
+ with open(pretrained_model_name_or_path, "r", encoding="utf-8") as f:
123
+ config_dict = json.load(f)
124
+
125
+ cls = Idefics2VisionConfig(
126
+ hidden_size=config_dict["hidden_size"],
127
+ image_size=config_dict["image_size"],
128
+ intermediate_size = config_dict["intermediate_size"],
129
+ model_type=config_dict["model_type"],
130
+ num_attention_heads = config_dict["num_attention_heads"],
131
+ num_hidden_layers = config_dict["num_hidden_layers"],
132
+ patch_size = config_dict["patch_size"]
133
+ )
134
+
135
+ return cls
136
+ """
137
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
138
+ def _get_unpad_data(attention_mask):
139
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
140
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
141
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
142
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
143
+ return (
144
+ indices,
145
+ cu_seqlens,
146
+ max_seqlen_in_batch,
147
+ )
148
+
149
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
150
+ class Idefics2VisionAttention(nn.Module):
151
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
152
+
153
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ self.config = config
157
+ self.embed_dim = config.hidden_size
158
+ self.num_heads = config.num_attention_heads
159
+ self.head_dim = self.embed_dim // self.num_heads
160
+ if self.head_dim * self.num_heads != self.embed_dim:
161
+ raise ValueError(
162
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
163
+ f" {self.num_heads})."
164
+ )
165
+ self.scale = self.head_dim**-0.5
166
+ self.dropout = config.attention_dropout
167
+
168
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
169
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
170
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
171
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
172
+
173
+ # Ignore copy
174
+ self.is_causal = False
175
+
176
+ def forward(
177
+ self,
178
+ hidden_states: torch.Tensor,
179
+ attention_mask: Optional[torch.Tensor] = None,
180
+ output_attentions: Optional[bool] = False,
181
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182
+ """Input shape: Batch x Time x Channel"""
183
+
184
+ batch_size, q_len, _ = hidden_states.size()
185
+
186
+ query_states = self.q_proj(hidden_states)
187
+ key_states = self.k_proj(hidden_states)
188
+ value_states = self.v_proj(hidden_states)
189
+
190
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
191
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
193
+
194
+ k_v_seq_len = key_states.shape[-2]
195
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
196
+
197
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
198
+ raise ValueError(
199
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
200
+ f" {attn_weights.size()}"
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
205
+ raise ValueError(
206
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
207
+ )
208
+ attn_weights = attn_weights + attention_mask
209
+
210
+ # upcast attention to fp32
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
212
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
213
+ attn_output = torch.matmul(attn_weights, value_states)
214
+
215
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
216
+ raise ValueError(
217
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
218
+ f" {attn_output.size()}"
219
+ )
220
+
221
+ attn_output = attn_output.transpose(1, 2).contiguous()
222
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
223
+
224
+ attn_output = self.out_proj(attn_output)
225
+
226
+ return attn_output, attn_weights
227
+
228
+
229
+ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
230
+ """
231
+ Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
232
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
233
+ flash attention and deal with padding tokens in case the input contains any of them.
234
+ """
235
+
236
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
237
+ def __init__(self, *args, **kwargs):
238
+ super().__init__(*args, **kwargs)
239
+
240
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
241
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
242
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
243
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ attention_mask: Optional[torch.LongTensor] = None,
249
+ position_ids: Optional[torch.LongTensor] = None,
250
+ past_key_value: Optional[Cache] = None,
251
+ output_attentions: bool = False,
252
+ use_cache: bool = False,
253
+ **kwargs,
254
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
255
+
256
+
257
+ output_attentions = False
258
+
259
+ bsz, q_len, _ = hidden_states.size()
260
+
261
+ query_states = self.q_proj(hidden_states)
262
+ key_states = self.k_proj(hidden_states)
263
+ value_states = self.v_proj(hidden_states)
264
+
265
+ # Flash attention requires the input to have the shape
266
+ # batch_size x seq_length x head_dim x hidden_dim
267
+ # therefore we just need to keep the original shape
268
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
270
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
271
+
272
+ kv_seq_len = key_states.shape[-2]
273
+ if past_key_value is not None:
274
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
275
+
276
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
277
+ # to be able to avoid many of these transpose/reshape/view.
278
+ query_states = query_states.transpose(1, 2)
279
+ key_states = key_states.transpose(1, 2)
280
+ value_states = value_states.transpose(1, 2)
281
+
282
+ dropout_rate = self.dropout if self.training else 0.0
283
+
284
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
285
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
286
+ # cast them back in the correct dtype just to be sure everything works as expected.
287
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
288
+ # in fp32. (Idefics2VisionRMSNorm handles it correctly)
289
+
290
+ input_dtype = query_states.dtype
291
+ if input_dtype == torch.float32:
292
+ if torch.is_autocast_enabled():
293
+ target_dtype = torch.get_autocast_gpu_dtype()
294
+ # Handle the case where the model is quantized
295
+ elif hasattr(self.config, "_pre_quantization_dtype"):
296
+ target_dtype = self.config._pre_quantization_dtype
297
+ else:
298
+ target_dtype = self.q_proj.weight.dtype
299
+
300
+ logger.warning_once(
301
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
302
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
303
+ f" {target_dtype}."
304
+ )
305
+
306
+ query_states = query_states.to(target_dtype)
307
+ key_states = key_states.to(target_dtype)
308
+ value_states = value_states.to(target_dtype)
309
+
310
+ attn_output = self._flash_attention_forward(
311
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
312
+ )
313
+
314
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
315
+ attn_output = self.out_proj(attn_output)
316
+
317
+ if not output_attentions:
318
+ attn_weights = None
319
+
320
+ return attn_output, attn_weights
321
+
322
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
323
+ def _flash_attention_forward(
324
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
325
+ ):
326
+ """
327
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
328
+ first unpad the input, then computes the attention scores and pad the final attention scores.
329
+
330
+ Args:
331
+ query_states (`torch.Tensor`):
332
+ Input query states to be passed to Flash Attention API
333
+ key_states (`torch.Tensor`):
334
+ Input key states to be passed to Flash Attention API
335
+ value_states (`torch.Tensor`):
336
+ Input value states to be passed to Flash Attention API
337
+ attention_mask (`torch.Tensor`):
338
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
339
+ position of padding tokens and 1 for the position of non-padding tokens.
340
+ dropout (`float`):
341
+ Attention dropout
342
+ softmax_scale (`float`, *optional*):
343
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
344
+ """
345
+ if not self._flash_attn_uses_top_left_mask:
346
+ causal = self.is_causal
347
+ else:
348
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
349
+ causal = self.is_causal and query_length != 1
350
+
351
+ # Contains at least one padding token in the sequence
352
+ if attention_mask is not None:
353
+ batch_size = query_states.shape[0]
354
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
355
+ query_states, key_states, value_states, attention_mask, query_length
356
+ )
357
+
358
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
359
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
360
+
361
+ attn_output_unpad = flash_attn_varlen_func(
362
+ query_states,
363
+ key_states,
364
+ value_states,
365
+ cu_seqlens_q=cu_seqlens_q,
366
+ cu_seqlens_k=cu_seqlens_k,
367
+ max_seqlen_q=max_seqlen_in_batch_q,
368
+ max_seqlen_k=max_seqlen_in_batch_k,
369
+ dropout_p=dropout,
370
+ softmax_scale=softmax_scale,
371
+ causal=causal,
372
+ )
373
+
374
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
375
+ else:
376
+ attn_output = flash_attn_func(
377
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
378
+ )
379
+
380
+ return attn_output
381
+
382
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
383
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
384
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
385
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
386
+
387
+ key_layer = index_first_axis(
388
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
389
+ )
390
+ value_layer = index_first_axis(
391
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
392
+ )
393
+ if query_length == kv_seq_len:
394
+ query_layer = index_first_axis(
395
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
396
+ )
397
+ cu_seqlens_q = cu_seqlens_k
398
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
399
+ indices_q = indices_k
400
+ elif query_length == 1:
401
+ max_seqlen_in_batch_q = 1
402
+ cu_seqlens_q = torch.arange(
403
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
404
+ ) # There is a memcpy here, that is very bad.
405
+ indices_q = cu_seqlens_q[:-1]
406
+ query_layer = query_layer.squeeze(1)
407
+ else:
408
+ # The -q_len: slice assumes left padding.
409
+ attention_mask = attention_mask[:, -query_length:]
410
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
411
+
412
+ return (
413
+ query_layer,
414
+ key_layer,
415
+ value_layer,
416
+ indices_q,
417
+ (cu_seqlens_q, cu_seqlens_k),
418
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
419
+ )
420
+
421
+ IDEFICS_VISION_ATTENTION_CLASSES = {
422
+ "eager": Idefics2VisionAttention,
423
+ "flash_attention_2": Idefics2VisionFlashAttention2,
424
+ }
425
+
426
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
427
+ class Idefics2VisionMLP(nn.Module):
428
+ def __init__(self, config):
429
+ super().__init__()
430
+ self.config = config
431
+ self.activation_fn = ACT2FN[config.hidden_act]
432
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
433
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
434
+
435
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
436
+ hidden_states = self.fc1(hidden_states)
437
+ hidden_states = self.activation_fn(hidden_states)
438
+ hidden_states = self.fc2(hidden_states)
439
+ return hidden_states
440
+
441
+ class Idefics2EncoderLayer(nn.Module):
442
+ def __init__(self, config: Idefics2VisionConfig):
443
+ super().__init__()
444
+ self.embed_dim = config.hidden_size
445
+ self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
446
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
447
+ self.mlp = Idefics2VisionMLP(config)
448
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
449
+
450
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
451
+ def forward(
452
+ self,
453
+ hidden_states: torch.Tensor,
454
+ attention_mask: torch.Tensor,
455
+ output_attentions: Optional[bool] = False,
456
+ ) -> Tuple[torch.FloatTensor]:
457
+ """
458
+ Args:
459
+ hidden_states (`torch.FloatTensor`):
460
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
461
+ attention_mask (`torch.FloatTensor`):
462
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
463
+ output_attentions (`bool`, *optional*, defaults to `False`):
464
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
465
+ returned tensors for more detail.
466
+ """
467
+ residual = hidden_states
468
+
469
+ hidden_states = self.layer_norm1(hidden_states)
470
+ hidden_states, attn_weights = self.self_attn(
471
+ hidden_states=hidden_states,
472
+ attention_mask=attention_mask,
473
+ output_attentions=output_attentions,
474
+ )
475
+ hidden_states = residual + hidden_states
476
+
477
+ residual = hidden_states
478
+ hidden_states = self.layer_norm2(hidden_states)
479
+ hidden_states = self.mlp(hidden_states)
480
+ hidden_states = residual + hidden_states
481
+
482
+ outputs = (hidden_states,)
483
+
484
+ if output_attentions:
485
+ outputs += (attn_weights,)
486
+
487
+ return outputs
488
+
489
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
490
+ class Idefics2Encoder(nn.Module):
491
+ """
492
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
493
+ [`Idefics2EncoderLayer`].
494
+
495
+ Args:
496
+ config: Idefics2VisionConfig
497
+ """
498
+
499
+ def __init__(self, config: Idefics2VisionConfig):
500
+ super().__init__()
501
+ self.config = config
502
+ self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
503
+ self.gradient_checkpointing = False
504
+
505
+ # Ignore copy
506
+ def forward(
507
+ self,
508
+ inputs_embeds,
509
+ attention_mask: Optional[torch.Tensor] = None,
510
+ output_attentions: Optional[bool] = None,
511
+ output_hidden_states: Optional[bool] = None,
512
+ return_dict: Optional[bool] = None,
513
+ ) -> Union[Tuple, BaseModelOutput]:
514
+ r"""
515
+ Args:
516
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
517
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
518
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
519
+ than the model's internal embedding lookup matrix.
520
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
521
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
522
+
523
+ - 1 for tokens that are **not masked**,
524
+ - 0 for tokens that are **masked**.
525
+
526
+ [What are attention masks?](../glossary#attention-mask)
527
+ output_attentions (`bool`, *optional*):
528
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
529
+ returned tensors for more detail.
530
+ output_hidden_states (`bool`, *optional*):
531
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
532
+ for more detail.
533
+ return_dict (`bool`, *optional*):
534
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
535
+ """
536
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
537
+ output_hidden_states = (
538
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
539
+ )
540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541
+
542
+ encoder_states = () if output_hidden_states else None
543
+ all_attentions = () if output_attentions else None
544
+
545
+ hidden_states = inputs_embeds
546
+ for encoder_layer in self.layers:
547
+ if output_hidden_states:
548
+ encoder_states = encoder_states + (hidden_states,)
549
+ if self.gradient_checkpointing and self.training:
550
+ layer_outputs = self._gradient_checkpointing_func(
551
+ encoder_layer.__call__,
552
+ hidden_states,
553
+ attention_mask,
554
+ output_attentions,
555
+ )
556
+ else:
557
+ layer_outputs = encoder_layer(
558
+ hidden_states,
559
+ attention_mask,
560
+ output_attentions=output_attentions,
561
+ )
562
+
563
+ hidden_states = layer_outputs[0]
564
+
565
+ if output_attentions:
566
+ all_attentions = all_attentions + (layer_outputs[1],)
567
+
568
+ if output_hidden_states:
569
+ encoder_states = encoder_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
573
+ return BaseModelOutput(
574
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
575
+ )
576
+
577
+ class Idefics2VisionEmbeddings(nn.Module):
578
+ """
579
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
580
+ resolution.
581
+
582
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
583
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
584
+ fixed size. In particular, we start from the original pre-trained SigLIP model
585
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
586
+ """
587
+
588
+ def __init__(self, config: Idefics2VisionConfig):
589
+ super().__init__()
590
+ self.embed_dim = config.hidden_size
591
+ self.image_size = config.image_size
592
+ self.patch_size = config.patch_size
593
+
594
+ self.patch_embedding = nn.Conv2d(
595
+ in_channels=config.num_channels,
596
+ out_channels=self.embed_dim,
597
+ kernel_size=self.patch_size,
598
+ stride=self.patch_size,
599
+ padding="valid",
600
+ )
601
+
602
+ self.num_patches_per_side = self.image_size // self.patch_size
603
+ self.num_patches = self.num_patches_per_side**2
604
+ self.num_positions = self.num_patches
605
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
606
+
607
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
608
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
609
+
610
+ patch_embeds = self.patch_embedding(pixel_values)
611
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
612
+
613
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
614
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
615
+ position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
616
+
617
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
618
+ nb_patches_h = p_attn_mask[:, 0].sum()
619
+ nb_patches_w = p_attn_mask[0].sum()
620
+
621
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
622
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
623
+
624
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
625
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
626
+
627
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
628
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
629
+
630
+ position_ids = position_ids.to(self.position_embedding.weight.device)
631
+ embeddings = embeddings + self.position_embedding(position_ids)
632
+ return embeddings
633
+
634
+
635
+ class Idefics2VisionTransformer(PreTrainedModel):
636
+ _auto_class = 'AutoModel'
637
+ config_class = Idefics2VisionConfig
638
+ supports_gradient_checkpointing = True
639
+
640
+ def __init__(self, config: Idefics2VisionConfig):
641
+ super().__init__(config)
642
+ embed_dim = config.hidden_size
643
+
644
+ config._attn_implementation = "flash_attention_2"
645
+ self._use_flash_attention_2 = True
646
+ self.config = config
647
+ self.embeddings = Idefics2VisionEmbeddings(config)
648
+ self.encoder = Idefics2Encoder(config)
649
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
650
+
651
+
652
+ def get_input_embeddings(self):
653
+ return self.embeddings
654
+
655
+ def set_input_embeddings(self, value):
656
+ self.embeddings = value
657
+
658
+ def forward(
659
+ self,
660
+ pixel_values,
661
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
662
+ output_attentions: Optional[bool] = None,
663
+ output_hidden_states: Optional[bool] = None,
664
+ return_dict: Optional[bool] = None,
665
+ ) -> Union[Tuple, BaseModelOutput]:
666
+
667
+ pixel_values = pixel_values.to(torch.bfloat16)
668
+
669
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
670
+ output_hidden_states = (
671
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
672
+ )
673
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
674
+
675
+ batch_size = pixel_values.size(0)
676
+ if patch_attention_mask is None:
677
+ patch_size = self.config.patch_size
678
+ patch_attention_mask = torch.ones(
679
+ (
680
+ batch_size,
681
+ pixel_values.size(2) // patch_size,
682
+ pixel_values.size(3) // patch_size,
683
+ )
684
+ )
685
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
686
+
687
+
688
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
689
+
690
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
691
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
692
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
693
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
694
+ if not torch.any(~patch_attention_mask):
695
+ patch_attention_mask = None
696
+ elif not self._use_flash_attention_2:
697
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
698
+
699
+ encoder_outputs = self.encoder(
700
+ inputs_embeds=hidden_states,
701
+ attention_mask=patch_attention_mask,
702
+ output_attentions=output_attentions,
703
+ output_hidden_states=output_hidden_states,
704
+ return_dict=return_dict,
705
+ )
706
+
707
+ last_hidden_state = encoder_outputs[0]
708
+ last_hidden_state = self.post_layernorm(last_hidden_state)
709
+
710
+ if not return_dict:
711
+ return (last_hidden_state,) + encoder_outputs[1:]
712
+
713
+ return BaseModelOutput(
714
+ last_hidden_state=last_hidden_state,
715
+ hidden_states=encoder_outputs.hidden_states,
716
+ attentions=encoder_outputs.attentions,
717
+ )
718
+ """
719
+ @classmethod
720
+ def from_pretrained(self, config_path="/mnt/csp/mmvision/home/arrayyang/idefics2-8b/idefics2_vision_model"):
721
+ config = Idefics2VisionConfig.from_pretrained(f'{config_path}/config.json')
722
+ cls = Idefics2VisionTransformer(config=config)
723
+
724
+ state_dict = torch.load(f'{config_path}/vision_model.pth', map_location='cpu')
725
+ ret = cls.load_state_dict(state_dict, strict=False)
726
+ print("Loading idefics2 Vision Model: {}".format(config_path))
727
+ return cls
728
+ """