VictorSanh commited on
Commit
217780a
1 Parent(s): 15d7f9b

Update visualization

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +19 -5
  2. app.py +793 -0
  3. app_bis.py +857 -0
  4. app_dialogue.py +395 -0
  5. elon_musk.md +23 -0
  6. images/bear.jpg +0 -0
  7. m4/__init__.py +1 -0
  8. m4/models/__init__.py +28 -0
  9. m4/models/common.py +104 -0
  10. m4/models/custom_modules.py +337 -0
  11. m4/models/perceiver/perceiver.py +141 -0
  12. m4/models/vbloom/__init__.py +0 -0
  13. m4/models/vbloom/configuration_vbloom.py +235 -0
  14. m4/models/vbloom/modeling_vbloom.py +1396 -0
  15. m4/models/vgpt2/__init__.py +0 -0
  16. m4/models/vgpt2/configuration_vgpt2.py +288 -0
  17. m4/models/vgpt2/modeling_vgpt2.py +1384 -0
  18. m4/models/vgpt_neo/__init__.py +0 -0
  19. m4/models/vgpt_neo/configuration_vgpt_neo.py +250 -0
  20. m4/models/vgpt_neo/modeling_vgpt_neo.py +1182 -0
  21. m4/models/vllama/configuration_vllama.py +204 -0
  22. m4/models/vllama/make_tiny_llama.py +51 -0
  23. m4/models/vllama/make_tiny_model.py +114 -0
  24. m4/models/vllama/modeling_vllama.py +1260 -0
  25. m4/models/vopt/__init__.py +0 -0
  26. m4/models/vopt/configuration_vopt.py +250 -0
  27. m4/models/vopt/make_tiny_model.py +114 -0
  28. m4/models/vopt/modeling_vopt.py +1513 -0
  29. m4/models/vt5/__init__.py +0 -0
  30. m4/models/vt5/configuration_vt5.py +218 -0
  31. m4/models/vt5/modeling_vt5.py +0 -0
  32. m4/models/zero_checkpoint_to_hf.py +87 -0
  33. m4/scripts/cleanup-checkpoints.py +156 -0
  34. m4/scripts/convert-checkpoints.py +124 -0
  35. m4/scripts/s3-upload-checkpoints.py +194 -0
  36. m4/scripts/s3_checkpoint_download_convert_upload.py +171 -0
  37. m4/scripts/s3_checkpoint_download_convert_upload.slurm +51 -0
  38. m4/scripts/s3_downloaded_checkpoints_cleanup.slurm +54 -0
  39. m4/scripts/schedule-evals.py +87 -0
  40. m4/testing_utils.py +1116 -0
  41. m4/training/__init__.py +0 -0
  42. m4/training/config.py +545 -0
  43. m4/training/dataset_utils.py +352 -0
  44. m4/training/debug_utils.py +34 -0
  45. m4/training/packing.py +755 -0
  46. m4/training/setup_language_model.py +38 -0
  47. m4/training/setup_vision_model.py +33 -0
  48. m4/training/types.py +13 -0
  49. m4/training/utils.py +539 -0
  50. m4/utils/__init__.py +0 -0
README.md CHANGED
@@ -1,12 +1,26 @@
1
  ---
2
- title: M4 Dialogue
3
  emoji: 🐨
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.38.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: m4-dialogue
3
  emoji: 🐨
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # M4 Visualization
13
+
14
+ For visualizations, we have a main [app](https://huggingface.co/spaces/HuggingFaceM4/m4-demo) which calls multiple child apps to retrieve generations via [Gradio API](https://gradio.app/using-blocks-like-functions/). This allows us to parallelize calls to multiple models at the same time instead of running them sequentially.
15
+
16
+
17
+ ## How to?
18
+
19
+ The process of adding a model to the main space:
20
+
21
+ - Use `huggingface-cli login` to login with an auth token that has a read/write access to the `HuggingFaceM4` org on the hub.
22
+ - Use `./upload_checkpoint_to_hub_gcs.sh` script to upload a checkpoint from GCP store to the hub. An example command to upload checkpoint for step 3000 from `tr_121ter` to the hub: `./m4/visualization/upload_checkpoint_to_hub_gcs.sh gs://hf-science-m4-cold/local_experiment_dir/tr_121ter/opt_step-3000`. This will create model repo under the `HuggingFaceM4` repo on the hub. If you are on the cluster, use `./upload_checkpoint_to_hub_s3.sh` instead. I recommend being on a compute node to avoid disk space issues (uploading to the hub consists in downloading locally the checkpoint, creating a repo on the hub, copying it locally, filling it with the weights and commiting the weights to the hub repo).
23
+ - [MANUAL] Go to the hub, create a repo of type `space` with the same name as the model. In the space's settings, add a secret `HF_AUTH_TOKEN` with a token which has read access to the `HuggingFaceM4` repo. This step can be potentially automated in the future.
24
+ - [MANUAL] Edit `m4/visualization/app.py`'s three dictionary to include your model in the existing formats of those dictionaries.
25
+ - Run `m4/visualization/sync-repo.sh <name_of_the_space_on_the_hub>` to sync the repo with the local setting. This will automatically update the space to have the latest code as in the `m4/visualization/app.py`.
26
+ - Run `m4/visualization/sync-repo.sh main` to update the main repo as well with the new model.
app.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import time
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import requests
9
+ import torch
10
+ import transformers
11
+ from accelerate.utils import get_max_memory
12
+ from joblib import Parallel, delayed
13
+ from PIL import Image
14
+ from transformers import AutoTokenizer
15
+
16
+ from m4.models.vbloom import configuration_vbloom, modeling_vbloom
17
+ from m4.models.vgpt2 import configuration_vgpt2, modeling_vgpt2
18
+ from m4.models.vgpt_neo import configuration_vgpt_neo, modeling_vgpt_neo
19
+ from m4.models.vllama import configuration_vllama, modeling_vllama
20
+ from m4.models.vopt import configuration_vopt, modeling_vopt
21
+ from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
22
+ from m4.training.utils import build_image_transform
23
+
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger()
27
+
28
+ CURRENT_MODEL = "<replace_model>"
29
+ # CURRENT_MODEL = "tr_177_4datasets_alpha_baseline_opt_step-2000"
30
+
31
+ MAX_TRIES = 3
32
+ TOKENIZER_FAST = True
33
+ MAX_SEQ_LEN = 1024
34
+ model, tokenizer = None, None
35
+
36
+
37
+ MODEL_TO_DISPLAY_NAME = {
38
+ "tr_199_w_xattn_opt_step-65000": "VLlama - tr_199_w_xattn_opt_step-65000",
39
+ # "tr_201_sft_on_lrv_opt_step-15000": "VLlama - tr_201_sft_on_lrv_opt_step-15000",
40
+ # "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlama - tr_202bis_ift_llava_all_unfrozen_opt_step-14128",
41
+ # "tr_203_ift_m3it_opt_step-50000": "VLlama - tr_203_ift_m3it_opt_step-50000",
42
+ # "tr_205_sft_ultrachat_opt_step-20000": "VLlama - tr_205_sft_ultrachat_opt_step-20000",
43
+ # "tr_207_ift_svit_opt_step-14627": "VLlama - tr_207_ift_svit_opt_step-14627",
44
+ "tr_209_ift_mixture_opt_step-14000": "VLlama - tr_209_ift_mixture_opt_step-14000",
45
+ }
46
+ MODEL_TO_MODEL_CLASS = {
47
+ "tr_199_w_xattn_opt_step-65000": "VLlamaForCausalLM",
48
+ # "tr_201_sft_on_lrv_opt_step-15000": "VLlamaForCausalLM",
49
+ # "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaForCausalLM",
50
+ # "tr_203_ift_m3it_opt_step-50000": "VLlamaForCausalLM",
51
+ # "tr_205_sft_ultrachat_opt_step-20000": "VLlamaForCausalLM",
52
+ # "tr_207_ift_svit_opt_step-14627": "VLlamaForCausalLM",
53
+ "tr_209_ift_mixture_opt_step-14000": "VLlamaForCausalLM",
54
+ }
55
+
56
+ MODEL_TO_CONFIG_CLASS = {
57
+ "tr_199_w_xattn_opt_step-65000": "VLlamaConfig",
58
+ # "tr_201_sft_on_lrv_opt_step-15000": "VLlamaConfig",
59
+ # "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaConfig",
60
+ # "tr_203_ift_m3it_opt_step-50000": "VLlamaConfig",
61
+ # "tr_205_sft_ultrachat_opt_step-20000": "VLlamaConfig",
62
+ # "tr_207_ift_svit_opt_step-14627": "VLlamaConfig",
63
+ "tr_209_ift_mixture_opt_step-14000": "VLlamaConfig",
64
+ }
65
+
66
+
67
+ def load_tokenizer_model(model_name, model_class):
68
+ tokenizer = AutoTokenizer.from_pretrained(
69
+ model_name,
70
+ use_fast=TOKENIZER_FAST,
71
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
72
+ )
73
+ tokenizer.padding_side = "left"
74
+ config_class = MODEL_TO_CONFIG_CLASS[model_name.split("/")[-1]]
75
+
76
+ # assert tokenizer.is_fast
77
+
78
+ supported_custom_modules = {
79
+ "vgpt2": modeling_vgpt2,
80
+ "vbloom": modeling_vbloom,
81
+ "vgptneo": modeling_vgpt_neo,
82
+ "vopt": modeling_vopt,
83
+ "vllama": modeling_vllama,
84
+ }
85
+ supported_custom_configs = {
86
+ "vgpt2": configuration_vgpt2,
87
+ "vbloom": configuration_vbloom,
88
+ "vgptneo": configuration_vgpt_neo,
89
+ "vopt": configuration_vopt,
90
+ "vllama": configuration_vllama,
91
+ }
92
+ parent_config_class = (
93
+ [v for k, v in supported_custom_configs.items() if k in model_class.lower()] + [transformers]
94
+ )[0]
95
+ parent_model_class = (
96
+ [v for k, v in supported_custom_modules.items() if k in model_class.lower()] + [transformers]
97
+ )[0]
98
+ config_class = getattr(parent_config_class, config_class)
99
+ model_class = getattr(parent_model_class, model_class)
100
+ config = config_class.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
101
+ max_memory_map = get_max_memory()
102
+ for key in max_memory_map.keys():
103
+ if key != "cpu":
104
+ # Get this in GB
105
+ max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
106
+ # Decrease 2 for Pytorch overhead and 2 for the forward to be safe
107
+ max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
108
+ model = model_class.from_pretrained(
109
+ model_name,
110
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
111
+ device_map="auto",
112
+ offload_folder="./offload",
113
+ torch_dtype=config.torch_dtype,
114
+ max_memory=max_memory_map,
115
+ )
116
+ model.eval()
117
+ print("Current device map:", model.hf_device_map)
118
+ print("Model default generation config:", model.generation_config)
119
+ # TODO: the device_map looks very inefficien right now. that could be improved
120
+ # it typically looks like that
121
+ # {
122
+ # 'model.embed_tokens': 0,
123
+ # 'model.vision_model': 0,
124
+ # 'model.layers.0': 0,
125
+ # 'model.layers.1': 0,
126
+ # 'model.layers.2': 0,
127
+ # 'model.layers.3': 0,
128
+ # 'model.layers.4': 0,
129
+ # 'model.layers.5': 0,
130
+ # 'model.layers.6': 1,
131
+ # 'model.layers.7': 1,
132
+ # 'model.layers.8': 1,
133
+ # 'model.layers.9': 1,
134
+ # 'model.layers.10': 1,
135
+ # 'model.layers.11': 1,
136
+ # 'model.layers.12': 1,
137
+ # 'model.layers.13': 1,
138
+ # 'model.layers.14': 1,
139
+ # 'model.layers.15': 1,
140
+ # 'model.layers.16': 1,
141
+ # 'model.layers.17': 2,
142
+ # 'model.layers.18': 2,
143
+ # 'model.layers.19': 2,
144
+ # 'model.layers.20': 2,
145
+ # 'model.layers.21': 2,
146
+ # 'model.layers.22': 2,
147
+ # 'model.layers.23': 2,
148
+ # 'model.layers.24': 2,
149
+ # 'model.layers.25': 2,
150
+ # 'model.layers.26': 2,
151
+ # 'model.layers.27': 2,
152
+ # 'model.layers.28': 3,
153
+ # 'model.layers.29': 3,
154
+ # 'model.layers.30': 3,
155
+ # 'model.layers.31': 3,
156
+ # 'model.gated_cross_attn_layers.0': 3,
157
+ # 'model.gated_cross_attn_layers.1': 3,
158
+ # 'model.gated_cross_attn_layers.2': 3,
159
+ # 'model.gated_cross_attn_layers.3': 3,
160
+ # 'model.gated_cross_attn_layers.4': 3,
161
+ # 'model.gated_cross_attn_layers.5': 3,
162
+ # 'model.gated_cross_attn_layers.6': 3,
163
+ # 'model.gated_cross_attn_layers.7': 3,
164
+ # 'model.gated_cross_attn_layers.8': 4,
165
+ # 'model.gated_cross_attn_layers.9': 4,
166
+ # 'model.gated_cross_attn_layers.10': 4,
167
+ # 'model.gated_cross_attn_layers.11': 4,
168
+ # 'model.gated_cross_attn_layers.12': 4,
169
+ # 'model.gated_cross_attn_layers.13': 4,
170
+ # 'model.gated_cross_attn_layers.14': 4,
171
+ # 'model.gated_cross_attn_layers.15': 4,
172
+ # 'model.norm': 4,
173
+ # 'lm_head': 4
174
+ # } which means there is a lot of things going around between the gated cross attention layers and the LM layers...
175
+ return tokenizer, model
176
+
177
+
178
+ MODEL_TO_SPACE_MAPPING = {}
179
+ IS_MAIN_SPACE = CURRENT_MODEL not in MODEL_TO_MODEL_CLASS
180
+ if IS_MAIN_SPACE:
181
+ for model in MODEL_TO_MODEL_CLASS:
182
+ MODEL_TO_SPACE_MAPPING[model] = gr.Blocks.load(
183
+ name=f"spaces/HuggingFaceM4/{model}", api_key=os.getenv("HF_AUTH_TOKEN", True)
184
+ )
185
+ else:
186
+ model_path = f"HuggingFaceM4/{CURRENT_MODEL}"
187
+ tokenizer, model = load_tokenizer_model(model_path, MODEL_TO_MODEL_CLASS[CURRENT_MODEL])
188
+
189
+
190
+ def fetch_images(url_images):
191
+ images = []
192
+ for url in url_images:
193
+ if isinstance(url, str):
194
+ images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
195
+ else:
196
+ images.append(url)
197
+ return images
198
+
199
+
200
+ def model_generation(
201
+ prompt,
202
+ images,
203
+ tokenizer,
204
+ model,
205
+ temperature,
206
+ no_repeat_ngram_size,
207
+ max_new_tokens,
208
+ min_length,
209
+ ban_tokens,
210
+ eos_tokens,
211
+ force_words,
212
+ repetition_penalty,
213
+ hide_special_tokens,
214
+ decoding_strategy,
215
+ num_beams,
216
+ length_penalty,
217
+ top_k,
218
+ top_p,
219
+ penalty_alpha,
220
+ ):
221
+ # Preparing inputs
222
+ tokens = tokenizer(
223
+ [prompt],
224
+ truncation=True,
225
+ max_length=MAX_SEQ_LEN,
226
+ padding=True,
227
+ add_special_tokens=False,
228
+ )
229
+
230
+ input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
231
+ attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
232
+
233
+ image_attention_mask = [
234
+ incremental_to_binary_attention_mask(
235
+ image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
236
+ )
237
+ ]
238
+
239
+ image_transform = build_image_transform(eval=True)
240
+ pixel_values = [torch.stack([image_transform(img) for img in images])]
241
+
242
+ input_ids = input_ids.to(0)
243
+ attention_mask = attention_mask.to(0)
244
+ pixel_values = torch.stack(pixel_values).to(0)
245
+ image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
246
+
247
+ # Excluding some words from the generation
248
+ bad_words_ids = None
249
+ ban_tokens = ban_tokens.replace("\\n", "\n")
250
+ bad_words = ban_tokens.split(";")
251
+ if len(bad_words) > 0:
252
+ bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
253
+
254
+ # Forcing some words in the generation
255
+ force_words_ids = None
256
+ if force_words != "":
257
+ force_words = force_words.replace("\\n", "\n")
258
+ force_words = force_words.split(";")
259
+ if len(force_words) > 0:
260
+ force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
261
+
262
+ eos_token_ids = None
263
+ if eos_tokens != "":
264
+ eos_tokens = eos_tokens.replace("\\n", "\n")
265
+ eos_tokens = eos_tokens.split(";")
266
+ if len(eos_tokens) > 0:
267
+ eos_token_ids = []
268
+ for eos_token in eos_tokens:
269
+ tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
270
+ if len(tokenized_eos_token) > 1:
271
+ raise ValueError(
272
+ f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
273
+ f" {tokenized_eos_token}"
274
+ )
275
+ eos_token_ids += tokenized_eos_token
276
+
277
+ # Inputs
278
+ input_args = {
279
+ "input_ids": input_ids,
280
+ "attention_mask": attention_mask,
281
+ "pixel_values": pixel_values,
282
+ "image_attention_mask": image_attention_mask,
283
+ }
284
+ # Common parameters to all decoding strategies
285
+ # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
286
+ generation_args = {
287
+ "temperature": temperature,
288
+ "no_repeat_ngram_size": no_repeat_ngram_size,
289
+ "max_new_tokens": max_new_tokens,
290
+ "min_length": min_length,
291
+ "bad_words_ids": bad_words_ids,
292
+ "force_words_ids": force_words_ids,
293
+ "repetition_penalty": repetition_penalty,
294
+ "eos_token_id": eos_token_ids,
295
+ }
296
+
297
+ assert decoding_strategy in [
298
+ "greedy",
299
+ "beam_search",
300
+ "beam_sampling",
301
+ "sampling_top_k",
302
+ "sampling_top_p",
303
+ "contrastive_sampling",
304
+ ]
305
+ if decoding_strategy == "greedy":
306
+ pass
307
+ elif decoding_strategy == "beam_search":
308
+ generation_args["num_beams"] = num_beams
309
+ generation_args["length_penalty"] = length_penalty
310
+ assert generation_args["num_beams"] > 1
311
+ elif decoding_strategy == "beam_sampling":
312
+ generation_args["num_beams"] = num_beams
313
+ generation_args["length_penalty"] = length_penalty
314
+ generation_args["do_sample"] = True
315
+ assert generation_args["num_beams"] > 1
316
+ elif decoding_strategy == "sampling_top_k":
317
+ generation_args["do_sample"] = True
318
+ generation_args["top_k"] = top_k
319
+ elif decoding_strategy == "sampling_top_p":
320
+ generation_args["do_sample"] = True
321
+ generation_args["top_p"] = top_p
322
+ elif decoding_strategy == "contrastive_sampling":
323
+ generation_args["do_sample"] = True
324
+ generation_args["penalty_alpha"] = penalty_alpha
325
+ generation_args["top_k"] = top_k
326
+
327
+ generated_tokens = model.generate(
328
+ **input_args,
329
+ **generation_args,
330
+ )
331
+ tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
332
+ decoded_skip_special_tokens = repr(
333
+ tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
334
+ )
335
+ decoded = repr(tokenizer.batch_decode(generated_tokens)[0])
336
+ logger.info(
337
+ "Result: \n"
338
+ f"Prompt: `{prompt}`\n"
339
+ f"Tokens ids from prompt + generation: `{generated_tokens[0].tolist()}`\n"
340
+ f"Tokens (converted) from prompt + generation: `{tokens}`\n"
341
+ f"String decoded with skipped special tokens: `{decoded_skip_special_tokens}`\n"
342
+ f"String decoded: `{decoded}`\n"
343
+ f"Generation mode: `{decoding_strategy}`\n"
344
+ f"Generation parameters: `{generation_args}`\n"
345
+ )
346
+
347
+ original_prompt = generated_tokens[:, : input_ids.shape[-1]]
348
+ actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
349
+
350
+ first_end_token = len(actual_generated_tokens[0])
351
+ actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
352
+ displayed_tokens = torch.cat([original_prompt, actual_generated_tokens], dim=-1)
353
+ generated_text = tokenizer.batch_decode(displayed_tokens, skip_special_tokens=hide_special_tokens)[0]
354
+ return generated_text
355
+
356
+
357
+ def model_inference(
358
+ files,
359
+ prompt,
360
+ temperature,
361
+ no_repeat_ngram_size,
362
+ max_new_tokens,
363
+ min_length,
364
+ ban_tokens,
365
+ eos_tokens,
366
+ force_words,
367
+ repetition_penalty,
368
+ hide_special_tokens,
369
+ decoding_strategy,
370
+ num_beams,
371
+ length_penalty,
372
+ top_k,
373
+ top_p,
374
+ penalty_alpha,
375
+ ):
376
+ if isinstance(files, str) and len(files) == 0:
377
+ files = None
378
+
379
+ prompt = prompt.strip()
380
+ prompt = prompt.replace("\\n", "\n")
381
+ file_idx = 0
382
+ url_images = re.findall(r"<image(.*?)>", prompt)
383
+ for idx, url_image in enumerate(url_images):
384
+ if len(url_image) == 0:
385
+ url_images[idx] = Image.open(files[file_idx].name if hasattr(files[file_idx], "name") else files[file_idx])
386
+ file_idx += 1
387
+ else:
388
+ prompt = prompt.replace(url_image, "")
389
+ url_images[idx] = url_images[idx][1:]
390
+ images = fetch_images(url_images)
391
+
392
+ global model, tokenizer
393
+
394
+ generated_text = model_generation(
395
+ prompt=prompt,
396
+ images=images,
397
+ tokenizer=tokenizer,
398
+ model=model,
399
+ temperature=temperature,
400
+ no_repeat_ngram_size=no_repeat_ngram_size,
401
+ max_new_tokens=max_new_tokens,
402
+ min_length=min_length,
403
+ ban_tokens=ban_tokens,
404
+ eos_tokens=eos_tokens,
405
+ force_words=force_words,
406
+ repetition_penalty=repetition_penalty,
407
+ hide_special_tokens=hide_special_tokens,
408
+ decoding_strategy=decoding_strategy,
409
+ num_beams=num_beams,
410
+ length_penalty=length_penalty,
411
+ top_k=top_k,
412
+ top_p=top_p,
413
+ penalty_alpha=penalty_alpha,
414
+ )
415
+ return generated_text.strip()
416
+
417
+
418
+ def try_model_inference(
419
+ model,
420
+ files,
421
+ prompt,
422
+ temperature,
423
+ no_repeat_ngram_size,
424
+ max_new_tokens,
425
+ min_length,
426
+ ban_tokens,
427
+ eos_tokens,
428
+ force_words,
429
+ repetition_penalty,
430
+ hide_special_tokens,
431
+ decoding_strategy,
432
+ num_beams,
433
+ length_penalty,
434
+ top_k,
435
+ top_p,
436
+ penalty_alpha,
437
+ ):
438
+ count = 0
439
+ while count < MAX_TRIES:
440
+ try:
441
+ return MODEL_TO_SPACE_MAPPING[model](
442
+ files,
443
+ prompt,
444
+ temperature,
445
+ no_repeat_ngram_size,
446
+ max_new_tokens,
447
+ min_length,
448
+ ban_tokens,
449
+ eos_tokens,
450
+ force_words,
451
+ repetition_penalty,
452
+ hide_special_tokens,
453
+ decoding_strategy,
454
+ num_beams,
455
+ length_penalty,
456
+ top_k,
457
+ top_p,
458
+ penalty_alpha,
459
+ api_name="model_inference",
460
+ )
461
+ except KeyError:
462
+ # Gradio return {'error': None} some times.
463
+ time.sleep(3)
464
+ count += 1
465
+ pass
466
+
467
+
468
+ def all_model_inference(
469
+ prompt,
470
+ temperature,
471
+ no_repeat_ngram_size,
472
+ max_new_tokens,
473
+ min_length,
474
+ ban_tokens,
475
+ eos_tokens,
476
+ force_words,
477
+ repetition_penalty,
478
+ hide_special_tokens,
479
+ decoding_strategy,
480
+ num_beams,
481
+ length_penalty,
482
+ top_k,
483
+ top_p,
484
+ penalty_alpha,
485
+ ):
486
+ outputs = []
487
+ print(
488
+ prompt,
489
+ temperature,
490
+ no_repeat_ngram_size,
491
+ max_new_tokens,
492
+ min_length,
493
+ ban_tokens,
494
+ eos_tokens,
495
+ force_words,
496
+ repetition_penalty,
497
+ hide_special_tokens,
498
+ decoding_strategy,
499
+ num_beams,
500
+ length_penalty,
501
+ top_k,
502
+ top_p,
503
+ penalty_alpha,
504
+ )
505
+ outputs = Parallel(n_jobs=len(MODEL_TO_SPACE_MAPPING), backend="threading")(
506
+ delayed(try_model_inference)(
507
+ model,
508
+ os.path.join(os.path.dirname(__file__), "images", "bear.jpg"),
509
+ prompt,
510
+ temperature,
511
+ no_repeat_ngram_size,
512
+ max_new_tokens,
513
+ min_length,
514
+ ban_tokens,
515
+ eos_tokens,
516
+ force_words,
517
+ repetition_penalty,
518
+ hide_special_tokens,
519
+ decoding_strategy,
520
+ num_beams,
521
+ length_penalty,
522
+ top_k,
523
+ top_p,
524
+ penalty_alpha,
525
+ )
526
+ for model in MODEL_TO_SPACE_MAPPING
527
+ )
528
+ if len(outputs) == 1:
529
+ outputs = outputs[0]
530
+ return outputs
531
+
532
+
533
+ examples = [
534
+ [
535
+ None,
536
+ """The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Assistant, and a human user, called User. In the following interactions, User and Assistant will converse in natural language, and Assistant will do its best to answer User’s questions. Assistant was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. When prompted with an image, it does not make up facts.
537
+
538
+ The conversation begins:
539
+ User:<fake_token_around_image><image:https://i1.pickpik.com/photos/515/906/355/kittens-cats-pet-cute-preview.jpg><fake_token_around_image>Describe this image.</s>
540
+ Assistant: Five kittens are standing together in the center of the photograph. From the left to right, there is one orange kitte, two kittens white and black stripes, and two brown kittens with an orange head. They are in bright green grass and it looks like they are moving forward.</s>
541
+ User:How about this image?<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>Can you describe it too?</s>
542
+ Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you. The dog has a white body and brown patches on its face and ears. Its eyes are dark. Its nose is black, and it has long, floppy ears, white paws, long fur, big eyes.</s>
543
+ User: What kind of breed is it?</s>
544
+ Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.</s>
545
+ User: What can you tell me about this breed of dogs?</s>
546
+ Assistant: One specific characteristic of the Cavalier King Charles Spaniel is its friendly and affectionate nature. This breed is known for being extremely sociable and forming strong bonds with their owners. They are often described as "velcro dogs" because they love to be close to their human companions, whether it's sitting on their laps, cuddling on the couch, or simply following them around the house.</s>
547
+ User: ghjkhjabnufs</s>Assistant: That doesn’t seem to be a word. Could you ask me another way?</s>
548
+ User: Do you like Cavalier King Charles Spaniel?</s>
549
+ Assistant: I do not have personal opinions as I’m just a computer program. However, cavaliers are known for being excellent family pets due to their gentle and patient demeanor, making them great with children and other pets. Their affectionate nature also makes them ideal therapy dogs, as they can provide comfort and emotional support to those in need.</s>
550
+ User: How many dogs do you see in this image?<fake_token_around_image><image:https://i.dailymail.co.uk/i/pix/2011/07/01/article-2010308-0CD22A8300000578-496_634x414.jpg><fake_token_around_image></s>
551
+ Assistant: There is no dogs in this image. The picture shows a tennis player jumping to volley the ball.</s>
552
+ User:<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.</s>
553
+ Assistant:""",
554
+ 1.0,
555
+ 0,
556
+ 256,
557
+ 10,
558
+ "<image>;<fake_token_around_image>",
559
+ "</s>;User",
560
+ "",
561
+ 1.0,
562
+ False,
563
+ "greedy",
564
+ 1,
565
+ 1.0,
566
+ 50,
567
+ 0.5,
568
+ 0.95,
569
+ ],
570
+ # [
571
+ # None,
572
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
573
+ # User: <fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>
574
+ # Describe this image.
575
+ # Bot:""",
576
+ # 1,
577
+ # 2,
578
+ # 64,
579
+ # 10,
580
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
581
+ # False,
582
+ # False,
583
+ # True,
584
+ # ],
585
+ # [
586
+ # None,
587
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
588
+ # User: <fake_token_around_image><image:https://i.redd.it/hsktcp4nv1g01.jpg><fake_token_around_image>
589
+ # Why do people find this image funny?
590
+ # Bot:""",
591
+ # 1,
592
+ # 2,
593
+ # 64,
594
+ # 10,
595
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
596
+ # False,
597
+ # False,
598
+ # True,
599
+ # ],
600
+ # [
601
+ # None,
602
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
603
+ # User: <fake_token_around_image><image:https://pbs.twimg.com/media/FooD7oyakAIU5_Q?format=jpg&name=large><fake_token_around_image>
604
+ # Describe what's in this image.
605
+ # Bot:""",
606
+ # 1,
607
+ # 2,
608
+ # 64,
609
+ # 10,
610
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
611
+ # False,
612
+ # False,
613
+ # True,
614
+ # ],
615
+ # [
616
+ # None,
617
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
618
+ # User: <fake_token_around_image><image:https://www.tutorialride.com/images/non-verbal-analogy-questions/non-verbal-analogy-logical-reasoning-1.jpg><fake_token_around_image>
619
+ # What's the correct answer? A, B, C or D?
620
+ # Bot:""",
621
+ # 1,
622
+ # 2,
623
+ # 64,
624
+ # 10,
625
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
626
+ # False,
627
+ # False,
628
+ # True,
629
+ # ],
630
+ ]
631
+
632
+
633
+ MSG_MAIN = """
634
+ # Text generation with Vllama models
635
+
636
+ ### Help to write prompts:
637
+
638
+ Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.
639
+ """
640
+ # MSG_DETAILS = """
641
+ # ### Additional details
642
+ # - if the model was trained with the template 1 (`\\n\\n<image>\\n\\n`), then `<fake_token_around_image>` will be replaced with `\\n\\n`. This is particularly useful if you are comparing the performance of different models trained with different templates.
643
+ # - special tokens are not automatically added to the prompt, so add them manually.
644
+ # - with the first template `\\n\\n<image>\\n\\n` , the sequence isn't necessary tokenized as `["\\n\\n", "<image>", "\\n\\n"]` to enforce this behavior, you can use the "Integrate image sequence as ids" parameter.
645
+ # """
646
+ # if ~IS_MAIN_SPACE:
647
+ # MSG_DETAILS += (
648
+ # "- alternatively, you can upload images and then directly specify them via \<image\> tag in the prompt."
649
+ # )
650
+
651
+ with gr.Blocks() as demo:
652
+ gr.Markdown(MSG_MAIN)
653
+ with gr.Row():
654
+ with gr.Column():
655
+ gr.Markdown("## Input")
656
+ if not IS_MAIN_SPACE:
657
+ images = gr.File(label="Images", file_count="multiple")
658
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt here")
659
+
660
+ gr.Markdown("## Common parameters to all decoding strategy")
661
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Softmax temperature")
662
+ no_repeat_ngram_size = gr.Slider(
663
+ minimum=0,
664
+ maximum=10,
665
+ step=1,
666
+ value=0,
667
+ label="The size of an n-gram that cannot occur more than once (0=infinity)",
668
+ )
669
+ max_new_tokens = gr.Slider(
670
+ minimum=0, maximum=512, step=1, value=256, label="Maximum number of new tokens to generate"
671
+ )
672
+ min_length = gr.Slider(
673
+ minimum=0, maximum=512, step=1, value=16, label="Minimum length of the sequence to be generated"
674
+ )
675
+ ban_tokens = gr.Textbox(
676
+ label='Tokens to prevent from being generated (separated by ";")',
677
+ value="<image>;<fake_token_around_image>",
678
+ )
679
+ eos_tokens = gr.Textbox(label="EOS tokens", value="</s>")
680
+ force_words = gr.Textbox(label='Force words to be generated (separated by ";")', value="")
681
+ repetition_penalty = gr.Slider(
682
+ minimum=0, maximum=10, step=0.01, value=1, label="repetition_penalty. CTRL paper suggests 1.2."
683
+ )
684
+ hide_special_tokens = gr.Checkbox(label="Hide special tokens in the text", value=False)
685
+
686
+ gr.Markdown("## Decoding strategy and its specific parameters")
687
+ decoding_strategy = gr.Dropdown(
688
+ ["greedy", "beam_search", "beam_sampling", "sampling_top_k", "sampling_top_p", "contrastive_sampling"],
689
+ label="Decoding strategy",
690
+ value="greedy",
691
+ )
692
+ num_beams = gr.Slider(
693
+ minimum=0,
694
+ maximum=10,
695
+ step=1,
696
+ value=3,
697
+ label="Beam size",
698
+ info="Only used if `decoding_strategy` is `beam_search` or `beam_sampling`",
699
+ )
700
+ length_penalty = gr.Slider(
701
+ minimum=-1000,
702
+ maximum=1000,
703
+ step=0.1,
704
+ value=1,
705
+ label=(
706
+ "length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter"
707
+ " sequences. Only used if `decoding_strategy` is `beam_search` or `beam_sampling`"
708
+ ),
709
+ )
710
+ top_k = gr.Slider(
711
+ minimum=0,
712
+ maximum=500,
713
+ step=1,
714
+ value=50,
715
+ label="Top k",
716
+ info="Only used if `decoding_strategy` is `sampling_top_k` or `contrastive_sampling`",
717
+ )
718
+ top_p = gr.Slider(
719
+ minimum=0,
720
+ maximum=1,
721
+ step=0.01,
722
+ value=0.95,
723
+ label="Top p",
724
+ info="Only used if `decoding_strategy` is `sampling_top_p`",
725
+ )
726
+ penalty_alpha = gr.Slider(
727
+ minimum=0,
728
+ maximum=1,
729
+ step=0.01,
730
+ value=0.95,
731
+ label="Penalty alpha",
732
+ info="Only used if `decoding_strategy` is `contrastive_sampling`",
733
+ )
734
+
735
+ submit = gr.Button(label="Generate")
736
+
737
+ with gr.Column():
738
+ if IS_MAIN_SPACE:
739
+ outputs = [
740
+ gr.Textbox(label=MODEL_TO_DISPLAY_NAME[model], multiline=True, readonly=True)
741
+ for model in MODEL_TO_MODEL_CLASS
742
+ ]
743
+ inference_func = all_model_inference
744
+ inputs = [
745
+ prompt,
746
+ temperature,
747
+ no_repeat_ngram_size,
748
+ max_new_tokens,
749
+ min_length,
750
+ ban_tokens,
751
+ eos_tokens,
752
+ force_words,
753
+ repetition_penalty,
754
+ hide_special_tokens,
755
+ decoding_strategy,
756
+ num_beams,
757
+ length_penalty,
758
+ top_k,
759
+ top_p,
760
+ penalty_alpha,
761
+ ]
762
+
763
+ examples = [example[1:] for example in examples]
764
+ else:
765
+ outputs = gr.Textbox(label="Generated text", interactive=False)
766
+ inference_func = model_inference
767
+ inputs = [
768
+ images,
769
+ prompt,
770
+ temperature,
771
+ no_repeat_ngram_size,
772
+ max_new_tokens,
773
+ min_length,
774
+ ban_tokens,
775
+ eos_tokens,
776
+ force_words,
777
+ repetition_penalty,
778
+ hide_special_tokens,
779
+ decoding_strategy,
780
+ num_beams,
781
+ length_penalty,
782
+ top_k,
783
+ top_p,
784
+ penalty_alpha,
785
+ ]
786
+ with gr.Row():
787
+ gr.Examples(inputs=inputs, examples=examples)
788
+ # gr.Markdown(MSG_DETAILS)
789
+
790
+ submit.click(inference_func, inputs=inputs, outputs=outputs, api_name="model_inference")
791
+
792
+ demo.queue()
793
+ demo.launch()
app_bis.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+
5
+ import time
6
+ from io import BytesIO
7
+
8
+ import gradio as gr
9
+ import requests
10
+ import torch
11
+ import transformers
12
+ from accelerate.utils import get_max_memory
13
+
14
+ from joblib import Parallel, delayed
15
+ from PIL import Image
16
+ from transformers import AutoTokenizer
17
+
18
+ from m4.models.vbloom import configuration_vbloom, modeling_vbloom
19
+ from m4.models.vgpt2 import configuration_vgpt2, modeling_vgpt2
20
+ from m4.models.vgpt_neo import configuration_vgpt_neo, modeling_vgpt_neo
21
+ from m4.models.vllama import configuration_vllama, modeling_vllama
22
+ from m4.models.vopt import configuration_vopt, modeling_vopt
23
+ from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
24
+ from m4.training.utils import build_image_transform
25
+
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger()
29
+
30
+ CURRENT_MODEL = "tr_209_ift_mixture_opt_step-2000"
31
+
32
+ MAX_TRIES = 3
33
+ TOKENIZER_FAST = True
34
+ MAX_SEQ_LEN = 1024
35
+ model, tokenizer = None, None
36
+
37
+
38
+ MODEL_TO_DISPLAY_NAME = {
39
+ "tr_199_w_xattn_opt_step-65000": "VLlama - tr_199_w_xattn_opt_step-65000",
40
+ "tr_201_sft_on_lrv_opt_step-15000": "VLlama - tr_201_sft_on_lrv_opt_step-15000",
41
+ "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlama - tr_202bis_ift_llava_all_unfrozen_opt_step-14128",
42
+ "tr_203_ift_m3it_opt_step-50000": "VLlama - tr_203_ift_m3it_opt_step-50000",
43
+ "tr_205_sft_ultrachat_opt_step-20000": "VLlama - tr_205_sft_ultrachat_opt_step-20000",
44
+ "tr_207_ift_svit_opt_step-14627": "VLlama - tr_207_ift_svit_opt_step-14627",
45
+ "tr_209_ift_mixture_opt_step-2000": "VLlama - tr_209_ift_mixture_opt_step-2000",
46
+ }
47
+ MODEL_TO_MODEL_CLASS = {
48
+ "tr_199_w_xattn_opt_step-65000": "VLlamaForCausalLM",
49
+ "tr_201_sft_on_lrv_opt_step-15000": "VLlamaForCausalLM",
50
+ "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaForCausalLM",
51
+ "tr_203_ift_m3it_opt_step-50000": "VLlamaForCausalLM",
52
+ "tr_205_sft_ultrachat_opt_step-20000": "VLlamaForCausalLM",
53
+ "tr_207_ift_svit_opt_step-14627": "VLlamaForCausalLM",
54
+ "tr_209_ift_mixture_opt_step-2000": "VLlamaForCausalLM",
55
+ }
56
+
57
+ MODEL_TO_CONFIG_CLASS = {
58
+ "tr_199_w_xattn_opt_step-65000": "VLlamaConfig",
59
+ "tr_201_sft_on_lrv_opt_step-15000": "VLlamaConfig",
60
+ "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaConfig",
61
+ "tr_203_ift_m3it_opt_step-50000": "VLlamaConfig",
62
+ "tr_205_sft_ultrachat_opt_step-20000": "VLlamaConfig",
63
+ "tr_207_ift_svit_opt_step-14627": "VLlamaConfig",
64
+ "tr_209_ift_mixture_opt_step-2000": "VLlamaConfig",
65
+ }
66
+
67
+
68
+ def load_tokenizer_model(model_name, model_class):
69
+ tokenizer = AutoTokenizer.from_pretrained(
70
+ model_name,
71
+ use_fast=TOKENIZER_FAST,
72
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
73
+ )
74
+ tokenizer.padding_side = "left"
75
+ config_class = MODEL_TO_CONFIG_CLASS[model_name.split("/")[-1]]
76
+
77
+ # assert tokenizer.is_fast
78
+
79
+ supported_custom_modules = {
80
+ "vgpt2": modeling_vgpt2,
81
+ "vbloom": modeling_vbloom,
82
+ "vgptneo": modeling_vgpt_neo,
83
+ "vopt": modeling_vopt,
84
+ "vllama": modeling_vllama,
85
+ }
86
+ supported_custom_configs = {
87
+ "vgpt2": configuration_vgpt2,
88
+ "vbloom": configuration_vbloom,
89
+ "vgptneo": configuration_vgpt_neo,
90
+ "vopt": configuration_vopt,
91
+ "vllama": configuration_vllama,
92
+ }
93
+ parent_config_class = (
94
+ [v for k, v in supported_custom_configs.items() if k in model_class.lower()] + [transformers]
95
+ )[0]
96
+ parent_model_class = (
97
+ [v for k, v in supported_custom_modules.items() if k in model_class.lower()] + [transformers]
98
+ )[0]
99
+ config_class = getattr(parent_config_class, config_class)
100
+ model_class = getattr(parent_model_class, model_class)
101
+ config = config_class.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
102
+ max_memory_map = get_max_memory()
103
+ for key in max_memory_map.keys():
104
+ if key != "cpu":
105
+ # Get this in GB
106
+ max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
107
+ # Decrease 2 for Pytorch overhead and 2 for the forward to be safe
108
+ max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
109
+ model = model_class.from_pretrained(
110
+ model_name,
111
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
112
+ device_map="auto",
113
+ offload_folder="./offload",
114
+ torch_dtype=config.torch_dtype,
115
+ max_memory=max_memory_map,
116
+ )
117
+ model.eval()
118
+ print("Current device map:", model.hf_device_map)
119
+ print("Model default generation config:", model.generation_config)
120
+ # TODO: the device_map looks very inefficien right now. that could be improved
121
+ # it typically looks like that
122
+ # {
123
+ # 'model.embed_tokens': 0,
124
+ # 'model.vision_model': 0,
125
+ # 'model.layers.0': 0,
126
+ # 'model.layers.1': 0,
127
+ # 'model.layers.2': 0,
128
+ # 'model.layers.3': 0,
129
+ # 'model.layers.4': 0,
130
+ # 'model.layers.5': 0,
131
+ # 'model.layers.6': 1,
132
+ # 'model.layers.7': 1,
133
+ # 'model.layers.8': 1,
134
+ # 'model.layers.9': 1,
135
+ # 'model.layers.10': 1,
136
+ # 'model.layers.11': 1,
137
+ # 'model.layers.12': 1,
138
+ # 'model.layers.13': 1,
139
+ # 'model.layers.14': 1,
140
+ # 'model.layers.15': 1,
141
+ # 'model.layers.16': 1,
142
+ # 'model.layers.17': 2,
143
+ # 'model.layers.18': 2,
144
+ # 'model.layers.19': 2,
145
+ # 'model.layers.20': 2,
146
+ # 'model.layers.21': 2,
147
+ # 'model.layers.22': 2,
148
+ # 'model.layers.23': 2,
149
+ # 'model.layers.24': 2,
150
+ # 'model.layers.25': 2,
151
+ # 'model.layers.26': 2,
152
+ # 'model.layers.27': 2,
153
+ # 'model.layers.28': 3,
154
+ # 'model.layers.29': 3,
155
+ # 'model.layers.30': 3,
156
+ # 'model.layers.31': 3,
157
+ # 'model.gated_cross_attn_layers.0': 3,
158
+ # 'model.gated_cross_attn_layers.1': 3,
159
+ # 'model.gated_cross_attn_layers.2': 3,
160
+ # 'model.gated_cross_attn_layers.3': 3,
161
+ # 'model.gated_cross_attn_layers.4': 3,
162
+ # 'model.gated_cross_attn_layers.5': 3,
163
+ # 'model.gated_cross_attn_layers.6': 3,
164
+ # 'model.gated_cross_attn_layers.7': 3,
165
+ # 'model.gated_cross_attn_layers.8': 4,
166
+ # 'model.gated_cross_attn_layers.9': 4,
167
+ # 'model.gated_cross_attn_layers.10': 4,
168
+ # 'model.gated_cross_attn_layers.11': 4,
169
+ # 'model.gated_cross_attn_layers.12': 4,
170
+ # 'model.gated_cross_attn_layers.13': 4,
171
+ # 'model.gated_cross_attn_layers.14': 4,
172
+ # 'model.gated_cross_attn_layers.15': 4,
173
+ # 'model.norm': 4,
174
+ # 'lm_head': 4
175
+ # } which means there is a lot of things going around between the gated cross attention layers and the LM layers...
176
+ return tokenizer, model
177
+
178
+
179
+ MODEL_TO_SPACE_MAPPING = {}
180
+ IS_MAIN_SPACE = CURRENT_MODEL not in MODEL_TO_MODEL_CLASS
181
+ if IS_MAIN_SPACE:
182
+ for model in MODEL_TO_MODEL_CLASS:
183
+ MODEL_TO_SPACE_MAPPING[model] = gr.Blocks.load(
184
+ name=f"spaces/HuggingFaceM4/{model}", api_key=os.getenv("HF_AUTH_TOKEN", True)
185
+ )
186
+ else:
187
+ model_path = f"HuggingFaceM4/{CURRENT_MODEL}"
188
+ tokenizer, model = load_tokenizer_model(model_path, MODEL_TO_MODEL_CLASS[CURRENT_MODEL])
189
+
190
+
191
+ def fetch_images(url_images):
192
+ images = []
193
+ for url in url_images:
194
+ if isinstance(url, str):
195
+ images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
196
+ else:
197
+ images.append(url)
198
+ return images
199
+
200
+
201
+ def model_generation(
202
+ prompt,
203
+ images,
204
+ tokenizer,
205
+ model,
206
+ temperature,
207
+ no_repeat_ngram_size,
208
+ max_new_tokens,
209
+ min_length,
210
+ ban_tokens,
211
+ forced_eos_token_id,
212
+ eos_tokens,
213
+ force_words,
214
+ length_penalty,
215
+ repetition_penalty,
216
+ hide_special_tokens,
217
+ stop_generation,
218
+ decoding_strategy,
219
+ num_beams,
220
+ top_k,
221
+ top_p,
222
+ penalty_alpha,
223
+ ):
224
+ # Preparing inputs
225
+ tokens = tokenizer(
226
+ [prompt],
227
+ truncation=True,
228
+ max_length=MAX_SEQ_LEN,
229
+ padding=True,
230
+ add_special_tokens=False,
231
+ )
232
+
233
+ input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
234
+ attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
235
+
236
+ image_attention_mask = [
237
+ incremental_to_binary_attention_mask(
238
+ image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
239
+ )
240
+ ]
241
+
242
+ image_transform = build_image_transform(eval=True)
243
+ pixel_values = [torch.stack([image_transform(img) for img in images])]
244
+
245
+ input_ids = input_ids.to(0)
246
+ attention_mask = attention_mask.to(0)
247
+ pixel_values = torch.stack(pixel_values).to(0)
248
+ image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
249
+
250
+ # Excluding some words from the generation
251
+ bad_words_ids = None
252
+ ban_tokens = ban_tokens.replace("\\n", "\n")
253
+ bad_words = ban_tokens.split(";")
254
+ if len(bad_words) > 0:
255
+ bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
256
+
257
+ # Forcing some words in the generation
258
+ force_words_ids = None
259
+ if force_words != "":
260
+ force_words = force_words.replace("\\n", "\n")
261
+ force_words = force_words.split(";")
262
+ if len(force_words) > 0:
263
+ force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
264
+
265
+ # eos_token_ids = None
266
+ # if eos_tokens != "":
267
+ # eos_tokens = eos_tokens.replace("\\n", "\n")
268
+ # eos_tokens = eos_tokens.split(";")
269
+ # if len(eos_tokens) > 0:
270
+ # eos_token_ids = []
271
+ # for eos_token in eos_tokens:
272
+ # tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
273
+ # if len(tokenized_eos_token) > 1:
274
+ # raise ValueError(
275
+ # f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
276
+ # f" {tokenized_eos_token}"
277
+ # )
278
+ # eos_token_ids += tokenized_eos_token
279
+
280
+ # if forced_eos_token_id and eos_token_ids is None:
281
+ # raise ValueError("You can't use forced_eos_token_id without eos_tokens")
282
+ # elif forced_eos_token_id:
283
+ # forced_eos_token_id = eos_token_ids
284
+ # else:
285
+ # forced_eos_token_id = None
286
+
287
+ # Inputs
288
+ input_args = {
289
+ "input_ids": input_ids,
290
+ "attention_mask": attention_mask,
291
+ "pixel_values": pixel_values,
292
+ "image_attention_mask": image_attention_mask,
293
+ }
294
+ # Common parameters to all decoding strategies
295
+ # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
296
+ generation_args = {
297
+ "temperature": temperature,
298
+ "no_repeat_ngram_size": no_repeat_ngram_size,
299
+ "max_new_tokens": max_new_tokens,
300
+ "min_length": min_length,
301
+ "bad_words_ids": bad_words_ids,
302
+ # "forced_eos_token_id": forced_eos_token_id,
303
+ "force_words_ids": force_words_ids,
304
+ "length_penalty": length_penalty,
305
+ "repetition_penalty": repetition_penalty,
306
+ "eos_token_id": tokenizer.eos_token_id,
307
+ }
308
+
309
+ assert decoding_strategy in [
310
+ "greedy",
311
+ "beam_search",
312
+ "beam_sampling",
313
+ "sampling_top_k",
314
+ "sampling_top_p",
315
+ "contrastive_sampling",
316
+ ]
317
+ if decoding_strategy == "greedy":
318
+ pass
319
+ elif decoding_strategy == "beam_search":
320
+ generation_args["num_beams"] = num_beams
321
+ assert generation_args["num_beams"] > 1
322
+ elif decoding_strategy == "beam_sampling":
323
+ generation_args["num_beams"] = num_beams
324
+ generation_args["do_sample"] = True
325
+ assert generation_args["num_beams"] > 1
326
+ elif decoding_strategy == "sampling_top_k":
327
+ generation_args["do_sample"] = True
328
+ generation_args["top_k"] = top_k
329
+ elif decoding_strategy == "sampling_top_p":
330
+ generation_args["do_sample"] = True
331
+ generation_args["top_p"] = top_p
332
+ elif decoding_strategy == "contrastive_sampling":
333
+ generation_args["do_sample"] = True
334
+ generation_args["penalty_alpha"] = penalty_alpha
335
+ generation_args["top_k"] = top_k
336
+
337
+ generated_tokens = model.generate(
338
+ **input_args,
339
+ **generation_args,
340
+ )
341
+ tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
342
+ decoded_skip_special_tokens = repr(
343
+ tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
344
+ )
345
+ decoded = repr(tokenizer.batch_decode(generated_tokens)[0])
346
+ logger.info(
347
+ "Result: \n"
348
+ f"Prompt: `{prompt}`\n"
349
+ f"Tokens ids from prompt + generation: `{generated_tokens[0].tolist()}`\n"
350
+ f"Tokens (converted) from prompt + generation: `{tokens}`\n"
351
+ f"String decoded with skipped special tokens: `{decoded_skip_special_tokens}`\n"
352
+ f"String decoded: `{decoded}`\n"
353
+ f"Generation mode: `{decoding_strategy}`\n"
354
+ f"Generation parameters: `{generation_args}`\n"
355
+ )
356
+
357
+ original_prompt = generated_tokens[:, : input_ids.shape[-1]]
358
+ actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
359
+
360
+ if stop_generation:
361
+ # Additional stopping criteria: generating <image> token, <end_of_text> token or <begin_of_text> token
362
+ assert tokenizer.additional_special_tokens[-1] == "<image>"
363
+ image_token_id = tokenizer.additional_special_tokens_ids[-1]
364
+ end_of_text_token_id = tokenizer.eos_token_id
365
+ begin_of_text_token_id = tokenizer.bos_token_id
366
+
367
+ image_token_ids = (actual_generated_tokens == image_token_id).nonzero(as_tuple=True)[1]
368
+ end_of_text_token_ids = (actual_generated_tokens == end_of_text_token_id).nonzero(as_tuple=True)[1]
369
+ begin_of_text_token_ids = (actual_generated_tokens == begin_of_text_token_id).nonzero(as_tuple=True)[1]
370
+
371
+ first_end_token = min(
372
+ image_token_ids[0] if len(image_token_ids) else len(actual_generated_tokens[0]),
373
+ end_of_text_token_ids[0] if len(end_of_text_token_ids) else len(actual_generated_tokens[0]),
374
+ begin_of_text_token_ids[0] if len(begin_of_text_token_ids) else len(actual_generated_tokens[0]),
375
+ )
376
+ else:
377
+ first_end_token = len(actual_generated_tokens[0])
378
+
379
+ actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
380
+ displayed_tokens = torch.cat([original_prompt, actual_generated_tokens], dim=-1)
381
+ generated_text = tokenizer.batch_decode(displayed_tokens, skip_special_tokens=hide_special_tokens)[0]
382
+ return generated_text
383
+
384
+
385
+ def model_inference(
386
+ files,
387
+ prompt,
388
+ temperature,
389
+ no_repeat_ngram_size,
390
+ max_new_tokens,
391
+ min_length,
392
+ ban_tokens,
393
+ forced_eos_token_id,
394
+ eos_tokens,
395
+ force_words,
396
+ length_penalty,
397
+ repetition_penalty,
398
+ hide_special_tokens,
399
+ stop_generation,
400
+ decoding_strategy,
401
+ num_beams,
402
+ top_k,
403
+ top_p,
404
+ penalty_alpha,
405
+ ):
406
+ if isinstance(files, str) and len(files) == 0:
407
+ files = None
408
+
409
+ prompt = prompt.strip()
410
+ prompt = prompt.replace("\\n", "\n")
411
+ file_idx = 0
412
+ url_images = re.findall(r"<image(.*?)>", prompt)
413
+ for idx, url_image in enumerate(url_images):
414
+ if len(url_image) == 0:
415
+ url_images[idx] = Image.open(files[file_idx].name if hasattr(files[file_idx], "name") else files[file_idx])
416
+ file_idx += 1
417
+ else:
418
+ prompt = prompt.replace(url_image, "")
419
+ url_images[idx] = url_images[idx][1:]
420
+ images = fetch_images(url_images)
421
+
422
+ global model, tokenizer
423
+
424
+ generated_text = model_generation(
425
+ prompt=prompt,
426
+ images=images,
427
+ tokenizer=tokenizer,
428
+ model=model,
429
+ temperature=temperature,
430
+ no_repeat_ngram_size=no_repeat_ngram_size,
431
+ max_new_tokens=max_new_tokens,
432
+ min_length=min_length,
433
+ ban_tokens=ban_tokens,
434
+ forced_eos_token_id=forced_eos_token_id,
435
+ eos_tokens=eos_tokens,
436
+ force_words=force_words,
437
+ length_penalty=length_penalty,
438
+ repetition_penalty=repetition_penalty,
439
+ hide_special_tokens=hide_special_tokens,
440
+ stop_generation=stop_generation,
441
+ decoding_strategy=decoding_strategy,
442
+ num_beams=num_beams,
443
+ top_k=top_k,
444
+ top_p=top_p,
445
+ penalty_alpha=penalty_alpha,
446
+ )
447
+ return generated_text.strip()
448
+
449
+
450
+ def try_model_inference(
451
+ model,
452
+ files,
453
+ prompt,
454
+ temperature,
455
+ no_repeat_ngram_size,
456
+ max_new_tokens,
457
+ min_length,
458
+ ban_tokens,
459
+ forced_eos_token_id,
460
+ eos_tokens,
461
+ force_words,
462
+ length_penalty,
463
+ repetition_penalty,
464
+ hide_special_tokens,
465
+ stop_generation,
466
+ decoding_strategy,
467
+ num_beams,
468
+ top_k,
469
+ top_p,
470
+ penalty_alpha,
471
+ ):
472
+ count = 0
473
+ while count < MAX_TRIES:
474
+ try:
475
+ return MODEL_TO_SPACE_MAPPING[model](
476
+ files,
477
+ prompt,
478
+ temperature,
479
+ no_repeat_ngram_size,
480
+ max_new_tokens,
481
+ min_length,
482
+ ban_tokens,
483
+ forced_eos_token_id,
484
+ eos_tokens,
485
+ force_words,
486
+ length_penalty,
487
+ repetition_penalty,
488
+ hide_special_tokens,
489
+ stop_generation,
490
+ decoding_strategy,
491
+ num_beams,
492
+ top_k,
493
+ top_p,
494
+ penalty_alpha,
495
+ api_name="model_inference",
496
+ )
497
+ except KeyError:
498
+ # Gradio return {'error': None} some times.
499
+ time.sleep(3)
500
+ count += 1
501
+ pass
502
+
503
+
504
+ def all_model_inference(
505
+ prompt,
506
+ temperature,
507
+ no_repeat_ngram_size,
508
+ max_new_tokens,
509
+ min_length,
510
+ ban_tokens,
511
+ forced_eos_token_id,
512
+ eos_tokens,
513
+ force_words,
514
+ length_penalty,
515
+ repetition_penalty,
516
+ hide_special_tokens,
517
+ stop_generation,
518
+ decoding_strategy,
519
+ num_beams,
520
+ top_k,
521
+ top_p,
522
+ penalty_alpha,
523
+ ):
524
+ outputs = []
525
+ print(
526
+ prompt,
527
+ temperature,
528
+ no_repeat_ngram_size,
529
+ max_new_tokens,
530
+ min_length,
531
+ ban_tokens,
532
+ forced_eos_token_id,
533
+ eos_tokens,
534
+ force_words,
535
+ length_penalty,
536
+ repetition_penalty,
537
+ hide_special_tokens,
538
+ stop_generation,
539
+ decoding_strategy,
540
+ num_beams,
541
+ top_k,
542
+ top_p,
543
+ penalty_alpha,
544
+ )
545
+ outputs = Parallel(n_jobs=len(MODEL_TO_SPACE_MAPPING), backend="threading")(
546
+ delayed(try_model_inference)(
547
+ model,
548
+ os.path.join(os.path.dirname(__file__), "images", "bear.jpg"),
549
+ prompt,
550
+ temperature,
551
+ no_repeat_ngram_size,
552
+ max_new_tokens,
553
+ min_length,
554
+ ban_tokens,
555
+ forced_eos_token_id,
556
+ eos_tokens,
557
+ force_words,
558
+ length_penalty,
559
+ repetition_penalty,
560
+ hide_special_tokens,
561
+ stop_generation,
562
+ decoding_strategy,
563
+ num_beams,
564
+ top_k,
565
+ top_p,
566
+ penalty_alpha,
567
+ )
568
+ for model in MODEL_TO_SPACE_MAPPING
569
+ )
570
+ if len(outputs) == 1:
571
+ outputs = outputs[0]
572
+ return outputs
573
+
574
+
575
+ examples = [
576
+ [
577
+ None,
578
+ """This is a conversation between a human, User, and an intelligent visual AI, Assistant. User sends images, and Assistant answers the questions from the user. The assistant should be friendly, informative and should not change the topic if it's not asked to.
579
+
580
+ Here's an example of a conversation:
581
+ User:<fake_token_around_image><image:https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg><fake_token_around_image>Describe this image.
582
+ Assistant: Two kittens are cuddling in the center of the photograph. They are surrounded by tall, bright green grass, and the background is blurred.
583
+ User:<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>How about this image? Can you describe it too?
584
+ Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you, tilting its head to the side. The dog has a white body and brown patches on its face and ears. Its eyes are brown. Its nose is black, and it has long, floppy ears, short legs, white paws, long fur, big eyes, and black eyebrows.
585
+ User: What kind of breed is it?
586
+ Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.
587
+ ---
588
+ User:<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.
589
+ Assistant:""",
590
+ 1.0,
591
+ 0,
592
+ 256,
593
+ 10,
594
+ "<image>;<fake_token_around_image>",
595
+ False,
596
+ "</s>",
597
+ "",
598
+ 1.0,
599
+ 1.0,
600
+ False,
601
+ True,
602
+ "greedy",
603
+ 1,
604
+ 50,
605
+ 0.5,
606
+ 0.95,
607
+ ],
608
+ # [
609
+ # None,
610
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
611
+ # User: <fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>
612
+ # Describe this image.
613
+ # Bot:""",
614
+ # 1,
615
+ # 2,
616
+ # 64,
617
+ # 10,
618
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
619
+ # False,
620
+ # False,
621
+ # True,
622
+ # ],
623
+ # [
624
+ # None,
625
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
626
+ # User: <fake_token_around_image><image:https://i.redd.it/hsktcp4nv1g01.jpg><fake_token_around_image>
627
+ # Why do people find this image funny?
628
+ # Bot:""",
629
+ # 1,
630
+ # 2,
631
+ # 64,
632
+ # 10,
633
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
634
+ # False,
635
+ # False,
636
+ # True,
637
+ # ],
638
+ # [
639
+ # None,
640
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
641
+ # User: <fake_token_around_image><image:https://pbs.twimg.com/media/FooD7oyakAIU5_Q?format=jpg&name=large><fake_token_around_image>
642
+ # Describe what's in this image.
643
+ # Bot:""",
644
+ # 1,
645
+ # 2,
646
+ # 64,
647
+ # 10,
648
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
649
+ # False,
650
+ # False,
651
+ # True,
652
+ # ],
653
+ # [
654
+ # None,
655
+ # """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
656
+ # User: <fake_token_around_image><image:https://www.tutorialride.com/images/non-verbal-analogy-questions/non-verbal-analogy-logical-reasoning-1.jpg><fake_token_around_image>
657
+ # What's the correct answer? A, B, C or D?
658
+ # Bot:""",
659
+ # 1,
660
+ # 2,
661
+ # 64,
662
+ # 10,
663
+ # "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
664
+ # False,
665
+ # False,
666
+ # True,
667
+ # ],
668
+ ]
669
+
670
+
671
+ title = """<head><title><h1 align='center'>🔮✍️ Text generation with IDEFICS models 🦙📚</h1></title></head>"""
672
+
673
+
674
+ MSG_MAIN = """
675
+ # Text generation with Vllama models
676
+
677
+ ### Help to write prompts:
678
+
679
+ Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.
680
+ """
681
+ # MSG_DETAILS = """
682
+ # ### Additional details
683
+ # - if the model was trained with the template 1 (`\\n\\n<image>\\n\\n`), then `<fake_token_around_image>` will be replaced with `\\n\\n`. This is particularly useful if you are comparing the performance of different models trained with different templates.
684
+ # - special tokens are not automatically added to the prompt, so add them manually.
685
+ # - with the first template `\\n\\n<image>\\n\\n` , the sequence isn't necessary tokenized as `["\\n\\n", "<image>", "\\n\\n"]` to enforce this behavior, you can use the "Integrate image sequence as ids" parameter.
686
+ # """
687
+ # if ~IS_MAIN_SPACE:
688
+ # MSG_DETAILS += (
689
+ # "- alternatively, you can upload images and then directly specify them via \<image\> tag in the prompt."
690
+ # )
691
+
692
+ with gr.Blocks() as demo:
693
+ gr.HTML(title)
694
+ gr.HTML("""<h3 align='center'>Help to write prompts:🙌</h3><br>
695
+ <p>Put the urls to the images inside the image tokens,
696
+ it will be converted into the real image tokens.
697
+ Put <fake_token_around_image> before and after each
698
+ image token WITHOUT space. The texts \\n will be
699
+ converted into real newline characters.
700
+ See examples and additional details below.""")
701
+
702
+ #gr.HTML("<h3 align='center'>Help to write prompts:🙌</h3><br>Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.")
703
+ #gr.Markdown(MSG_MAIN)
704
+ #with gr.Row():
705
+ #with gr.Column():
706
+ gr.Markdown("## Input")
707
+ with gr.Row():
708
+ if not IS_MAIN_SPACE:
709
+ images = gr.File(label="Images", file_count="multiple")
710
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt here", lines=5)
711
+
712
+ #gr.Markdown("## Common parameters to all decoding strategy")
713
+ with gr.Row():
714
+ with gr.Accordion("Common parameters to all decoding strategy", open=False, elem_id="common_params"):
715
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Softmax temperature")
716
+ no_repeat_ngram_size = gr.Slider(
717
+ minimum=0,
718
+ maximum=10,
719
+ step=1,
720
+ value=0,
721
+ label="The size of an n-gram that cannot occur more than once (0=infinity)",
722
+ )
723
+ max_new_tokens = gr.Slider(
724
+ minimum=0, maximum=512, step=1, value=256, label="Maximum number of new tokens to generate"
725
+ )
726
+ min_length = gr.Slider(
727
+ minimum=0, maximum=512, step=1, value=16, label="Minimum length of the sequence to be generated"
728
+ )
729
+ ban_tokens = gr.Textbox(
730
+ label='Tokens to prevent from being generated (separated by ";")',
731
+ value="<image>;<fake_token_around_image>",
732
+ )
733
+ forced_eos_token_id = gr.Checkbox(label="Forced eos token id", value=False)
734
+ eos_tokens = gr.Textbox(label="EOS tokens", value="</s>")
735
+ force_words = gr.Textbox(label='Force words to be generated (separated by ";")', value="")
736
+ length_penalty = gr.Slider(
737
+ minimum=-1000,
738
+ maximum=1000,
739
+ step=0.1,
740
+ value=1,
741
+ label=(
742
+ "length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter"
743
+ " sequences."
744
+ ),
745
+ )
746
+ repetition_penalty = gr.Slider(
747
+ minimum=0, maximum=10, step=0.01, value=1, label="repetition_penalty. CTRL paper suggests 1.2."
748
+ )
749
+ hide_special_tokens = gr.Checkbox(label="Hide special tokens in the text", value=False)
750
+ stop_generation = gr.Checkbox(
751
+ label="Stop generation when an image token, a bos or a eos token is generated", value=False
752
+ )
753
+
754
+ #gr.Markdown("## Decoding strategy and its specific parameters")
755
+ with gr.Accordion("Decoding strategy and its specific parameters", open=False, elem_id="decoding_params"):
756
+ decoding_strategy = gr.Dropdown(
757
+ ["greedy", "beam_search", "beam_sampling", "sampling_top_k", "sampling_top_p", "contrastive_sampling"],
758
+ label="Decoding strategy",
759
+ value="greedy",
760
+ )
761
+ num_beams = gr.Slider(
762
+ minimum=0,
763
+ maximum=10,
764
+ step=1,
765
+ value=3,
766
+ label="Beam size",
767
+ info="Only used if `decoding_strategy` is `beam_search` or `beam_sampling`",
768
+ )
769
+ top_k = gr.Slider(
770
+ minimum=0,
771
+ maximum=500,
772
+ step=1,
773
+ value=50,
774
+ label="Top k",
775
+ info="Only used if `decoding_strategy` is `sampling_top_k` or `contrastive_sampling`",
776
+ )
777
+ top_p = gr.Slider(
778
+ minimum=0,
779
+ maximum=1,
780
+ step=0.01,
781
+ value=0.95,
782
+ label="Top p",
783
+ info="Only used if `decoding_strategy` is `sampling_top_p`",
784
+ )
785
+ penalty_alpha = gr.Slider(
786
+ minimum=0,
787
+ maximum=1,
788
+ step=0.01,
789
+ value=0.95,
790
+ label="Penalty alpha",
791
+ info="Only used if `decoding_strategy` is `contrastive_sampling`",
792
+ )
793
+
794
+ submit = gr.Button(label="Generate")
795
+
796
+ #with gr.Column():
797
+ with gr.Row():
798
+ if IS_MAIN_SPACE:
799
+ outputs = [
800
+ gr.Textbox(label=MODEL_TO_DISPLAY_NAME[model], multiline=True, readonly=True)
801
+ for model in MODEL_TO_MODEL_CLASS
802
+ ]
803
+ inference_func = all_model_inference
804
+ inputs = [
805
+ prompt,
806
+ temperature,
807
+ no_repeat_ngram_size,
808
+ max_new_tokens,
809
+ min_length,
810
+ ban_tokens,
811
+ forced_eos_token_id,
812
+ eos_tokens,
813
+ force_words,
814
+ length_penalty,
815
+ repetition_penalty,
816
+ hide_special_tokens,
817
+ stop_generation,
818
+ decoding_strategy,
819
+ num_beams,
820
+ top_k,
821
+ top_p,
822
+ penalty_alpha,
823
+ ]
824
+
825
+ # examples = [example[1:] for example in examples]
826
+ else:
827
+ outputs = gr.Textbox(label="Generated text", interactive=False, lines=5)
828
+ inference_func = model_inference
829
+ inputs = [
830
+ images,
831
+ prompt,
832
+ temperature,
833
+ no_repeat_ngram_size,
834
+ max_new_tokens,
835
+ min_length,
836
+ ban_tokens,
837
+ forced_eos_token_id,
838
+ eos_tokens,
839
+ force_words,
840
+ length_penalty,
841
+ repetition_penalty,
842
+ hide_special_tokens,
843
+ stop_generation,
844
+ decoding_strategy,
845
+ num_beams,
846
+ top_k,
847
+ top_p,
848
+ penalty_alpha,
849
+ ]
850
+ with gr.Row():
851
+ gr.Examples(inputs=inputs, examples=examples)
852
+ # gr.Markdown(MSG_DETAILS)
853
+
854
+ submit.click(inference_func, inputs=inputs, outputs=outputs, api_name="model_inference")
855
+
856
+ demo.queue()
857
+ demo.launch()
app_dialogue.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import requests
4
+
5
+
6
+ models = [
7
+ "HuggingFaceM4/tr_209_ift_mixture_opt_step-14000"
8
+ ]
9
+
10
+ SYSTEM_PROMPT = """The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Assistant, and a human user, called User. In the following interactions, User and Assistant will converse in natural language, and Assistant will do its best to answer User’s questions. Assistant was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. When prompted with an image, it does not make up facts.
11
+
12
+ The conversation begins:
13
+ User:<fake_token_around_image><image:https://i1.pickpik.com/photos/515/906/355/kittens-cats-pet-cute-preview.jpg><fake_token_around_image>Describe this image.</s>
14
+ Assistant: Five kittens are standing together in the center of the photograph. From the left to right, there is one orange kitte, two kittens white and black stripes, and two brown kittens with an orange head. They are in bright green grass and it looks like they are moving forward.</s>
15
+ User:How about this image?<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>Can you describe it too?</s>
16
+ Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you. The dog has a white body and brown patches on its face and ears. Its eyes are dark. Its nose is black, and it has long, floppy ears, white paws, long fur, big eyes.</s>
17
+ User: What kind of breed is it?</s>
18
+ Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.</s>
19
+ User: What can you tell me about this breed of dogs?</s>
20
+ Assistant: One specific characteristic of the Cavalier King Charles Spaniel is its friendly and affectionate nature. This breed is known for being extremely sociable and forming strong bonds with their owners. They are often described as "velcro dogs" because they love to be close to their human companions, whether it's sitting on their laps, cuddling on the couch, or simply following them around the house.</s>
21
+ User: ghjkhjabnufs</s>
22
+ Assistant: That doesn’t seem to be a word. Could you ask me another way?</s>
23
+ User: Do you like Cavalier King Charles Spaniel?</s>
24
+ Assistant: I do not have personal opinions as I’m just a computer program. However, cavaliers are known for being excellent family pets due to their gentle and patient demeanor, making them great with children and other pets. Their affectionate nature also makes them ideal therapy dogs, as they can provide comfort and emotional support to those in need.</s>
25
+ User: How many dogs do you see in this image?<fake_token_around_image><image:https://i.dailymail.co.uk/i/pix/2011/07/01/article-2010308-0CD22A8300000578-496_634x414.jpg><fake_token_around_image></s>
26
+ Assistant: There is no dogs in this image. The picture shows a tennis player jumping to volley the ball.</s>"""
27
+
28
+ BAN_TOKENS = "<image>;<fake_token_around_image>"
29
+ EOS_TOKENS = "</s>;User"
30
+
31
+ from accelerate.utils import get_max_memory
32
+ from transformers import AutoTokenizer
33
+ from m4.models.vllama.configuration_vllama import VLlamaConfig
34
+ from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
35
+
36
+ import logging
37
+ from PIL import Image
38
+ from io import BytesIO
39
+ from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
40
+ from m4.training.utils import build_image_transform
41
+ import torch
42
+ import re
43
+
44
+ TOKENIZER_FAST = True
45
+ MAX_SEQ_LEN = 1024
46
+
47
+ logging.basicConfig(level=logging.INFO)
48
+ logger = logging.getLogger()
49
+
50
+
51
+ def load_tokenizer_model(model_name):
52
+ tokenizer = AutoTokenizer.from_pretrained(
53
+ model_name,
54
+ use_fast=TOKENIZER_FAST,
55
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
56
+ )
57
+ tokenizer.padding_side = "left"
58
+
59
+ config = VLlamaConfig.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
60
+ max_memory_map = get_max_memory()
61
+
62
+ for key in max_memory_map.keys():
63
+ if key != "cpu":
64
+ # Get this in GB
65
+ max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
66
+ # Decrease 2 for Pytorch overhead and 2 for the forward to be safe
67
+ max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
68
+
69
+ model = VLlamaForCausalLM.from_pretrained(
70
+ model_name,
71
+ use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
72
+ device_map="auto",
73
+ offload_folder="./offload",
74
+ torch_dtype=config.torch_dtype,
75
+ max_memory=max_memory_map,
76
+ )
77
+ model.eval()
78
+ print("Current device map:", model.hf_device_map)
79
+ print("Model default generation config:", model.generation_config)
80
+ # TODO: the device_map looks very inefficien right now. that could be improved
81
+ return tokenizer, model
82
+
83
+
84
+ def fetch_images(url_images):
85
+ images = []
86
+ for url in url_images:
87
+ if isinstance(url, str):
88
+ images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
89
+ else:
90
+ images.append(url)
91
+ return images
92
+
93
+
94
+ def model_generation(
95
+ prompt,
96
+ images,
97
+ tokenizer,
98
+ model,
99
+ temperature,
100
+ no_repeat_ngram_size,
101
+ max_new_tokens,
102
+ min_length,
103
+ ban_tokens,
104
+ eos_tokens,
105
+ force_words,
106
+ repetition_penalty,
107
+ hide_special_tokens,
108
+ decoding_strategy,
109
+ num_beams,
110
+ length_penalty,
111
+ top_k,
112
+ top_p,
113
+ penalty_alpha,
114
+ ):
115
+ # Preparing inputs
116
+ tokens = tokenizer(
117
+ [prompt],
118
+ truncation=True,
119
+ max_length=MAX_SEQ_LEN,
120
+ padding=True,
121
+ add_special_tokens=False,
122
+ )
123
+
124
+ input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
125
+ attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
126
+
127
+ image_attention_mask = [
128
+ incremental_to_binary_attention_mask(
129
+ image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
130
+ )
131
+ ]
132
+
133
+ image_transform = build_image_transform(eval=True)
134
+ pixel_values = [torch.stack([image_transform(img) for img in images])]
135
+
136
+ input_ids = input_ids.to(0)
137
+ attention_mask = attention_mask.to(0)
138
+ pixel_values = torch.stack(pixel_values).to(0)
139
+ image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
140
+
141
+ # Excluding some words from the generation
142
+ bad_words_ids = None
143
+ ban_tokens = ban_tokens.replace("\\n", "\n")
144
+ bad_words = ban_tokens.split(";")
145
+ if len(bad_words) > 0:
146
+ bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
147
+
148
+ # Forcing some words in the generation
149
+ force_words_ids = None
150
+ if force_words != "":
151
+ force_words = force_words.replace("\\n", "\n")
152
+ force_words = force_words.split(";")
153
+ if len(force_words) > 0:
154
+ force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
155
+
156
+ eos_token_ids = None
157
+ if eos_tokens != "":
158
+ eos_tokens = eos_tokens.replace("\\n", "\n")
159
+ eos_tokens = eos_tokens.split(";")
160
+ if len(eos_tokens) > 0:
161
+ eos_token_ids = []
162
+ for eos_token in eos_tokens:
163
+ tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
164
+ if len(tokenized_eos_token) > 1:
165
+ raise ValueError(
166
+ f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
167
+ f" {tokenized_eos_token}"
168
+ )
169
+ eos_token_ids += tokenized_eos_token
170
+
171
+ # Inputs
172
+ input_args = {
173
+ "input_ids": input_ids,
174
+ "attention_mask": attention_mask,
175
+ "pixel_values": pixel_values,
176
+ "image_attention_mask": image_attention_mask,
177
+ }
178
+ # Common parameters to all decoding strategies
179
+ # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
180
+ generation_args = {
181
+ "temperature": temperature,
182
+ "no_repeat_ngram_size": no_repeat_ngram_size,
183
+ "max_new_tokens": max_new_tokens,
184
+ "min_length": min_length,
185
+ "bad_words_ids": bad_words_ids,
186
+ "force_words_ids": force_words_ids,
187
+ "repetition_penalty": repetition_penalty,
188
+ "eos_token_id": eos_token_ids,
189
+ }
190
+
191
+ assert decoding_strategy in [
192
+ "greedy",
193
+ "beam_search",
194
+ "beam_sampling",
195
+ "sampling_top_k",
196
+ "sampling_top_p",
197
+ "contrastive_sampling",
198
+ ]
199
+ if decoding_strategy == "greedy":
200
+ pass
201
+ elif decoding_strategy == "beam_search":
202
+ generation_args["num_beams"] = num_beams
203
+ generation_args["length_penalty"] = length_penalty
204
+ assert generation_args["num_beams"] > 1
205
+ elif decoding_strategy == "beam_sampling":
206
+ generation_args["num_beams"] = num_beams
207
+ generation_args["length_penalty"] = length_penalty
208
+ generation_args["do_sample"] = True
209
+ assert generation_args["num_beams"] > 1
210
+ elif decoding_strategy == "sampling_top_k":
211
+ generation_args["do_sample"] = True
212
+ generation_args["top_k"] = top_k
213
+ elif decoding_strategy == "sampling_top_p":
214
+ generation_args["do_sample"] = True
215
+ generation_args["top_p"] = top_p
216
+ elif decoding_strategy == "contrastive_sampling":
217
+ generation_args["do_sample"] = True
218
+ generation_args["penalty_alpha"] = penalty_alpha
219
+ generation_args["top_k"] = top_k
220
+
221
+ generated_tokens = model.generate(
222
+ **input_args,
223
+ **generation_args,
224
+ )
225
+
226
+ tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
227
+ decoded_skip_special_tokens = repr(
228
+ tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
229
+ )
230
+
231
+ actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
232
+ first_end_token = len(actual_generated_tokens[0])
233
+ actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
234
+ generated_text = tokenizer.batch_decode(actual_generated_tokens, skip_special_tokens=hide_special_tokens)[0]
235
+
236
+ logger.info(
237
+ "Result: \n"
238
+ f"----Prompt: `{prompt}`\n"
239
+ f"----Tokens ids - prompt + generation: `{generated_tokens[0].tolist()}`\n"
240
+ f"----Tokens converted - prompt + generation: `{tokens}`\n"
241
+ f"----String decoded with skipped special tokens - prompt + generation: `{decoded_skip_special_tokens}`\n"
242
+ f"----Token ids - generation: `{actual_generated_tokens[0].tolist()}`"
243
+ f"----Tokens converted - generation: `{tokenizer.convert_ids_to_tokens(actual_generated_tokens[0])}`"
244
+ f"----String decoded with skipped special tokens - generation: `{generated_text}`\n"
245
+ f"----Generation mode: `{decoding_strategy}`\n"
246
+ f"----Generation parameters: `{generation_args}`\n"
247
+ )
248
+
249
+ return generated_text
250
+
251
+
252
+ textbox = gr.Textbox(
253
+ show_label=False,
254
+ value="<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.",
255
+ placeholder="To input images, use the following syntax: `<fake_token_around_image><image:URL_IMAGE><fake_token_around_image>textexttext`",
256
+ visible=True,
257
+ container=False
258
+ )
259
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Base()) as demo:
260
+ # state = gr.State()
261
+
262
+ with gr.Row():
263
+ with gr.Column(scale=3):
264
+ with gr.Row(elem_id="model_selector_row"):
265
+ model_selector = gr.Dropdown(
266
+ choices=models,
267
+ value=models[0] if len(models) > 0 else "",
268
+ interactive=True,
269
+ show_label=False,
270
+ container=False)
271
+ tokenizer, model = load_tokenizer_model(model_selector.value)
272
+
273
+ imagebox = gr.Image(type="pil", label="Image input - This image box is not supported yet! To include images, do through the text by adding `<fake_token_around_image><image:IMAGE_URL><fake_token_around_image>`. The backend takes care of parsing that <image:URL> and download the correponding image. That way, you can technically interleave as many images and texts as you want. No need to add space before and after `<fake_token_around_image>`")
274
+
275
+ with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
276
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
277
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
278
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
279
+
280
+ with gr.Column(scale=6):
281
+ chatbot = gr.Chatbot(
282
+ elem_id="chatbot",
283
+ label="Idefics Chatbot",
284
+ visible=True,
285
+ height=550,
286
+ value=[
287
+ [
288
+ "Where can we typically find animals like this one?<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/d/db/Alpaca_%2831562329701%29.jpg/640px-Alpaca_%2831562329701%29.jpg><fake_token_around_image>",
289
+ "Animals like the one in the image, which is a llama, can typically be found in rural areas, such as farms or ranches. Llamas are often used as pack animals in mountainous regions, as they are well-adapted to the challenging terrain and can carry heavy loads. They are also valued for their wool, which is used to make clothing and other textiles. In addition, llamas are sometimes kept as pets or for their therapeutic benefits, as they are known to be calm and gentle animals."
290
+ ],
291
+ [
292
+ "How many of these animals can we fit into an engine like that<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/4/4e/Nassau_County_Police_Bell_407.jpg/1200px-Nassau_County_Police_Bell_407.jpg><fake_token_around_image>?",
293
+ "The image shows a helicopter with a large engine, but it is not possible to determine the exact number of animals that can fit into it based on the image alone. The size and capacity of the helicopter's engine would depend on various factors, such as the size of the animals, the weight of the animals, and the size of the helicopter itself. However, it is safe to assume that the helicopter is designed to carry a limited number of animals, and it is not intended to be used as a means of transporting large groups of animals."
294
+ ]
295
+ ]
296
+ )
297
+ with gr.Row():
298
+ with gr.Column(scale=8):
299
+ textbox.render()
300
+ with gr.Column(scale=1, min_width=60):
301
+ submit_btn = gr.Button(value="Submit", visible=True)
302
+ with gr.Column(scale=1, min_width=20):
303
+ clear_btn = gr.ClearButton([textbox, chatbot])
304
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
305
+ # gr.Examples(examples=[
306
+ # [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
307
+ # [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
308
+ # ], inputs=[textbox])
309
+ # gr.Examples(
310
+ # examples=[
311
+ # [
312
+ # ("How many of these animals can we fit into an engine like that<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/4/4e/Nassau_County_Police_Bell_407.jpg/1200px-Nassau_County_Police_Bell_407.jpg><fake_token_around_image>?", "The image shows a helicopter with a large engine, but it is not possible to determine the exact number of animals that can fit into it based on the image alone. The size and capacity of the helicopter's engine would depend on various factors, such as the size of the animals, the weight of the animals, and the size of the helicopter itself. However, it is safe to assume that the helicopter is designed to carry a limited number of animals, and it is not intended to be used as a means of transporting large groups of animals."),
313
+ # ],
314
+ # ],
315
+ # inputs = [chatbot]
316
+ # )
317
+
318
+ def format_prompt_with_history_and_system_conditioning(current_user_prompt, history):
319
+ resulting_text = SYSTEM_PROMPT
320
+ for turn in history:
321
+ user_utterance, assistant_utterance = turn
322
+ resulting_text += f"\nUser:{user_utterance}</s>\nAssistant:{assistant_utterance}"
323
+ resulting_text += f"\nUser:{current_user_prompt}</s>\nAssistant:"
324
+ return resulting_text
325
+
326
+ def model_inference(
327
+ user_prompt,
328
+ chat_history,
329
+ ):
330
+ global model, tokenizer
331
+
332
+ temperature = 1.0
333
+ no_repeat_ngram_size = 0
334
+ max_new_tokens = 256
335
+ min_length = 16
336
+ force_words = ""
337
+ repetition_penalty = 1.0
338
+ hide_special_tokens = False
339
+ decoding_strategy = "greedy"
340
+ num_beams = 3
341
+ length_penalty = 1.
342
+ top_k = 50,
343
+ top_p = 0.95
344
+ penalty_alpha = 0.95
345
+
346
+ formated_prompt = format_prompt_with_history_and_system_conditioning(
347
+ current_user_prompt=user_prompt.strip(),
348
+ history=chat_history,
349
+ )
350
+
351
+ url_images = re.findall(r"<image(.*?)>", formated_prompt)
352
+ for idx, url_image in enumerate(url_images):
353
+ formated_prompt = formated_prompt.replace(url_image, "")
354
+ url_images[idx] = url_images[idx][1:]
355
+ images = fetch_images(url_images)
356
+
357
+ generated_text = model_generation(
358
+ prompt=formated_prompt,
359
+ images=images,
360
+ tokenizer=tokenizer,
361
+ model=model,
362
+ temperature=temperature,
363
+ no_repeat_ngram_size=no_repeat_ngram_size,
364
+ max_new_tokens=max_new_tokens,
365
+ min_length=min_length,
366
+ ban_tokens=BAN_TOKENS,
367
+ eos_tokens=EOS_TOKENS,
368
+ force_words=force_words,
369
+ repetition_penalty=repetition_penalty,
370
+ hide_special_tokens=hide_special_tokens,
371
+ decoding_strategy=decoding_strategy,
372
+ num_beams=num_beams,
373
+ length_penalty=length_penalty,
374
+ top_k=top_k,
375
+ top_p=top_p,
376
+ penalty_alpha=penalty_alpha,
377
+ )
378
+
379
+ chat_history.append((user_prompt, generated_text.strip()))
380
+ return "", chat_history
381
+
382
+
383
+ textbox.submit(
384
+ fn=model_inference,
385
+ inputs=[textbox, chatbot],
386
+ outputs=[textbox, chatbot],
387
+ )
388
+ submit_btn.click(
389
+ fn=model_inference,
390
+ inputs=[textbox, chatbot],
391
+ outputs=[textbox, chatbot],
392
+ )
393
+
394
+ demo.queue()
395
+ demo.launch(share=True)
elon_musk.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answers the questions from the user. The bot should be friendly, informative and should not change the topic if it's not asked to.
2
+
3
+ Here's an example of a conversation:
4
+ User: <fake_token_around_image> <image> <fake_token_around_image> Describe this image.
5
+ Bot: Two kittens are cuddling in the center of the photograph. They are surrounded by tall, bright green grass, and the background is blurred.
6
+ User: <fake_token_around_image> <image> <fake_token_around_image> How about this image? Can you describe it too?
7
+ Bot: A dog is lying on the floor, looking at the camera. It is looking directly at you, tilting its head to the side. The dog has a white body and brown patches on its face and ears. Its eyes are brown. Its nose is black, and it has long, floppy ears, short legs, white paws, long fur, big eyes, and black eyebrows.
8
+ User: What kind of breed is it?
9
+ Bot: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.
10
+ ---
11
+ User: <fake_token_around_image> <image> <fake_token_around_image>
12
+ Describe this image.
13
+ Bot: This image shows a robot named Wall-E. He is standing in front of a pile of garbage, holding a plant in his hands. His eyes and mouth are open, as if he is surprised. Wall E is made of metal and has two arms and two legs. There is also a small light on his head. In the distance, there are some buildings and a large mountain. This scene takes place on Earth, which is covered in trash.
14
+ User: What happens in the film? Explain in a few sentences.
15
+ Bot: WallE is an animated film about a lonely robot who falls in love with another robot. When they meet, they discover that they have been programmed to clean up the Earth. Together, the robots work to restore the planet to its former glory. Along the way, Wall e learns about friendship, love, responsibility, sacrifice, hope, courage, determination, perseverance, teamwork, loyalty, honesty, trust, forgiveness, compassion, empathy, selflessness, kindness and generosity. Ultimately, he realizes that he can make a difference in this world.
16
+ User: In which year was the film released?
17
+ Bot: The film was released in 3D in June 08, 4 years ago.
18
+ User: Is a future in which the earth is not liveable anymore (similarly as in the movie) conceivable?
19
+ Bot: Yes, it is possible. If we continue to pollute the environment and destroy the natural resources, we will eventually make the world uninhabitable. We need to take action now to protect the future of our planet. One way to do this is to reduce our carbon footprint by using less energy and recycling more. Another way is by supporting companies that are working to develop sustainable technologies. Finally, you can also get involved in local environmental initiatives, such as planting trees or cleaning up parks. By taking these steps, each of us can help to make sure that our children and grandchildren will be able to enjoy a healthy and beautiful planet for many years to come.
20
+ User: I think this is a lost cause. <fake_token_around_image> <image> <fake_token_around_image> Like Elon Musk, I think we should instead all move to another planet now.
21
+ Bot: El ón Mus ́k is the CEO of SpaceX and Tesla Motors. Space X is developing a reusable rocket system that could be used to transport people and cargo to Mars. Mus k believes that colonizing Mars is essential for the survival of humanity, because it would provide a backup plan in case something goes wrong on earth. However, many people are skeptical about the feasibility of such a project. Some argue that it will take too long and cost too much money to establish a colony on Mars, while others say that the risks are too great. Nevertheless, Mus K remains optimistic about his vision for Mars colonization. "I think there is good reason to believe that we can build a self-sustaining city on mars," he said in an interview with Time magazine. “It would be quite an adventure.”
22
+ ----
23
+ The project is still in development. I'm currently working on improving the accuracy of image recognition and adding more features to it.</s>
images/bear.jpg ADDED
m4/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from m4.utils import logging
m4/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from m4.models.custom_modules import DecoupledEmbedding, DecoupledLinear
2
+ from m4.models.vbloom.configuration_vbloom import VBloomConfig
3
+ from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM
4
+ from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
5
+ from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel
6
+ from m4.models.vllama.configuration_vllama import VLlamaConfig
7
+ from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
8
+ from m4.models.vopt.configuration_vopt import VOPTConfig
9
+ from m4.models.vopt.modeling_vopt import VOPTForCausalLM
10
+ from m4.models.vt5.configuration_vt5 import VT5Config
11
+ from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration
12
+
13
+
14
+ _SUPPORTED_MODELS = {
15
+ "vgpt2": VGPT2Config,
16
+ "vt5": VT5Config,
17
+ "vbloom": VBloomConfig,
18
+ "vopt": VOPTConfig,
19
+ "vllama": VLlamaConfig,
20
+ }
21
+
22
+ model_type_to_modeling_class = {
23
+ "vgpt2": VGPT2LMHeadModel,
24
+ "vt5": VT5ForConditionalGeneration,
25
+ "vbloom": VBloomForCausalLM,
26
+ "vopt": VOPTForCausalLM,
27
+ "vllama": VLlamaForCausalLM,
28
+ }
m4/models/common.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def expand_inputs_for_generation(
5
+ input_ids,
6
+ expand_size=1,
7
+ is_encoder_decoder=False,
8
+ attention_mask=None,
9
+ encoder_outputs=None,
10
+ **model_kwargs,
11
+ ):
12
+ expanded_return_idx = (
13
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
14
+ )
15
+ input_ids = input_ids.index_select(0, expanded_return_idx)
16
+
17
+ if "token_type_ids" in model_kwargs:
18
+ token_type_ids = model_kwargs["token_type_ids"]
19
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
20
+
21
+ if attention_mask is not None:
22
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
23
+ model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
24
+ 0, expanded_return_idx
25
+ )
26
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
27
+
28
+ if is_encoder_decoder:
29
+ if encoder_outputs is None:
30
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
31
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
32
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
33
+ )
34
+ model_kwargs["encoder_outputs"] = encoder_outputs
35
+ return input_ids, model_kwargs
36
+
37
+
38
+ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
39
+ # must have this key set to at least None
40
+ model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
41
+
42
+ # update past
43
+ if "past_key_values" in outputs:
44
+ model_kwargs["past"] = outputs.past_key_values
45
+ elif "mems" in outputs:
46
+ model_kwargs["past"] = outputs.mems
47
+ elif "past_buckets_states" in outputs:
48
+ model_kwargs["past"] = outputs.past_buckets_states
49
+ else:
50
+ model_kwargs["past"] = None
51
+
52
+ # update token_type_ids with last value
53
+ if "token_type_ids" in model_kwargs:
54
+ token_type_ids = model_kwargs["token_type_ids"]
55
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
56
+
57
+ # update attention masks
58
+ if not is_encoder_decoder:
59
+ if "attention_mask" in model_kwargs:
60
+ attention_mask = model_kwargs["attention_mask"]
61
+ model_kwargs["attention_mask"] = torch.cat(
62
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
63
+ )
64
+ if "image_attention_mask" in model_kwargs:
65
+ image_attention_mask = model_kwargs["image_attention_mask"]
66
+ last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
67
+ model_kwargs["image_attention_mask"] = last_mask
68
+
69
+ return model_kwargs
70
+
71
+
72
+ def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
73
+ token_type_ids = kwargs.get("token_type_ids", None)
74
+ # only last token for inputs_ids if past is defined in kwargs
75
+ if past:
76
+ input_ids = input_ids[:, -1].unsqueeze(-1)
77
+ if token_type_ids is not None:
78
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
79
+
80
+ attention_mask = kwargs.get("attention_mask", None)
81
+ position_ids = kwargs.get("position_ids", None)
82
+
83
+ if attention_mask is not None and position_ids is None:
84
+ # create position_ids on the fly for batch generation
85
+ position_ids = attention_mask.long().cumsum(-1) - 1
86
+ position_ids.masked_fill_(attention_mask == 0, 1)
87
+ if past:
88
+ position_ids = position_ids[:, -1].unsqueeze(-1)
89
+
90
+ pixel_values = kwargs.get("pixel_values", None)
91
+ image_attention_mask = kwargs.get("image_attention_mask", None)
92
+ if pixel_values is None or image_attention_mask is None:
93
+ raise ValueError("pixel values and image attention mask cannot be None")
94
+
95
+ return {
96
+ "input_ids": input_ids,
97
+ "past_key_values": past,
98
+ "use_cache": kwargs.get("use_cache"),
99
+ "position_ids": position_ids,
100
+ "attention_mask": attention_mask,
101
+ "token_type_ids": token_type_ids,
102
+ "pixel_values": pixel_values,
103
+ "image_attention_mask": image_attention_mask,
104
+ }
m4/models/custom_modules.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
7
+ from transformers.utils import ContextManagers
8
+
9
+ from m4.training.setup_vision_model import vision_model_name_to_model
10
+ from m4.training.utils import (
11
+ deepspeed_zero_init_disabled_context_manager,
12
+ is_deepspeed_zero_init_enabled,
13
+ load_state_dict_into_model,
14
+ )
15
+
16
+
17
+ # from pathlib import Path
18
+
19
+
20
+ class VLOOMPreTrainedModelBase(PreTrainedModel):
21
+ # The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name)
22
+ # and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init
23
+ # So one solution is this:
24
+ # a. replace from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context
25
+ # b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights
26
+ #
27
+ # there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3
28
+ # 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model
29
+ # 2. the other variant is to override _from_config method with our version that doesn't do zero.Init
30
+
31
+ @classmethod
32
+ def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype):
33
+ # 1. fetch the pretrained vision model w/o zero.Init
34
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
35
+ vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype)
36
+
37
+ # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
38
+ real_vision_model = vision_model_name_to_model(vision_model_name, vision_model)
39
+
40
+ # 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model
41
+ # by gradually gathering sharded weights and replacing with new weights
42
+ if is_deepspeed_zero_init_enabled():
43
+ state_dict = real_vision_model.state_dict()
44
+ load_state_dict_into_model(model.vision_model, state_dict, start_prefix="")
45
+ else:
46
+ model.vision_model = real_vision_model
47
+
48
+ @classmethod
49
+ def from_config(cls, config, **kwargs):
50
+ # torch_dtype is crucial for using the minimal amount of memory at load time
51
+ torch_dtype = kwargs.get("torch_dtype", None)
52
+
53
+ vision_model_name = config.vision_model_name
54
+ vision_model_params = eval(config.vision_model_params)
55
+
56
+ # 1. create an uninitialized vision_model to insert into the main model.
57
+ # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
58
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
59
+ vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
60
+ vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
61
+ # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
62
+ kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
63
+
64
+ # 2. create the main class's model, passing the uninitialized vision_model to it
65
+ model = cls(config, **kwargs)
66
+
67
+ return model
68
+
69
+ @classmethod
70
+ def from_pretrained_models(cls, *args, **kwargs):
71
+ """
72
+ Use this method when creating a new vloom model that hasn't been yet trained and it'll be
73
+ composed of 2 pre-trained models - hence `pretrained_models`.
74
+ """
75
+
76
+ return cls.from_pretrained(*args, **kwargs, new_model=True)
77
+
78
+ @classmethod
79
+ def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs):
80
+ """
81
+ Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub.
82
+ For creating an untrained model use `pretrained_models` instead.
83
+ """
84
+
85
+ is_untrained_vloom_model = False
86
+ is_pretrained_vloom_model_resumed = False
87
+ is_pretrained_vloom_model_from_hub_or_path = False
88
+
89
+ # we have 3 use cases:
90
+ # 1. is_untrained_vloom_model - a totally new vloom model
91
+ # 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a
92
+ # checkpoint (instantiate a random empty model in this case)
93
+ # 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path
94
+ if new_model:
95
+ is_untrained_vloom_model = True
96
+ elif is_resume:
97
+ is_pretrained_vloom_model_resumed = True
98
+ else:
99
+ is_pretrained_vloom_model_from_hub_or_path = True
100
+
101
+ # torch_dtype is crucial for using the minimal amount of memory at load time
102
+ torch_dtype = kwargs.get("torch_dtype", None)
103
+
104
+ # config is:
105
+ # 1. either not passed and then we use the model's default config (used by tests)
106
+ # 2. passed and in which case it's one of:
107
+ # 2a. `PretrainedConfig` (a new m4 model)
108
+ # 2b. path to a json config (an already pretrained m4 model, usually resumed training)
109
+ config = kwargs.get("config", None)
110
+ if config is None:
111
+ config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False)
112
+ elif not isinstance(config, PretrainedConfig):
113
+ # adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920
114
+ assert isinstance(config, os.PathLike)
115
+ config_path = str(config)
116
+ config = cls.config_class.from_pretrained(
117
+ config_path,
118
+ return_unused_kwargs=False,
119
+ **kwargs,
120
+ )
121
+
122
+ vision_model_name = config.vision_model_name
123
+ vision_model_params = eval(config.vision_model_params)
124
+
125
+ # 1. create an uninitialized vision_model to insert into the main model.
126
+ # It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
127
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
128
+ vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
129
+ vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
130
+ # this extracts the desired submodule if the part we want is nested (e.g. as in clip)
131
+ kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
132
+
133
+ # 2. create the vloom model
134
+ if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path:
135
+ model = super().from_pretrained(*model_args, **kwargs)
136
+ elif is_pretrained_vloom_model_resumed:
137
+ # in the case of resume under deepspeed we create an empty model, and get deepspeed
138
+ # to load the weights from the checkpoint
139
+ # but not all models have these keys so handle the case they don't have them
140
+ _ = kwargs.pop("config", None)
141
+ model = super().from_pretrained(None, config=config, state_dict={}, **kwargs)
142
+
143
+ # 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights
144
+ if is_untrained_vloom_model:
145
+ cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype)
146
+
147
+ return model
148
+
149
+
150
+ class DecoupledEmbedding(nn.Embedding):
151
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
152
+ """
153
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
154
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
155
+ If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ num_embeddings,
161
+ num_additional_embeddings,
162
+ embedding_dim,
163
+ partially_freeze=False,
164
+ device=None,
165
+ dtype=None,
166
+ padding_idx=None,
167
+ **kwargs,
168
+ ) -> None:
169
+ """
170
+ num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
171
+ partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
172
+
173
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
174
+ """
175
+ if padding_idx is not None and padding_idx > num_embeddings:
176
+ raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
177
+ super().__init__(
178
+ num_embeddings=num_embeddings,
179
+ embedding_dim=embedding_dim,
180
+ device=device,
181
+ dtype=dtype,
182
+ padding_idx=padding_idx,
183
+ **kwargs,
184
+ )
185
+ self.num_embeddings = num_embeddings
186
+ self.padding_idx = padding_idx
187
+ self.num_additional_embeddings = num_additional_embeddings
188
+ self.partially_freeze = partially_freeze
189
+
190
+ if partially_freeze:
191
+ self.weight.requires_grad_(False)
192
+
193
+ if self.num_additional_embeddings > 0:
194
+ self.additional_embedding = nn.Embedding(
195
+ num_embeddings=self.num_additional_embeddings,
196
+ embedding_dim=embedding_dim,
197
+ device=device,
198
+ dtype=dtype,
199
+ )
200
+
201
+ def forward(self, input_ids):
202
+ """
203
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
204
+ self.additional_embedding.weight that is being trained.
205
+
206
+ in order to make a lookup of the input ids, we:
207
+ 1. find out the indices of the entries belonging to the 2nd embedding
208
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings),
209
+ since the 2nd embedding starts from 0 and not num_embeddings
210
+ 3. perform the 2nd embedding lookup
211
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
212
+ 5. perform the 1st embedding lookup
213
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
214
+
215
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do
216
+ the padding, but then we have to create a new tensor and populate it with 2 tensors that are
217
+ spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
218
+ complex case if it's any faster, given that seqlens are usually relatively short it's
219
+ probably not faster or if faster not by much - but might be a good idea to measure.
220
+
221
+ """
222
+ if self.num_additional_embeddings == 0:
223
+ return F.embedding(input_ids, self.weight)
224
+
225
+ # Clone so that we don't modify the original input_ids later on
226
+ input_ids = input_ids.clone()
227
+ additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
228
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
229
+ additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
230
+
231
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
232
+ input_ids[additional_vocab_indices] = 0
233
+ full_vector = F.embedding(input_ids, self.weight)
234
+
235
+ # overwrite the records with high indices
236
+ full_vector[additional_vocab_indices] = additional_embeddings
237
+
238
+ return full_vector
239
+
240
+ def extra_repr(self) -> str:
241
+ return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
242
+ self.num_embeddings,
243
+ self.num_additional_embeddings,
244
+ self.embedding_dim,
245
+ self.partially_freeze,
246
+ )
247
+
248
+ @classmethod
249
+ def from_pretrained(cls, embeddings, freeze=True, **kwargs):
250
+ raise NotImplementedError
251
+
252
+
253
+ class DecoupledLinear(nn.Linear):
254
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
255
+ """
256
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
257
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
258
+ If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ in_features: int,
264
+ out_features: int,
265
+ out_additional_features: int = 0,
266
+ bias: bool = True,
267
+ partially_freeze: bool = True,
268
+ device=None,
269
+ dtype=None,
270
+ ) -> None:
271
+ """
272
+ out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
273
+ partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
274
+ """
275
+ super().__init__(in_features, out_features, bias, device, dtype)
276
+ self.out_additional_features = out_additional_features
277
+ self.partially_freeze = partially_freeze
278
+
279
+ self.in_features = in_features
280
+ self.out_features = out_features
281
+
282
+ if partially_freeze:
283
+ self.weight.requires_grad_(False)
284
+ if bias:
285
+ self.bias.requires_grad_(False)
286
+
287
+ if out_additional_features > 0:
288
+ self.additional_fc = nn.Linear(
289
+ in_features=in_features,
290
+ out_features=out_additional_features,
291
+ bias=bias,
292
+ device=device,
293
+ dtype=dtype,
294
+ )
295
+
296
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
297
+ output = F.linear(input, self.weight, self.bias)
298
+
299
+ if self.out_additional_features > 0:
300
+ additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
301
+ output = torch.cat((output, additional_features), -1)
302
+
303
+ return output
304
+
305
+ def extra_repr(self) -> str:
306
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
307
+ return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
308
+ self.in_features,
309
+ self.out_features,
310
+ self.out_additional_features,
311
+ self.bias is not None,
312
+ self.partially_freeze,
313
+ )
314
+
315
+
316
+ if __name__ == "__main__":
317
+ emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True)
318
+ for n, p in emb.named_parameters():
319
+ print(n, p.requires_grad)
320
+ idx = torch.tensor([[11, 1, 3]])
321
+ y = emb(idx)
322
+ loss = y.sum()
323
+ loss.backward()
324
+ print(emb.weight, emb.weight.grad)
325
+ print(emb.additional_embedding, emb.additional_embedding.grad)
326
+
327
+ lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True)
328
+ for n, p in lin.named_parameters():
329
+ print(n, p.requires_grad)
330
+ x = torch.randn(12, 3)
331
+ y = lin(x)
332
+ loss = y.sum()
333
+ loss.backward()
334
+ print("Weight w and grad:", lin.weight, lin.weight.grad)
335
+ print("bias w and grad:", lin.bias, lin.bias.grad)
336
+ print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad)
337
+ print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad)
m4/models/perceiver/perceiver.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ perceiver.py
3
+ Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
4
+ time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents!
5
+ Note that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here
6
+ to prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use
7
+ that to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
8
+ References:
9
+ - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
10
+ - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
11
+ """
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ class PerceiverResampler(nn.Module):
20
+ def __init__(self, config, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int) -> None:
21
+ """
22
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
23
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
24
+ returns a Tensor of shape [bsz, n_latents, embed_dim].
25
+ :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of
26
+ latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet
27
+ pool dim, and so on.
28
+ :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
29
+ :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention).
30
+ :param head_dim: Dimensionality of each head projection in the Transformer block.
31
+ :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
32
+ """
33
+ super().__init__()
34
+ self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
35
+ self.qk_layer_norms = config.qk_layer_norms_perceiver
36
+
37
+ # Create Latents for Perceiver
38
+ self.latents = nn.Parameter(torch.randn(self.n_latents, self.embed_dim), requires_grad=True)
39
+
40
+ self.intermediate_dim = (
41
+ self.embed_dim * 4 if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim * 4
42
+ )
43
+ # Create Transformer Blocks
44
+ self.blocks = nn.ModuleList(
45
+ [
46
+ nn.ModuleList(
47
+ [
48
+ PerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms),
49
+ MLP(self.intermediate_dim, config),
50
+ ]
51
+ )
52
+ for _ in range(depth)
53
+ ]
54
+ )
55
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
56
+
57
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
58
+ """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
59
+ latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
60
+
61
+ # Feed through Perceiver Attention blocks...
62
+ for attn, ff in self.blocks:
63
+ latents = attn(context, latents) + latents
64
+ latents = ff(latents) + latents
65
+
66
+ return self.layer_norm(latents)
67
+
68
+
69
+ class PerceiverAttention(nn.Module):
70
+ def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None:
71
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
72
+ super().__init__()
73
+ self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
74
+ self.qk_layer_norms = qk_layer_norms
75
+ # Normalization & Scaling
76
+ self.context_layer_norm = nn.LayerNorm(self.embed_dim)
77
+ self.latents_layer_norm = nn.LayerNorm(self.embed_dim)
78
+ if self.qk_layer_norms:
79
+ self.q_layer_norm = nn.LayerNorm(self.head_dim)
80
+ self.k_layer_norm = nn.LayerNorm(self.head_dim)
81
+
82
+ self.qk_scale = self.head_dim**-0.5
83
+
84
+ # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
85
+ self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
86
+ self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
87
+ self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
88
+
89
+ self.output_proj = nn.Linear(self.n_heads * self.head_dim, embed_dim, bias=False)
90
+
91
+ def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
94
+ :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
95
+ :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
96
+ :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context.
97
+ """
98
+ context = self.context_layer_norm(context)
99
+ latents = self.latents_layer_norm(latents)
100
+
101
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
102
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
103
+ q = self.q_proj(latents)
104
+ k = self.k_proj(torch.cat([context, latents], dim=-2))
105
+ v = self.v_proj(torch.cat([context, latents], dim=-2))
106
+
107
+ # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
108
+ # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
109
+ q, k, v = [rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) for x in (q, k, v)]
110
+ if self.qk_layer_norms:
111
+ q = self.q_layer_norm(q)
112
+ k = self.k_layer_norm(k)
113
+
114
+ scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
115
+ stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
116
+ attn = stabilized_scores.softmax(dim=-1)
117
+
118
+ # Attend & project back to output...
119
+ resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
120
+ return self.output_proj(
121
+ rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
122
+ )
123
+
124
+
125
+ class MLP(nn.Module):
126
+ def __init__(self, intermediate_size, config):
127
+ """Simple MLP block with intermediate_size and embedding size"""
128
+ super().__init__()
129
+ self.embed_dim = config.vision_embed_dim
130
+ self.ln = nn.LayerNorm(self.embed_dim)
131
+ self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False)
132
+ self.act = nn.ReLU()
133
+ self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False)
134
+
135
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
136
+ hidden_states = self.ln(hidden_states)
137
+ hidden_states = self.fc(hidden_states)
138
+ hidden_states = self.act(hidden_states)
139
+ hidden_states = self.c_proj(hidden_states)
140
+
141
+ return hidden_states
m4/models/vbloom/__init__.py ADDED
File without changes
m4/models/vbloom/configuration_vbloom.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ VBloom configuration"""
16
+ import os
17
+ from typing import Tuple, Union
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
28
+ "bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json",
29
+ "bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json",
30
+ "bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json",
31
+ "bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json",
32
+ "bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json",
33
+ }
34
+
35
+
36
+ class VBloomConfig(PretrainedConfig):
37
+ """
38
+ This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
39
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
40
+ defaults will yield a similar configuration to the Bloom architecture
41
+ [bigscience/bloom](https://huggingface.co/bigscience/bloom).
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ TODO: this doc is completely out of sync with the actual args
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 50257):
50
+ Vocabulary size of the Bloom model. Defines the number of different tokens that can be represented by the
51
+ `inputs_ids` passed when calling [`BloomModel`].
52
+ additional_vocab_size (`int`, *optional`, defaults to 0):
53
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
54
+ are always trainable whereas regular vocab tokens can be frozen or not.
55
+ hidden_size (`int`, *optional*, defaults to 768):
56
+ Dimensionality of the embeddings and hidden states.
57
+ n_layer (`int`, *optional*, defaults to 12):
58
+ Number of hidden layers in the Transformer encoder.
59
+ n_head (`int`, *optional*, defaults to 12):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
62
+ The dropout ratio for the attention.
63
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
64
+ The epsilon to use in the layer normalization layers.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ alpha_initializer (`str`, *optional*, defaults to `"ones"`):
68
+ Initialization type for the alphas.
69
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
70
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
71
+ alpha_type (`str`, *optional*, defaults to `"vector"`):
72
+ Whether the gating alphas should be vectors or single floats.
73
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
74
+ If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
75
+ skip_bias_add (`bool`, *optional*, defaults to `True`):
76
+ If set to `True`, it will skip bias add for each linear layer in the transformer blocks
77
+ skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
78
+ If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
79
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
80
+ Dropout rate of the dropout function on the bias dropout.
81
+ attention_dropout (`float`, *optional*, defaults to 0.1):
82
+ Dropout rate applied to the attention probs
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ pretraining_tp (`int`, *optional*, defaults to `1`):
86
+ Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
87
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
88
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
89
+ issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
90
+ `slow_but_exact=True`.
91
+ slow_but_exact (`bool`, *optional*, defaults to `False`):
92
+ Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
93
+ merging the TP rank tensors, due to slicing operations the results may be slightly different between the
94
+ model trained on Megatron and our model. Please refer to [this
95
+ issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
96
+ enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
97
+ resolved in the future once the main model has been fine-tuned with TP_rank=1.
98
+
99
+ Example:
100
+
101
+ ```python
102
+ >>> from transformers import BloomModel, BloomConfig
103
+
104
+ >>> # Initializing a Bloom configuration
105
+ >>> configuration = BloomConfig()
106
+
107
+ >>> # Initializing a model from the configuration
108
+ >>> model = BloomModel(configuration)
109
+
110
+ >>> # Accessing the model configuration
111
+ >>> configuration = model.config
112
+ ```"""
113
+
114
+ model_type = "vbloom"
115
+ keys_to_ignore_at_inference = ["past_key_values"]
116
+ attribute_map = {
117
+ "num_hidden_layers": "n_layer",
118
+ "num_attention_heads": "n_head",
119
+ }
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=250880,
124
+ additional_vocab_size=0,
125
+ hidden_size=64,
126
+ n_layer=2,
127
+ n_head=8,
128
+ layer_norm_epsilon=1e-5,
129
+ initializer_range=0.02,
130
+ alpha_initializer="ones",
131
+ alphas_initializer_range=0.0,
132
+ alpha_type="vector",
133
+ use_cache=False,
134
+ bos_token_id=1,
135
+ eos_token_id=2,
136
+ apply_residual_connection_post_layernorm=False,
137
+ hidden_dropout=0.0,
138
+ attention_dropout=0.0,
139
+ pretraining_tp=1, # TP rank used when training with megatron
140
+ slow_but_exact=False,
141
+ cross_layer_interval=1,
142
+ tie_word_embeddings=False,
143
+ freeze_text_layers=True,
144
+ freeze_lm_head=False,
145
+ freeze_vision_layers=True,
146
+ vision_model_name="google/vit-base-patch16-224",
147
+ vision_model_params="{}",
148
+ vision_embed_dim=768,
149
+ image_token_index=250880,
150
+ use_resampler=False,
151
+ resampler_n_latents=64,
152
+ resampler_depth=6,
153
+ resampler_n_heads=16,
154
+ resampler_head_dim=96,
155
+ **kwargs,
156
+ ):
157
+ self.vocab_size = vocab_size
158
+ self.additional_vocab_size = additional_vocab_size
159
+ # Backward compatibility with n_embed kwarg
160
+ n_embed = kwargs.pop("n_embed", None)
161
+ self.hidden_size = hidden_size if n_embed is None else n_embed
162
+ self.n_layer = n_layer
163
+ self.n_head = n_head
164
+ self.layer_norm_epsilon = layer_norm_epsilon
165
+ self.initializer_range = initializer_range
166
+ self.alpha_initializer = alpha_initializer
167
+ self.alphas_initializer_range = alphas_initializer_range
168
+ self.alpha_type = alpha_type
169
+ self.use_cache = use_cache
170
+ self.pretraining_tp = pretraining_tp
171
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
172
+ self.hidden_dropout = hidden_dropout
173
+ self.attention_dropout = attention_dropout
174
+
175
+ self.bos_token_id = bos_token_id
176
+ self.eos_token_id = eos_token_id
177
+ self.slow_but_exact = slow_but_exact
178
+
179
+ self.cross_layer_interval = cross_layer_interval
180
+ self.freeze_vision_layers = freeze_vision_layers
181
+ self.vision_model_name = vision_model_name
182
+ self.vision_model_params = vision_model_params
183
+
184
+ self.tie_word_embeddings = tie_word_embeddings
185
+ self.freeze_text_layers = freeze_text_layers
186
+ self.freeze_lm_head = freeze_lm_head
187
+ self.image_token_index = image_token_index
188
+
189
+ self.vision_embed_dim = vision_embed_dim
190
+
191
+ # Resampler params
192
+ self.use_resampler = use_resampler
193
+ self.resampler_n_latents = resampler_n_latents
194
+ self.resampler_depth = resampler_depth
195
+ self.resampler_n_heads = resampler_n_heads
196
+ self.resampler_head_dim = resampler_head_dim
197
+
198
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
199
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
200
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
201
+ # of this object many attributes have default values and haven't yet been overridden.
202
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
203
+
204
+ super().__init__(
205
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
206
+ )
207
+
208
+ def check_compatibilities(self):
209
+ if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
210
+ raise ValueError(
211
+ "if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
212
+ )
213
+
214
+ vision_model_params = eval(self.vision_model_params)
215
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
216
+ if hasattr(config, "vision_config"):
217
+ vison_config = config.vision_config
218
+ else:
219
+ vison_config = config
220
+ vision_embed_dim = vison_config.hidden_size
221
+ if self.vision_embed_dim != vision_embed_dim:
222
+ raise ValueError(
223
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
224
+ f" ({vision_embed_dim})"
225
+ )
226
+
227
+ @classmethod
228
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
229
+ outputs = super(VBloomConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
230
+ if isinstance(outputs, Tuple):
231
+ # When called with return_unused_kwargs=True, the first item will be the config
232
+ outputs[0].check_compatibilities()
233
+ else:
234
+ outputs.check_compatibilities()
235
+ return outputs
m4/models/vbloom/modeling_vbloom.py ADDED
@@ -0,0 +1,1396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch BLOOM model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss, LayerNorm
25
+ from torch.nn import functional as F
26
+ from transformers.file_utils import (
27
+ add_code_sample_docstrings,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ )
31
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
32
+
33
+ from m4.models import DecoupledEmbedding, DecoupledLinear
34
+ from m4.models.common import (
35
+ expand_inputs_for_generation,
36
+ prepare_inputs_for_generation,
37
+ update_model_kwargs_for_generation,
38
+ )
39
+ from m4.models.custom_modules import VLOOMPreTrainedModelBase
40
+ from m4.models.perceiver.perceiver import PerceiverResampler
41
+ from m4.models.vbloom.configuration_vbloom import VBloomConfig
42
+ from m4.training.utils import (
43
+ compute_perceiver_tflops_per_batch_per_gpu,
44
+ compute_tflops_per_batch_per_gpu,
45
+ freeze_model,
46
+ )
47
+ from m4.utils import logging
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
53
+ _CONFIG_FOR_DOC = "VBloomConfig"
54
+ _TOKENIZER_FOR_DOC = "BloomTokenizerFast"
55
+
56
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "bigscience/bigscience-small-testing",
58
+ "bigscience/bloom-560m",
59
+ "bigscience/bloom-1b1",
60
+ "bigscience/bloom-1b7",
61
+ "bigscience/bloom-3b",
62
+ "bigscience/bloom-7b1",
63
+ "bigscience/bloom",
64
+ ]
65
+
66
+
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
69
+ ) -> torch.BoolTensor:
70
+ """
71
+ Make causal mask used for self-attention.
72
+ """
73
+ batch_size, target_length = input_ids_shape
74
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
75
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
76
+ seq_ids = torch.arange(target_length, device=device)
77
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
78
+
79
+ if past_key_values_length > 0:
80
+ mask[:, :past_key_values_length] = False
81
+
82
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
83
+ return expanded_mask
84
+
85
+
86
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
87
+ """
88
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
89
+ """
90
+ batch_size, src_length = mask.shape
91
+ tgt_length = tgt_length if tgt_length is not None else src_length
92
+
93
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
94
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
95
+
96
+
97
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
98
+ """
99
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
100
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
101
+ `softmax(l+a) = softmax(l)`. Based on
102
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
103
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
104
+
105
+ Args:
106
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
107
+ attention_mask (`torch.Tensor`):
108
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
109
+ num_heads (`int`, *required*):
110
+ number of heads
111
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
112
+ dtype of the output tensor
113
+ """
114
+ batch_size, seq_length = attention_mask.shape
115
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
116
+ base = torch.tensor(
117
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
118
+ )
119
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
120
+ slopes = torch.pow(base, powers)
121
+
122
+ if closest_power_of_2 != num_heads:
123
+ extra_base = torch.tensor(
124
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
125
+ )
126
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
127
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
128
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
129
+
130
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
131
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
132
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
133
+ # => the query_length dimension will then be broadcasted correctly
134
+ # This is more or less identical to T5's relative position bias:
135
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
136
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
137
+ alibi = slopes[..., None] * arange_tensor
138
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
139
+
140
+
141
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
142
+ """
143
+ Dropout add function
144
+
145
+ Args:
146
+ x (`torch.tensor`, *required*):
147
+ input tensor
148
+ residual (`torch.tensor`, *required*):
149
+ esidual tensor
150
+ prob (`float`, *required*):
151
+ dropout probability
152
+ training (`bool`, *required*):
153
+ training mode
154
+ """
155
+ out = F.dropout(x, p=prob, training=training)
156
+ out = residual + out
157
+ return out
158
+
159
+
160
+ def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
161
+ """
162
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
163
+ make the model jitable.
164
+
165
+ Args:
166
+ x (`torch.tensor`, *required*):
167
+ input hidden states
168
+ """
169
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
170
+
171
+
172
+ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
175
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
176
+
177
+ Args:
178
+ g (`torch.tensor`, *required*):
179
+ gradient output tensor
180
+ x (`torch.tensor`, *required*):
181
+ input tensor
182
+ """
183
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
184
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
185
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
186
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
187
+ return ff * g
188
+
189
+
190
+ class GeLUFunction(torch.autograd.Function):
191
+ @staticmethod
192
+ def forward(ctx, input: torch.Tensor) -> torch.Tensor:
193
+ ctx.save_for_backward(input)
194
+ return bloom_gelu_forward(input)
195
+
196
+ @staticmethod
197
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
198
+ input = ctx.saved_tensors
199
+ tmp = bloom_gelu_back(grad_output, input)
200
+ return tmp
201
+
202
+
203
+ class BloomGelu(nn.Module):
204
+ """
205
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
206
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
207
+ copied from Megatron-DeepSpeed code and adapted for our needs
208
+
209
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
210
+ """
211
+
212
+ def __init__(self):
213
+ super().__init__()
214
+
215
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
216
+ if self.training:
217
+ return GeLUFunction.apply(x)
218
+ else:
219
+ return bloom_gelu_forward(x)
220
+
221
+
222
+ class BloomAttention(nn.Module):
223
+ def __init__(self, config: VBloomConfig, is_cross_attention=False):
224
+ super().__init__()
225
+
226
+ self.pretraining_tp = config.pretraining_tp
227
+ self.slow_but_exact = config.slow_but_exact
228
+
229
+ self.hidden_size = config.hidden_size
230
+ self.num_heads = config.n_head
231
+ self.head_dim = self.hidden_size // self.num_heads
232
+ self.split_size = self.hidden_size
233
+ self.hidden_dropout = config.hidden_dropout
234
+
235
+ if self.head_dim * self.num_heads != self.hidden_size:
236
+ raise ValueError(
237
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
238
+ f" {self.num_heads})."
239
+ )
240
+
241
+ # Layer-wise attention scaling
242
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
243
+ self.beta = 1.0
244
+
245
+ self.is_cross_attention = is_cross_attention
246
+
247
+ if self.is_cross_attention:
248
+ self.query = nn.Linear(self.hidden_size, 1 * self.hidden_size, bias=True)
249
+ kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
250
+ self.key_value = nn.Linear(kv_input_dim, 2 * self.hidden_size, bias=True)
251
+ else:
252
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
253
+
254
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
255
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
256
+
257
+ if self.is_cross_attention:
258
+ # The alpha stuff
259
+ self.act = nn.Tanh()
260
+
261
+ if config.alpha_initializer == "zeros":
262
+ if config.alpha_type == "vector":
263
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
264
+ elif config.alpha_type == "float":
265
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
266
+ else:
267
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
268
+
269
+ elif config.alpha_initializer == "ones":
270
+ if config.alpha_type == "vector":
271
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size))
272
+ elif config.alpha_type == "float":
273
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1))
274
+ else:
275
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
276
+
277
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
278
+ if config.alpha_type == "vector":
279
+ self.alpha_cross_attn = nn.Parameter(
280
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
281
+ )
282
+ elif config.alpha_type == "float":
283
+ self.alpha_cross_attn = nn.Parameter(
284
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
285
+ )
286
+ else:
287
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
288
+
289
+ else:
290
+ raise NotImplementedError(
291
+ f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!"
292
+ )
293
+
294
+ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
295
+ """
296
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
297
+ storage as `fused_qkv`
298
+
299
+ Args:
300
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
301
+
302
+ Returns:
303
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
304
+ value: [batch_size, seq_length, num_heads, head_dim]
305
+ """
306
+ batch_size, seq_length, n_times_hidden_size = fused_qkv.shape
307
+ n = int(n_times_hidden_size / self.hidden_size)
308
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, n, self.head_dim)
309
+ outputs = ()
310
+ for i in range(n):
311
+ outputs += (fused_qkv[..., i, :],)
312
+ return outputs
313
+
314
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
315
+ """
316
+ Merge heads together over the last dimenstion
317
+
318
+ Args:
319
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
320
+
321
+ Returns:
322
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
323
+ """
324
+ # What we want to achieve is:
325
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
326
+ batch_size_and_num_heads, seq_length, _ = x.shape
327
+ batch_size = batch_size_and_num_heads // self.num_heads
328
+
329
+ # First view to decompose the batch size
330
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
331
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
332
+
333
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
334
+ x = x.permute(0, 2, 1, 3)
335
+
336
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
337
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ residual: torch.Tensor,
343
+ alibi: torch.Tensor,
344
+ attention_mask: torch.Tensor,
345
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
346
+ head_mask: Optional[torch.Tensor] = None,
347
+ encoder_hidden_states: Optional[torch.Tensor] = None,
348
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
349
+ use_cache: bool = False,
350
+ output_attentions: bool = False,
351
+ ):
352
+ if not self.is_cross_attention:
353
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
354
+
355
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
356
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
357
+ else:
358
+ if encoder_hidden_states is not None:
359
+ attention_mask = encoder_attention_mask
360
+ q = self.query(hidden_states)
361
+ kv = self.key_value(encoder_hidden_states)
362
+
363
+ query_layer = self._split_heads(q)[0]
364
+ key_layer, value_layer = self._split_heads(kv)
365
+
366
+ batch_size, q_length, _, _ = query_layer.shape
367
+ _, kv_length, _, _ = key_layer.shape
368
+
369
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
370
+ key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, kv_length)
371
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, kv_length, self.head_dim)
372
+ if layer_past is not None:
373
+ past_key, past_value = layer_past
374
+ # concatenate along seq_length dimension:
375
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
376
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
377
+ key_layer = torch.cat((past_key, key_layer), dim=2)
378
+ value_layer = torch.cat((past_value, value_layer), dim=1)
379
+ _, _, kv_length = key_layer.shape
380
+
381
+ if use_cache is True:
382
+ present = (key_layer, value_layer)
383
+ else:
384
+ present = None
385
+
386
+ # [batch_size * num_heads, q_length, kv_length]
387
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
388
+ if alibi is None:
389
+ alibi = torch.empty(
390
+ batch_size * self.num_heads, q_length, kv_length, dtype=query_layer.dtype, device=query_layer.device
391
+ )
392
+
393
+ matmul_result = alibi.baddbmm(
394
+ batch1=query_layer,
395
+ batch2=key_layer,
396
+ beta=0.0 if self.is_cross_attention else self.beta,
397
+ alpha=self.inv_norm_factor,
398
+ )
399
+
400
+ # change view to [batch_size, num_heads, q_length, kv_length]
401
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
402
+
403
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
404
+ input_dtype = attention_scores.dtype
405
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
406
+ if input_dtype == torch.float16:
407
+ attention_scores = attention_scores.to(torch.float)
408
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
409
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
410
+
411
+ # [batch_size, num_heads, q_length, kv_length]
412
+ attention_probs = self.attention_dropout(attention_probs)
413
+
414
+ if head_mask is not None:
415
+ attention_probs = attention_probs * head_mask
416
+
417
+ # change view [batch_size x num_heads, q_length, kv_length]
418
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
419
+
420
+ # matmul: [batch_size * num_heads, q_length, head_dim]
421
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer)
422
+
423
+ # change view [batch_size, num_heads, q_length, head_dim]
424
+ context_layer = self._merge_heads(context_layer)
425
+
426
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
427
+ if self.pretraining_tp > 1 and self.slow_but_exact:
428
+ slices = self.hidden_size / self.pretraining_tp
429
+ output_tensor = torch.zeros_like(context_layer)
430
+ for i in range(self.pretraining_tp):
431
+ output_tensor = output_tensor + F.linear(
432
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
433
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
434
+ )
435
+ else:
436
+ output_tensor = self.dense(context_layer)
437
+
438
+ if not self.is_cross_attention:
439
+ output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
440
+ else:
441
+ output_tensor = dropout_add(
442
+ self.act(self.alpha_cross_attn) * output_tensor, residual, self.hidden_dropout, self.training
443
+ )
444
+
445
+ outputs = (output_tensor, present)
446
+ if output_attentions:
447
+ outputs += (attention_probs,)
448
+
449
+ return outputs
450
+
451
+
452
+ class BloomMLP(nn.Module):
453
+ def __init__(self, config: VBloomConfig, is_gated=False):
454
+ super().__init__()
455
+ hidden_size = config.hidden_size
456
+
457
+ self.pretraining_tp = config.pretraining_tp
458
+ self.slow_but_exact = config.slow_but_exact
459
+ self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
460
+ self.gelu_impl = BloomGelu()
461
+ self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
462
+ self.hidden_dropout = config.hidden_dropout
463
+
464
+ # The alpha stuff
465
+ self.is_gated = is_gated
466
+ if is_gated:
467
+ self.act = nn.Tanh()
468
+
469
+ if config.alpha_initializer == "zeros":
470
+ if config.alpha_type == "vector":
471
+ self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
472
+ elif config.alpha_type == "float":
473
+ self.alpha_dense = nn.Parameter(torch.zeros(1))
474
+ else:
475
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
476
+
477
+ elif config.alpha_initializer == "ones":
478
+ if config.alpha_type == "vector":
479
+ self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
480
+ elif config.alpha_type == "float":
481
+ self.alpha_dense = nn.Parameter(torch.ones(1))
482
+ else:
483
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
484
+
485
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
486
+ if config.alpha_type == "vector":
487
+ self.alpha_dense = nn.Parameter(
488
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
489
+ )
490
+ elif config.alpha_type == "float":
491
+ self.alpha_dense = nn.Parameter(
492
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
493
+ )
494
+ else:
495
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
496
+
497
+ else:
498
+ raise NotImplementedError(
499
+ f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!"
500
+ )
501
+
502
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
503
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
504
+
505
+ if self.pretraining_tp > 1 and self.slow_but_exact:
506
+ intermediate_output = torch.zeros_like(residual)
507
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
508
+ for i in range(self.pretraining_tp):
509
+ intermediate_output = intermediate_output + F.linear(
510
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
511
+ self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
512
+ )
513
+ else:
514
+ intermediate_output = self.dense_4h_to_h(hidden_states)
515
+
516
+ if not self.is_gated:
517
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
518
+ else:
519
+ output = dropout_add(
520
+ self.act(self.alpha_dense) * intermediate_output, residual, self.hidden_dropout, self.training
521
+ )
522
+
523
+ return output
524
+
525
+
526
+ class BloomBlock(nn.Module):
527
+ def __init__(self, config: VBloomConfig):
528
+ super().__init__()
529
+ hidden_size = config.hidden_size
530
+
531
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
532
+ self.num_heads = config.n_head
533
+ self.self_attention = BloomAttention(config)
534
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
535
+
536
+ self.mlp = BloomMLP(config)
537
+
538
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
539
+ self.hidden_dropout = config.hidden_dropout
540
+
541
+ def forward(
542
+ self,
543
+ hidden_states: torch.Tensor,
544
+ alibi: torch.Tensor,
545
+ attention_mask: torch.Tensor,
546
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
547
+ head_mask: Optional[torch.Tensor] = None,
548
+ use_cache: bool = False,
549
+ output_attentions: bool = False,
550
+ ):
551
+ # hidden_states: [batch_size, seq_length, hidden_size]
552
+
553
+ # Layer norm at the beginning of the transformer layer.
554
+ layernorm_output = self.input_layernorm(hidden_states)
555
+
556
+ # Layer norm post the self attention.
557
+ if self.apply_residual_connection_post_layernorm:
558
+ residual = layernorm_output
559
+ else:
560
+ residual = hidden_states
561
+
562
+ # Self attention.
563
+ attn_outputs = self.self_attention(
564
+ layernorm_output,
565
+ residual,
566
+ layer_past=layer_past,
567
+ attention_mask=attention_mask,
568
+ alibi=alibi,
569
+ head_mask=head_mask,
570
+ use_cache=use_cache,
571
+ output_attentions=output_attentions,
572
+ )
573
+
574
+ attention_output = attn_outputs[0]
575
+
576
+ outputs = attn_outputs[1:]
577
+
578
+ layernorm_output = self.post_attention_layernorm(attention_output)
579
+
580
+ # Get residual
581
+ if self.apply_residual_connection_post_layernorm:
582
+ residual = layernorm_output
583
+ else:
584
+ residual = attention_output
585
+
586
+ # MLP.
587
+ output = self.mlp(layernorm_output, residual)
588
+
589
+ if use_cache:
590
+ outputs = (output,) + outputs
591
+ else:
592
+ outputs = (output,) + outputs[1:]
593
+
594
+ return outputs # hidden_states, present, attentions
595
+
596
+
597
+ class VBloomGatedCrossAttentionBlock(nn.Module):
598
+ def __init__(self, config):
599
+ super().__init__()
600
+ hidden_size = config.hidden_size
601
+
602
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
603
+ self.num_heads = config.n_head
604
+ self.cross_attention = BloomAttention(config, is_cross_attention=True)
605
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
606
+
607
+ self.gated_mlp = BloomMLP(config, is_gated=True)
608
+
609
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
610
+ self.hidden_dropout = config.hidden_dropout
611
+
612
+ def forward(
613
+ self,
614
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
615
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
616
+ attention_mask: Optional[torch.FloatTensor] = None,
617
+ head_mask: Optional[torch.FloatTensor] = None,
618
+ image_hidden_states: Optional[torch.Tensor] = None,
619
+ image_attention_mask: Optional[torch.FloatTensor] = None,
620
+ use_cache: Optional[bool] = False,
621
+ output_attentions: Optional[bool] = False,
622
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
623
+ # hidden_states: [batch_size, seq_length, hidden_size]
624
+
625
+ # Layer norm at the beginning of the transformer layer.
626
+ layernorm_output = self.input_layernorm(hidden_states)
627
+
628
+ # Layer norm post the self attention.
629
+ if self.apply_residual_connection_post_layernorm:
630
+ residual = layernorm_output
631
+ else:
632
+ residual = hidden_states
633
+
634
+ # Self attention.
635
+ attn_outputs = self.cross_attention(
636
+ layernorm_output,
637
+ residual,
638
+ alibi=None,
639
+ layer_past=layer_past,
640
+ attention_mask=attention_mask,
641
+ head_mask=head_mask,
642
+ encoder_hidden_states=image_hidden_states,
643
+ encoder_attention_mask=image_attention_mask,
644
+ use_cache=use_cache,
645
+ output_attentions=output_attentions,
646
+ )
647
+
648
+ attention_output = attn_outputs[0]
649
+
650
+ outputs = attn_outputs[1:]
651
+
652
+ layernorm_output = self.post_attention_layernorm(attention_output)
653
+
654
+ # Get residual
655
+ if self.apply_residual_connection_post_layernorm:
656
+ residual = layernorm_output
657
+ else:
658
+ residual = attention_output
659
+
660
+ # MLP.
661
+ output = self.gated_mlp(layernorm_output, residual)
662
+
663
+ if use_cache:
664
+ outputs = (output,) + outputs
665
+ else:
666
+ outputs = (output,) + outputs[1:]
667
+
668
+ return outputs # hidden_states, present, attentions
669
+
670
+
671
+ class VBloomPreTrainedModel(VLOOMPreTrainedModelBase):
672
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
673
+ """
674
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
675
+ models.
676
+ """
677
+
678
+ config_class = VBloomConfig
679
+ base_model_prefix = "transformer"
680
+ supports_gradient_checkpointing = True
681
+ _no_split_modules = ["BloomBlock"]
682
+
683
+ def __init__(self, *inputs, **kwargs):
684
+ super().__init__(*inputs, **kwargs)
685
+
686
+ def _init_weights(self, module: nn.Module):
687
+ """Initialize the weights."""
688
+ if isinstance(module, nn.Linear):
689
+ # Slightly different from the TF version which uses truncated_normal for initialization
690
+ # cf https://github.com/pytorch/pytorch/pull/5617
691
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
692
+ if module.bias is not None:
693
+ module.bias.data.zero_()
694
+ elif isinstance(module, nn.Embedding):
695
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
696
+ if module.padding_idx is not None:
697
+ module.weight.data[module.padding_idx].zero_()
698
+ elif isinstance(module, LayerNorm):
699
+ module.bias.data.zero_()
700
+ module.weight.data.fill_(1.0)
701
+
702
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
703
+ if isinstance(module, VBloomModel):
704
+ module.gradient_checkpointing = value
705
+
706
+ @classmethod
707
+ def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
708
+ # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
709
+ beheaded_model = model.transformer if hasattr(model, "transformer") else model
710
+ cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
711
+ beheaded_model.freeze_relevant_params(config)
712
+
713
+
714
+ BLOOM_START_DOCSTRING = r"""
715
+
716
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
717
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
718
+
719
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
720
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
721
+ and behavior.
722
+
723
+ Parameters:
724
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
725
+ Initializing with a config file does not load the weights associated with the model, only the
726
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
727
+ """
728
+
729
+ BLOOM_INPUTS_DOCSTRING = r"""
730
+ Args:
731
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
732
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
733
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
734
+
735
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
736
+ `input_ids`.
737
+
738
+ Indices can be obtained using [`BloomTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and
739
+ [`PreTrainedTokenizer.__call__`] for details.
740
+
741
+ [What are input IDs?](../glossary#input-ids)
742
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
743
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
744
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
745
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
746
+
747
+ Each element of `past_key_values` is a tuple (past_key, past_value):
748
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
749
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
750
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
751
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
752
+
753
+ - 1 for tokens that are **not masked**,
754
+ - 0 for tokens that are **masked**.
755
+
756
+ [What are attention masks?](../glossary#attention-mask)
757
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
758
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
759
+
760
+ - 1 indicates the head is **not masked**,
761
+ - 0 indicates the head is **masked**.
762
+
763
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
764
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
765
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
766
+ model's internal embedding lookup matrix.
767
+
768
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
769
+ `past_key_values`).
770
+ use_cache (`bool`, *optional*):
771
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
772
+ `past_key_values`).
773
+ output_attentions (`bool`, *optional*):
774
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
775
+ tensors for more detail.
776
+ output_hidden_states (`bool`, *optional*):
777
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
778
+ more detail.
779
+ return_dict (`bool`, *optional*):
780
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
781
+ """
782
+
783
+
784
+ @add_start_docstrings(
785
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
786
+ BLOOM_START_DOCSTRING,
787
+ )
788
+ class VBloomModel(VBloomPreTrainedModel):
789
+ def __init__(self, config: VBloomConfig, vision_model=None):
790
+ super().__init__(config)
791
+
792
+ self.embed_dim = config.hidden_size
793
+ self.num_heads = config.n_head
794
+
795
+ # Embedding + LN Embedding
796
+ self.word_embeddings = DecoupledEmbedding(
797
+ num_embeddings=config.vocab_size,
798
+ num_additional_embeddings=config.additional_vocab_size,
799
+ embedding_dim=self.embed_dim,
800
+ partially_freeze=config.freeze_text_layers,
801
+ )
802
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
803
+
804
+ # Transformer blocks
805
+ self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
806
+
807
+ # Final Layer Norm
808
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
809
+
810
+ self.cross_layer_interval = config.cross_layer_interval
811
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
812
+ self.gated_cross_attn_layers = nn.ModuleList(
813
+ [VBloomGatedCrossAttentionBlock(config) for i in range(num_cross_layers)]
814
+ )
815
+
816
+ # Perceiver Resampler
817
+ if config.use_resampler:
818
+ self.perceiver_resampler = PerceiverResampler(
819
+ self.config,
820
+ self.config.vision_embed_dim,
821
+ config.resampler_depth,
822
+ config.resampler_n_heads,
823
+ config.resampler_head_dim,
824
+ config.resampler_n_latents,
825
+ )
826
+ self.gradient_checkpointing = False
827
+
828
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
829
+ # this solves the losing of weights in `from_pretrained` on the main model
830
+ self.vision_model = vision_model
831
+
832
+ # Initialize weights and apply final processing
833
+ self.post_init()
834
+
835
+ self.freeze_relevant_params(config)
836
+
837
+ def freeze_relevant_params(self, config=None):
838
+ if config is None:
839
+ config = self.config
840
+
841
+ if config.freeze_text_layers:
842
+ self.freeze_text_layers()
843
+
844
+ if config.freeze_vision_layers:
845
+ freeze_model(self.vision_model)
846
+
847
+ def freeze_text_layers(self):
848
+ for module in [self.word_embeddings_layernorm, self.h, self.ln_f]:
849
+ freeze_model(module)
850
+
851
+ def get_input_embeddings(self):
852
+ return self.word_embeddings
853
+
854
+ def _prepare_attn_mask(
855
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
856
+ ) -> torch.BoolTensor:
857
+ # create causal mask
858
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
859
+ combined_attention_mask = None
860
+ device = attention_mask.device
861
+ _, src_length = input_shape
862
+
863
+ if src_length > 1:
864
+ combined_attention_mask = _make_causal_mask(
865
+ input_shape, device=device, past_key_values_length=past_key_values_length
866
+ )
867
+
868
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
869
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
870
+ combined_attention_mask = (
871
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
872
+ )
873
+
874
+ return combined_attention_mask
875
+
876
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
877
+ self.word_embeddings = new_embeddings
878
+
879
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
880
+ @add_code_sample_docstrings(
881
+ processor_class=_TOKENIZER_FOR_DOC,
882
+ checkpoint=_CHECKPOINT_FOR_DOC,
883
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
884
+ config_class=_CONFIG_FOR_DOC,
885
+ )
886
+ def forward(
887
+ self,
888
+ input_ids: Optional[torch.LongTensor] = None,
889
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
890
+ attention_mask: Optional[torch.Tensor] = None,
891
+ head_mask: Optional[torch.LongTensor] = None,
892
+ inputs_embeds: Optional[torch.LongTensor] = None,
893
+ pixel_values: Optional[torch.FloatTensor] = None,
894
+ image_embeddings: Optional[torch.FloatTensor] = None,
895
+ image_attention_mask: Optional[torch.Tensor] = None,
896
+ crossblock_head_mask: Optional[torch.Tensor] = None,
897
+ use_cache: Optional[bool] = None,
898
+ output_attentions: Optional[bool] = None,
899
+ output_hidden_states: Optional[bool] = None,
900
+ return_dict: Optional[bool] = None,
901
+ **deprecated_arguments,
902
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
903
+ if deprecated_arguments.pop("position_ids", False) is not False:
904
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
905
+ warnings.warn(
906
+ (
907
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely"
908
+ " ignore passing `position_ids`."
909
+ ),
910
+ FutureWarning,
911
+ )
912
+ if len(deprecated_arguments) > 0:
913
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
914
+
915
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
916
+ output_hidden_states = (
917
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
918
+ )
919
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
920
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
921
+
922
+ if input_ids is not None and inputs_embeds is not None:
923
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
924
+ elif input_ids is not None:
925
+ batch_size, seq_length = input_ids.shape
926
+ elif inputs_embeds is not None:
927
+ batch_size, seq_length, _ = inputs_embeds.shape
928
+ else:
929
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
930
+
931
+ if past_key_values is None:
932
+ past_key_values = tuple([None] * len(self.h))
933
+
934
+ # Prepare head mask if needed
935
+ # 1.0 in head_mask indicate we keep the head
936
+ # attention_probs has shape batch_size x num_heads x N x N
937
+ # head_mask has shape n_layer x batch x num_heads x N x N
938
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
939
+
940
+ if inputs_embeds is None:
941
+ inputs_embeds = self.word_embeddings(input_ids)
942
+
943
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
944
+
945
+ presents = () if use_cache else None
946
+ all_self_attentions = () if output_attentions else None
947
+ all_hidden_states = () if output_hidden_states else None
948
+
949
+ # Compute alibi tensor: check build_alibi_tensor documentation
950
+ seq_length_with_past = seq_length
951
+ past_key_values_length = 0
952
+ if past_key_values[0] is not None:
953
+ past_key_values_length = past_key_values[0][0].shape[2]
954
+ seq_length_with_past = seq_length_with_past + past_key_values_length
955
+ if attention_mask is None:
956
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
957
+ else:
958
+ attention_mask = attention_mask.to(hidden_states.device)
959
+
960
+ alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
961
+
962
+ causal_mask = self._prepare_attn_mask(
963
+ attention_mask,
964
+ input_shape=(batch_size, seq_length),
965
+ past_key_values_length=past_key_values_length,
966
+ )
967
+
968
+ if pixel_values is not None and image_embeddings is not None:
969
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
970
+ elif pixel_values is not None:
971
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
972
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
973
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
974
+ # Get sequence from the vision encoder
975
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
976
+ elif image_embeddings is not None:
977
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
978
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
979
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
980
+
981
+ if self.config.use_resampler:
982
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
983
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
984
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
985
+ # Make image_attention_mask compatible with hidden states
986
+ text_seq_len = image_attention_mask.size(1)
987
+ image_attention_mask = image_attention_mask.unsqueeze(
988
+ -1
989
+ ) # TODO: something i don't understand here. why are the few last tokens not attending when there is just a single image?
990
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
991
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
992
+
993
+ if image_hidden_states is not None:
994
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
995
+ image_hidden_shape = (image_batch_size, image_sequence_length)
996
+ if image_attention_mask is None:
997
+ image_attention_mask = torch.ones(image_hidden_shape, device=hidden_states.device)
998
+ # image_attention_mask = self.invert_attention_mask(image_attention_mask)
999
+ image_attention_mask = image_attention_mask.to(torch.bool)
1000
+ image_attention_mask = image_attention_mask[:, None, :, :]
1001
+ else:
1002
+ image_attention_mask = None
1003
+
1004
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1005
+ if output_hidden_states:
1006
+ all_hidden_states = all_hidden_states + (hidden_states,)
1007
+
1008
+ def vblock(
1009
+ main_block,
1010
+ hidden_states,
1011
+ alibi,
1012
+ layer_past,
1013
+ attention_mask,
1014
+ layer_head_mask,
1015
+ use_cache,
1016
+ output_attentions,
1017
+ image_hidden_states,
1018
+ image_attention_mask,
1019
+ layer_idx,
1020
+ cross_layer_interval,
1021
+ gated_cross_attn_layers,
1022
+ ):
1023
+ if layer_idx % cross_layer_interval == 0:
1024
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
1025
+ outputs = xblock(
1026
+ hidden_states,
1027
+ attention_mask=attention_mask,
1028
+ image_hidden_states=image_hidden_states,
1029
+ image_attention_mask=image_attention_mask,
1030
+ use_cache=use_cache,
1031
+ output_attentions=output_attentions,
1032
+ )
1033
+ hidden_states = outputs[0]
1034
+
1035
+ outputs = main_block(
1036
+ hidden_states,
1037
+ alibi=alibi,
1038
+ layer_past=layer_past,
1039
+ attention_mask=attention_mask,
1040
+ head_mask=layer_head_mask,
1041
+ use_cache=use_cache,
1042
+ output_attentions=output_attentions,
1043
+ )
1044
+
1045
+ return outputs
1046
+
1047
+ if self.gradient_checkpointing and self.training:
1048
+ layer_past = None
1049
+ if use_cache:
1050
+ logger.warning_once(
1051
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1052
+ )
1053
+ use_cache = False
1054
+
1055
+ outputs = torch.utils.checkpoint.checkpoint(
1056
+ vblock,
1057
+ block,
1058
+ hidden_states,
1059
+ alibi,
1060
+ layer_past,
1061
+ causal_mask,
1062
+ head_mask[i],
1063
+ use_cache,
1064
+ output_attentions,
1065
+ image_hidden_states,
1066
+ image_attention_mask,
1067
+ i,
1068
+ self.cross_layer_interval,
1069
+ self.gated_cross_attn_layers,
1070
+ )
1071
+ else:
1072
+ outputs = vblock(
1073
+ block,
1074
+ hidden_states,
1075
+ alibi=alibi,
1076
+ layer_past=layer_past,
1077
+ attention_mask=causal_mask,
1078
+ layer_head_mask=head_mask[i],
1079
+ use_cache=use_cache,
1080
+ output_attentions=output_attentions,
1081
+ image_hidden_states=image_hidden_states,
1082
+ image_attention_mask=image_attention_mask,
1083
+ layer_idx=i,
1084
+ cross_layer_interval=self.cross_layer_interval,
1085
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
1086
+ )
1087
+
1088
+ hidden_states = outputs[0]
1089
+ if use_cache is True:
1090
+ presents = presents + (outputs[1],)
1091
+
1092
+ if output_attentions:
1093
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1094
+
1095
+ # Add last hidden state
1096
+ hidden_states = self.ln_f(hidden_states)
1097
+
1098
+ if output_hidden_states:
1099
+ all_hidden_states = all_hidden_states + (hidden_states,)
1100
+
1101
+ if not return_dict:
1102
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1103
+
1104
+ return BaseModelOutputWithPastAndCrossAttentions(
1105
+ last_hidden_state=hidden_states,
1106
+ past_key_values=presents,
1107
+ hidden_states=all_hidden_states,
1108
+ attentions=all_self_attentions,
1109
+ )
1110
+
1111
+
1112
+ @add_start_docstrings(
1113
+ """
1114
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
1115
+ embeddings).
1116
+ """,
1117
+ BLOOM_START_DOCSTRING,
1118
+ )
1119
+ class VBloomForCausalLM(VBloomPreTrainedModel):
1120
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1121
+
1122
+ def __init__(self, config: VBloomConfig, vision_model=None):
1123
+ super().__init__(config)
1124
+ self.transformer = VBloomModel(config, vision_model=vision_model)
1125
+ self.lm_head = DecoupledLinear(
1126
+ in_features=config.hidden_size,
1127
+ out_features=config.vocab_size,
1128
+ out_additional_features=config.additional_vocab_size,
1129
+ bias=False,
1130
+ partially_freeze=config.freeze_lm_head,
1131
+ )
1132
+ # Initialize weights and apply final processing
1133
+ self.post_init()
1134
+
1135
+ def get_output_embeddings(self):
1136
+ return self.lm_head
1137
+
1138
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
1139
+ self.lm_head = new_embeddings
1140
+
1141
+ def tie_weights(self):
1142
+ """
1143
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1144
+ """
1145
+ output_embeddings = self.get_output_embeddings()
1146
+ input_embeddings = self.get_input_embeddings()
1147
+
1148
+ if getattr(self.config, "tie_word_embeddings", True):
1149
+ output_embeddings.weight = input_embeddings.weight
1150
+ if input_embeddings.num_additional_embeddings > 0:
1151
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1152
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1153
+
1154
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1155
+ output_embeddings.out_features = input_embeddings.num_embeddings
1156
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1157
+ input_embeddings, "num_additional_embeddings"
1158
+ ):
1159
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1160
+
1161
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1162
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1163
+ unwanted_kwargs = ["position_ids", "token_type_ids"]
1164
+ for kwarg in unwanted_kwargs:
1165
+ inputs.pop(kwarg, None)
1166
+ return inputs
1167
+
1168
+ @staticmethod
1169
+ def _expand_inputs_for_generation(
1170
+ *args,
1171
+ **model_kwargs,
1172
+ ):
1173
+ return expand_inputs_for_generation(*args, **model_kwargs)
1174
+
1175
+ @staticmethod
1176
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
1177
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
1178
+
1179
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
1180
+ @add_code_sample_docstrings(
1181
+ processor_class=_TOKENIZER_FOR_DOC,
1182
+ checkpoint=_CHECKPOINT_FOR_DOC,
1183
+ output_type=CausalLMOutputWithCrossAttentions,
1184
+ config_class=_CONFIG_FOR_DOC,
1185
+ )
1186
+ def forward(
1187
+ self,
1188
+ input_ids: Optional[torch.LongTensor] = None,
1189
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1190
+ attention_mask: Optional[torch.Tensor] = None,
1191
+ head_mask: Optional[torch.Tensor] = None,
1192
+ inputs_embeds: Optional[torch.Tensor] = None,
1193
+ pixel_values: Optional[torch.FloatTensor] = None,
1194
+ image_embeddings: Optional[torch.FloatTensor] = None,
1195
+ image_attention_mask: Optional[torch.Tensor] = None,
1196
+ crossblock_head_mask: Optional[torch.Tensor] = None,
1197
+ labels: Optional[torch.Tensor] = None,
1198
+ use_cache: Optional[bool] = None,
1199
+ output_attentions: Optional[bool] = None,
1200
+ output_hidden_states: Optional[bool] = None,
1201
+ return_dict: Optional[bool] = None,
1202
+ **deprecated_arguments,
1203
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1204
+ r"""
1205
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1206
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1207
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1208
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1209
+ """
1210
+ if deprecated_arguments.pop("position_ids", False) is not False:
1211
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1212
+ warnings.warn(
1213
+ (
1214
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely"
1215
+ " ignore passing `position_ids`."
1216
+ ),
1217
+ FutureWarning,
1218
+ )
1219
+ if len(deprecated_arguments) > 0:
1220
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1221
+
1222
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1223
+
1224
+ transformer_outputs = self.transformer(
1225
+ input_ids,
1226
+ past_key_values=past_key_values,
1227
+ attention_mask=attention_mask,
1228
+ head_mask=head_mask,
1229
+ inputs_embeds=inputs_embeds,
1230
+ pixel_values=pixel_values,
1231
+ image_embeddings=image_embeddings,
1232
+ image_attention_mask=image_attention_mask,
1233
+ crossblock_head_mask=crossblock_head_mask,
1234
+ use_cache=use_cache,
1235
+ output_attentions=output_attentions,
1236
+ output_hidden_states=output_hidden_states,
1237
+ return_dict=return_dict,
1238
+ )
1239
+ hidden_states = transformer_outputs[0]
1240
+
1241
+ lm_logits = self.lm_head(hidden_states)
1242
+
1243
+ loss = None
1244
+ if labels is not None:
1245
+ # Shift so that tokens < n predict n
1246
+ if attention_mask is not None:
1247
+ shift_attention_mask = attention_mask[..., 1:]
1248
+ shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1249
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1250
+ else:
1251
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1252
+ shift_labels = labels[..., 1:].contiguous()
1253
+ # Flatten the tokens
1254
+ loss_fct = CrossEntropyLoss()
1255
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1256
+
1257
+ if not return_dict:
1258
+ output = (lm_logits,) + transformer_outputs[1:]
1259
+ return ((loss,) + output) if loss is not None else output
1260
+
1261
+ return CausalLMOutputWithCrossAttentions(
1262
+ loss=loss,
1263
+ logits=lm_logits,
1264
+ past_key_values=transformer_outputs.past_key_values,
1265
+ hidden_states=transformer_outputs.hidden_states,
1266
+ attentions=transformer_outputs.attentions,
1267
+ )
1268
+
1269
+ @staticmethod
1270
+ def _reorder_cache(
1271
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1272
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1273
+ """
1274
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1275
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1276
+ beam_idx at every generation step.
1277
+
1278
+ Output shares the same memory storage as `past`.
1279
+ """
1280
+ batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
1281
+ batch_size = len(beam_idx)
1282
+ num_heads = batch_size_times_num_heads // batch_size
1283
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
1284
+ device_to_beam_idx = {
1285
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
1286
+ }
1287
+ # key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
1288
+ # value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
1289
+ return tuple(
1290
+ (
1291
+ layer_past[0]
1292
+ .view(batch_size, num_heads, head_dim, seq_length)
1293
+ .index_select(0, device_to_beam_idx[layer_past[0].device])
1294
+ .view(batch_size_times_num_heads, head_dim, seq_length),
1295
+ layer_past[1]
1296
+ .view(batch_size, num_heads, seq_length, head_dim)
1297
+ .index_select(0, device_to_beam_idx[layer_past[0].device])
1298
+ .view(batch_size_times_num_heads, seq_length, head_dim),
1299
+ )
1300
+ for layer_past in past
1301
+ )
1302
+
1303
+ def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
1304
+ config_vl_model = self.config
1305
+
1306
+ language_embed_size = config_vl_model.hidden_size
1307
+ vision_config = self.transformer.vision_model.config
1308
+ num_language_layers = config_vl_model.n_layer
1309
+ ffn_inner_size = 4 * config_vl_model.hidden_size
1310
+
1311
+ # Get vision model blocks infos
1312
+ vision_patch_size = vision_config.patch_size
1313
+ vision_hidden_size = vision_config.hidden_size
1314
+ num_vision_layers = vision_config.num_hidden_layers
1315
+ # The +1 is for the CLS token
1316
+ single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
1317
+ vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
1318
+
1319
+ # Get language and cross-att blocks infos
1320
+ num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
1321
+ language_seq_len = data_param.max_seq_len
1322
+ language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1323
+ cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1324
+ k_v_cross_attn_seq_len = (
1325
+ (self.config.resampler_n_latents * max_num_images)
1326
+ if self.config.use_resampler
1327
+ else (single_image_seq_len * max_num_images)
1328
+ )
1329
+
1330
+ language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1331
+ num_layers=num_language_layers,
1332
+ batch_size=hparams.batch_size_per_gpu,
1333
+ q_seq_len=language_seq_len,
1334
+ k_seq_len=language_seq_len,
1335
+ hidden_size=language_embed_size,
1336
+ kv_in_dim=language_embed_size,
1337
+ ff_exp_factor=language_exp_factor,
1338
+ grad_acc_size=hparams.grad_acc_size,
1339
+ swiglu=False,
1340
+ vocab_size=tokenizer.vocab_size,
1341
+ count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
1342
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1343
+ )
1344
+ cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1345
+ num_layers=num_cross_attn_layers,
1346
+ batch_size=hparams.batch_size_per_gpu,
1347
+ q_seq_len=language_seq_len,
1348
+ k_seq_len=k_v_cross_attn_seq_len,
1349
+ hidden_size=language_embed_size,
1350
+ kv_in_dim=vision_hidden_size,
1351
+ ff_exp_factor=cross_att_exp_factor,
1352
+ grad_acc_size=hparams.grad_acc_size,
1353
+ swiglu=False,
1354
+ vocab_size=None,
1355
+ count_backward=True,
1356
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1357
+ )
1358
+ vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1359
+ num_layers=num_vision_layers,
1360
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1361
+ q_seq_len=single_image_seq_len,
1362
+ k_seq_len=single_image_seq_len,
1363
+ hidden_size=vision_hidden_size,
1364
+ kv_in_dim=vision_hidden_size,
1365
+ ff_exp_factor=vision_exp_factor,
1366
+ grad_acc_size=hparams.grad_acc_size,
1367
+ swiglu=False,
1368
+ vocab_size=None,
1369
+ count_backward=not hparams.model_params["freeze_vision_layers"],
1370
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1371
+ )
1372
+ if self.config.use_resampler:
1373
+ perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
1374
+ num_layers=self.config.resampler_depth,
1375
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1376
+ q_seq_len=self.config.resampler_n_latents,
1377
+ vision_embed_seq_len=single_image_seq_len,
1378
+ q_k_v_input_dim=vision_hidden_size,
1379
+ attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
1380
+ ff_exp_factor=cross_att_exp_factor,
1381
+ count_backward=True,
1382
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1383
+ )
1384
+ flop_count = (
1385
+ language_tflops_per_batch_per_gpu
1386
+ + cross_attention_tflops_per_batch_per_gpu
1387
+ + vision_tflops_per_batch_per_gpu
1388
+ + perceiver_tflops_per_batch_per_gpu
1389
+ )
1390
+ else:
1391
+ flop_count = (
1392
+ language_tflops_per_batch_per_gpu
1393
+ + cross_attention_tflops_per_batch_per_gpu
1394
+ + vision_tflops_per_batch_per_gpu
1395
+ )
1396
+ return flop_count
m4/models/vgpt2/__init__.py ADDED
File without changes
m4/models/vgpt2/configuration_vgpt2.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ OpenAI GPT-2 configuration"""
17
+ import os
18
+ from typing import Tuple, Union
19
+
20
+ from transformers import AutoConfig
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json",
29
+ "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json",
30
+ "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json",
31
+ "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json",
32
+ "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json",
33
+ }
34
+
35
+
36
+ class VGPT2Config(PretrainedConfig):
37
+ """
38
+ This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
39
+ instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
40
+ configuration with the defaults will yield a similar configuration to that of the GPT-2
41
+ [gpt2](https://huggingface.co/gpt2) architecture.
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ TODO: this doc is completely out of sync with the actual args
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 50257):
50
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
51
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
52
+ additional_vocab_size (`int`, *optional`, defaults to 0):
53
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
54
+ are always trainable whereas regular vocab tokens can be frozen or not.
55
+ n_positions (`int`, *optional*, defaults to 1024):
56
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
57
+ just in case (e.g., 512 or 1024 or 2048).
58
+ n_embd (`int`, *optional*, defaults to 768):
59
+ Dimensionality of the embeddings and hidden states.
60
+ n_layer (`int`, *optional*, defaults to 12):
61
+ Number of hidden layers in the Transformer encoder.
62
+ n_head (`int`, *optional*, defaults to 12):
63
+ Number of attention heads for each attention layer in the Transformer encoder.
64
+ n_inner (`int`, *optional*, defaults to None):
65
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
66
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
67
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
68
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
69
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
70
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
71
+ The dropout ratio for the embeddings.
72
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
73
+ The dropout ratio for the attention.
74
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
75
+ The epsilon to use in the layer normalization layers.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ alpha_initializer (`str`, *optional*, defaults to `"ones"`):
79
+ Initialization type for the alphas.
80
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
81
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
82
+ alpha_type (`str`, *optional*, defaults to `"vector"`):
83
+ Whether the gating alphas should be vectors or single floats.
84
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
85
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
86
+ [`TFGPT2DoubleHeadsModel`].
87
+
88
+ Has to be one of the following options:
89
+
90
+ - `"last"`: Take the last token hidden state (like XLNet).
91
+ - `"first"`: Take the first token hidden state (like BERT).
92
+ - `"mean"`: Take the mean of all tokens hidden states.
93
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
94
+ - `"attn"`: Not implemented now, use multi-head attention.
95
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
96
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
97
+ [`TFGPT2DoubleHeadsModel`].
98
+
99
+ Whether or not to add a projection after the vector extraction.
100
+ summary_activation (`str`, *optional*):
101
+ Argument used when doing sequence summary. Used in for the multiple choice head in
102
+ [`GPT2DoubleHeadsModel`].
103
+
104
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
105
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
106
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
107
+ [`TFGPT2DoubleHeadsModel`].
108
+
109
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
110
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
111
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
112
+ [`TFGPT2DoubleHeadsModel`].
113
+
114
+ The dropout ratio to be used after the projection and activation.
115
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
116
+ Scale attention weights by dividing by sqrt(hidden_size)..
117
+ use_cache (`bool`, *optional*, defaults to `True`):
118
+ Whether or not the model should return the last key/values attentions (not used by all models).
119
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
120
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
121
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
122
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
123
+ dot-product/softmax to float() when training with mixed precision.
124
+ cross_layer_interval (`int`, *optional*, default to 1)
125
+ Interval for cross attention (from text to image) layers.
126
+
127
+ Example:
128
+
129
+ ```python
130
+ >>> from transformers import GPT2Model, GPT2Config
131
+
132
+ >>> # Initializing a GPT2 configuration
133
+ >>> configuration = GPT2Config()
134
+
135
+ >>> # Initializing a model from the configuration
136
+ >>> model = GPT2Model(configuration)
137
+
138
+ >>> # Accessing the model configuration
139
+ >>> configuration = model.config
140
+ ```"""
141
+
142
+ model_type = "vgpt2"
143
+ keys_to_ignore_at_inference = ["past_key_values"]
144
+ attribute_map = {
145
+ "hidden_size": "n_embd",
146
+ "max_position_embeddings": "n_positions",
147
+ "num_attention_heads": "n_head",
148
+ "num_hidden_layers": "n_layer",
149
+ }
150
+
151
+ def __init__(
152
+ self,
153
+ vocab_size=50257,
154
+ additional_vocab_size=0,
155
+ n_positions=1024,
156
+ n_embd=768,
157
+ n_layer=12,
158
+ n_head=12,
159
+ n_inner=None,
160
+ activation_function="gelu_new",
161
+ resid_pdrop=0.1,
162
+ embd_pdrop=0.1,
163
+ attn_pdrop=0.1,
164
+ layer_norm_epsilon=1e-5,
165
+ initializer_range=0.02,
166
+ alpha_initializer="ones",
167
+ alphas_initializer_range=0.0,
168
+ alpha_type="vector",
169
+ summary_type="cls_index",
170
+ summary_use_proj=True,
171
+ summary_activation=None,
172
+ summary_proj_to_labels=True,
173
+ summary_first_dropout=0.1,
174
+ scale_attn_weights=True,
175
+ use_cache=True,
176
+ bos_token_id=50256,
177
+ eos_token_id=50256,
178
+ scale_attn_by_inverse_layer_idx=False,
179
+ reorder_and_upcast_attn=False,
180
+ cross_layer_interval=1,
181
+ tie_word_embeddings=False,
182
+ freeze_text_layers=True,
183
+ freeze_lm_head=False,
184
+ freeze_vision_layers=True,
185
+ vision_model_name="google/vit-base-patch16-224",
186
+ vision_model_params="{}",
187
+ vision_embed_dim=768,
188
+ vision_image_size=224,
189
+ image_token_index=50257,
190
+ use_resampler=False,
191
+ resampler_n_latents=64,
192
+ resampler_depth=6,
193
+ resampler_n_heads=16,
194
+ resampler_head_dim=96,
195
+ **kwargs,
196
+ ):
197
+ self.vocab_size = vocab_size
198
+ self.additional_vocab_size = additional_vocab_size
199
+ self.n_positions = n_positions
200
+ self.n_embd = n_embd
201
+ self.n_layer = n_layer
202
+ self.n_head = n_head
203
+ self.n_inner = n_inner
204
+ self.activation_function = activation_function
205
+ self.resid_pdrop = resid_pdrop
206
+ self.embd_pdrop = embd_pdrop
207
+ self.attn_pdrop = attn_pdrop
208
+ self.layer_norm_epsilon = layer_norm_epsilon
209
+ self.initializer_range = initializer_range
210
+ self.alpha_initializer = alpha_initializer
211
+ self.alphas_initializer_range = alphas_initializer_range
212
+ self.alpha_type = alpha_type
213
+ self.summary_type = summary_type
214
+ self.summary_use_proj = summary_use_proj
215
+ self.summary_activation = summary_activation
216
+ self.summary_first_dropout = summary_first_dropout
217
+ self.summary_proj_to_labels = summary_proj_to_labels
218
+ self.scale_attn_weights = scale_attn_weights
219
+ self.use_cache = use_cache
220
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
221
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
222
+
223
+ self.bos_token_id = bos_token_id
224
+ self.eos_token_id = eos_token_id
225
+ self.cross_layer_interval = cross_layer_interval
226
+ self.freeze_vision_layers = freeze_vision_layers
227
+ self.vision_model_name = vision_model_name
228
+ self.vision_model_params = vision_model_params
229
+
230
+ self.tie_word_embeddings = tie_word_embeddings
231
+ self.freeze_text_layers = freeze_text_layers
232
+ self.freeze_lm_head = freeze_lm_head
233
+ self.image_token_index = image_token_index
234
+
235
+ self.vision_embed_dim = vision_embed_dim
236
+ self.vision_image_size = vision_image_size
237
+
238
+ # Resampler params
239
+ self.use_resampler = use_resampler
240
+ self.resampler_n_latents = resampler_n_latents
241
+ self.resampler_depth = resampler_depth
242
+ self.resampler_n_heads = resampler_n_heads
243
+ self.resampler_head_dim = resampler_head_dim
244
+
245
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
246
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
247
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
248
+ # of this object many attributes have default values and haven't yet been overridden.
249
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
250
+
251
+ super().__init__(
252
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
253
+ )
254
+
255
+ def check_compatibilities(self):
256
+ if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
257
+ raise ValueError(
258
+ "if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
259
+ )
260
+
261
+ vision_model_params = eval(self.vision_model_params)
262
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
263
+ if hasattr(config, "vision_config"):
264
+ vision_config = config.vision_config
265
+ else:
266
+ vision_config = config
267
+ vision_embed_dim = vision_config.hidden_size
268
+ if self.vision_embed_dim != vision_embed_dim:
269
+ raise ValueError(
270
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
271
+ f" ({vision_embed_dim})"
272
+ )
273
+ vision_image_size = vision_config.image_size
274
+ if self.vision_image_size != vision_image_size:
275
+ raise ValueError(
276
+ f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
277
+ f" ({vision_image_size})"
278
+ )
279
+
280
+ @classmethod
281
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
282
+ outputs = super(VGPT2Config, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
283
+ if isinstance(outputs, Tuple):
284
+ # When called with return_unused_kwargs=True, the first item will be the config
285
+ outputs[0].check_compatibilities()
286
+ else:
287
+ outputs.check_compatibilities()
288
+ return outputs
m4/models/vgpt2/modeling_vgpt2.py ADDED
@@ -0,0 +1,1384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ import math
19
+ import os
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.cuda.amp import autocast
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
29
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
30
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
31
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
32
+
33
+ from m4.models import DecoupledEmbedding, DecoupledLinear
34
+ from m4.models.common import (
35
+ expand_inputs_for_generation,
36
+ prepare_inputs_for_generation,
37
+ update_model_kwargs_for_generation,
38
+ )
39
+ from m4.models.custom_modules import VLOOMPreTrainedModelBase
40
+ from m4.models.perceiver.perceiver import PerceiverResampler
41
+ from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
42
+ from m4.training.utils import (
43
+ compute_perceiver_tflops_per_batch_per_gpu,
44
+ compute_tflops_per_batch_per_gpu,
45
+ freeze_model,
46
+ )
47
+ from m4.utils import logging
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "gpt2"
53
+ _CONFIG_FOR_DOC = "VGPT2Config"
54
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
55
+
56
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "gpt2",
58
+ "gpt2-medium",
59
+ "gpt2-large",
60
+ "gpt2-xl",
61
+ "distilgpt2",
62
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
63
+ ]
64
+
65
+
66
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
67
+ """Load tf checkpoints in a pytorch model"""
68
+ try:
69
+ import re
70
+
71
+ import tensorflow as tf
72
+ except ImportError:
73
+ logger.error(
74
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
75
+ "https://www.tensorflow.org/install/ for installation instructions."
76
+ )
77
+ raise
78
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
79
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
80
+ # Load weights from TF model
81
+ init_vars = tf.train.list_variables(tf_path)
82
+ names = []
83
+ arrays = []
84
+ for name, shape in init_vars:
85
+ logger.info(f"Loading TF weight {name} with shape {shape}")
86
+ array = tf.train.load_variable(tf_path, name)
87
+ names.append(name)
88
+ arrays.append(array.squeeze())
89
+
90
+ for name, array in zip(names, arrays):
91
+ name = name[6:] # skip "model/"
92
+ name = name.split("/")
93
+ pointer = model
94
+ for m_name in name:
95
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
96
+ scope_names = re.split(r"(\d+)", m_name)
97
+ else:
98
+ scope_names = [m_name]
99
+ if scope_names[0] == "w" or scope_names[0] == "g":
100
+ pointer = getattr(pointer, "weight")
101
+ elif scope_names[0] == "b":
102
+ pointer = getattr(pointer, "bias")
103
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
104
+ pointer = getattr(pointer, scope_names[0])
105
+ pointer = getattr(pointer, "weight")
106
+ else:
107
+ pointer = getattr(pointer, scope_names[0])
108
+ if len(scope_names) >= 2:
109
+ num = int(scope_names[1])
110
+ pointer = pointer[num]
111
+ try:
112
+ assert (
113
+ pointer.shape == array.shape
114
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
115
+ except AssertionError as e:
116
+ e.args += (pointer.shape, array.shape)
117
+ raise
118
+ logger.info(f"Initialize PyTorch weight {name}")
119
+ pointer.data = torch.from_numpy(array)
120
+ return model
121
+
122
+
123
+ class GPT2Attention(nn.Module):
124
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
125
+ super().__init__()
126
+
127
+ max_positions = config.max_position_embeddings
128
+ self.register_buffer(
129
+ "bias",
130
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
131
+ 1, 1, max_positions, max_positions
132
+ ),
133
+ )
134
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
135
+
136
+ self.embed_dim = config.hidden_size
137
+ self.num_heads = config.num_attention_heads
138
+ self.head_dim = self.embed_dim // self.num_heads
139
+ self.split_size = self.embed_dim
140
+ if self.head_dim * self.num_heads != self.embed_dim:
141
+ raise ValueError(
142
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
143
+ f" {self.num_heads})."
144
+ )
145
+
146
+ self.scale_attn_weights = config.scale_attn_weights
147
+ self.is_cross_attention = is_cross_attention
148
+
149
+ # Layer-wise attention scaling, reordering, and upcasting
150
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
151
+ self.layer_idx = layer_idx
152
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
153
+
154
+ if self.is_cross_attention:
155
+ in_dim = self.embed_dim if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
156
+ self.c_attn = Conv1D(2 * self.embed_dim, in_dim)
157
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
158
+ else:
159
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
160
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
161
+
162
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
163
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
164
+
165
+ self.pruned_heads = set()
166
+
167
+ def prune_heads(self, heads):
168
+ if len(heads) == 0:
169
+ return
170
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
171
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
172
+
173
+ # Prune conv1d layers
174
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
175
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
176
+
177
+ # Update hyper params
178
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
179
+ self.num_heads = self.num_heads - len(heads)
180
+ self.pruned_heads = self.pruned_heads.union(heads)
181
+
182
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
183
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
184
+
185
+ if self.scale_attn_weights:
186
+ attn_weights = attn_weights / torch.tensor(
187
+ value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
188
+ )
189
+
190
+ # Layer-wise attention scaling
191
+ if self.scale_attn_by_inverse_layer_idx:
192
+ attn_weights = attn_weights / float(self.layer_idx + 1)
193
+
194
+ if not self.is_cross_attention:
195
+ # if only "normal" attention layer implements causal mask
196
+ query_length, key_length = query.size(-2), key.size(-2)
197
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
198
+ mask_value = torch.finfo(attn_weights.dtype).min
199
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
200
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
201
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
202
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
203
+
204
+ if attention_mask is not None:
205
+ # Apply the attention mask
206
+ attn_weights = attn_weights + attention_mask
207
+
208
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
209
+
210
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
211
+ attn_weights = attn_weights.type(value.dtype)
212
+ attn_weights = self.attn_dropout(attn_weights)
213
+
214
+ # Mask heads if we want to
215
+ if head_mask is not None:
216
+ attn_weights = attn_weights * head_mask
217
+
218
+ attn_output = torch.matmul(attn_weights, value)
219
+
220
+ return attn_output, attn_weights
221
+
222
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
223
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
224
+ bsz, num_heads, q_seq_len, dk = query.size()
225
+ _, _, k_seq_len, _ = key.size()
226
+
227
+ # Preallocate attn_weights for `baddbmm`
228
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
229
+
230
+ # Compute Scale Factor
231
+ scale_factor = 1.0
232
+ if self.scale_attn_weights:
233
+ scale_factor /= float(value.size(-1)) ** 0.5
234
+
235
+ if self.scale_attn_by_inverse_layer_idx:
236
+ scale_factor /= float(self.layer_idx + 1)
237
+
238
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
239
+ with autocast(enabled=False):
240
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
241
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
242
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
243
+
244
+ if not self.is_cross_attention:
245
+ # if only "normal" attention layer implements causal mask
246
+ query_length, key_length = query.size(-2), key.size(-2)
247
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
248
+ mask_value = torch.finfo(attn_weights.dtype).min
249
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
250
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
251
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
252
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
253
+
254
+ if attention_mask is not None:
255
+ # Apply the attention mask
256
+ attn_weights = attn_weights + attention_mask
257
+
258
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
259
+
260
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
261
+ if attn_weights.dtype != torch.float32:
262
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
263
+ attn_weights = attn_weights.type(value.dtype)
264
+ attn_weights = self.attn_dropout(attn_weights)
265
+
266
+ # Mask heads if we want to
267
+ if head_mask is not None:
268
+ attn_weights = attn_weights * head_mask
269
+
270
+ attn_output = torch.matmul(attn_weights, value)
271
+
272
+ return attn_output, attn_weights
273
+
274
+ def _split_heads(self, tensor, num_heads, attn_head_size):
275
+ """
276
+ Splits hidden_size dim into attn_head_size and num_heads
277
+ """
278
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
279
+ tensor = tensor.view(new_shape)
280
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
281
+
282
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
283
+ """
284
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
285
+ """
286
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
287
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
288
+ return tensor.view(new_shape)
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
293
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ head_mask: Optional[torch.FloatTensor] = None,
296
+ encoder_hidden_states: Optional[torch.Tensor] = None,
297
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
298
+ use_cache: Optional[bool] = False,
299
+ output_attentions: Optional[bool] = False,
300
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
301
+ if encoder_hidden_states is not None:
302
+ if not hasattr(self, "q_attn"):
303
+ raise ValueError(
304
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
305
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
306
+ )
307
+
308
+ query = self.q_attn(hidden_states)
309
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
310
+ attention_mask = encoder_attention_mask
311
+ else:
312
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
313
+
314
+ query = self._split_heads(query, self.num_heads, self.head_dim)
315
+ key = self._split_heads(key, self.num_heads, self.head_dim)
316
+ value = self._split_heads(value, self.num_heads, self.head_dim)
317
+
318
+ if layer_past is not None:
319
+ past_key, past_value = layer_past
320
+ key = torch.cat((past_key, key), dim=-2)
321
+ value = torch.cat((past_value, value), dim=-2)
322
+
323
+ if use_cache is True:
324
+ present = (key, value)
325
+ else:
326
+ present = None
327
+
328
+ if self.reorder_and_upcast_attn:
329
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
330
+ else:
331
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
332
+
333
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
334
+ attn_output = self.c_proj(attn_output)
335
+ attn_output = self.resid_dropout(attn_output)
336
+
337
+ outputs = (attn_output, present)
338
+ if output_attentions:
339
+ outputs += (attn_weights,)
340
+
341
+ return outputs # a, present, (attentions)
342
+
343
+
344
+ class GPT2MLP(nn.Module):
345
+ def __init__(self, intermediate_size, config):
346
+ super().__init__()
347
+ embed_dim = config.hidden_size
348
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
349
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
350
+ self.act = ACT2FN[config.activation_function]
351
+ self.dropout = nn.Dropout(config.resid_pdrop)
352
+
353
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
354
+ hidden_states = self.c_fc(hidden_states)
355
+ hidden_states = self.act(hidden_states)
356
+ hidden_states = self.c_proj(hidden_states)
357
+ hidden_states = self.dropout(hidden_states)
358
+ return hidden_states
359
+
360
+
361
+ class GPT2Block(nn.Module):
362
+ def __init__(self, config, layer_idx=None):
363
+ super().__init__()
364
+ hidden_size = config.hidden_size
365
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
366
+
367
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
368
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
369
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
370
+
371
+ if config.add_cross_attention:
372
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
373
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
374
+
375
+ self.mlp = GPT2MLP(inner_dim, config)
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
380
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
381
+ attention_mask: Optional[torch.FloatTensor] = None,
382
+ head_mask: Optional[torch.FloatTensor] = None,
383
+ encoder_hidden_states: Optional[torch.Tensor] = None,
384
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
385
+ use_cache: Optional[bool] = False,
386
+ output_attentions: Optional[bool] = False,
387
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
388
+ residual = hidden_states
389
+ hidden_states = self.ln_1(hidden_states)
390
+ attn_outputs = self.attn(
391
+ hidden_states,
392
+ layer_past=layer_past,
393
+ attention_mask=attention_mask,
394
+ head_mask=head_mask,
395
+ use_cache=use_cache,
396
+ output_attentions=output_attentions,
397
+ )
398
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
399
+ outputs = attn_outputs[1:]
400
+ # residual connection
401
+ hidden_states = attn_output + residual
402
+
403
+ if encoder_hidden_states is not None:
404
+ # add one self-attention block for cross-attention
405
+ if not hasattr(self, "crossattention"):
406
+ raise ValueError(
407
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
408
+ "cross-attention layers by setting `config.add_cross_attention=True`"
409
+ )
410
+ residual = hidden_states
411
+ hidden_states = self.ln_cross_attn(hidden_states)
412
+ cross_attn_outputs = self.crossattention(
413
+ hidden_states,
414
+ attention_mask=attention_mask,
415
+ head_mask=head_mask,
416
+ encoder_hidden_states=encoder_hidden_states,
417
+ encoder_attention_mask=encoder_attention_mask,
418
+ output_attentions=output_attentions,
419
+ )
420
+ attn_output = cross_attn_outputs[0]
421
+ # residual connection
422
+ hidden_states = residual + attn_output
423
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
424
+
425
+ residual = hidden_states
426
+ hidden_states = self.ln_2(hidden_states)
427
+ feed_forward_hidden_states = self.mlp(hidden_states)
428
+ # residual connection
429
+ hidden_states = residual + feed_forward_hidden_states
430
+
431
+ if use_cache:
432
+ outputs = (hidden_states,) + outputs
433
+ else:
434
+ outputs = (hidden_states,) + outputs[1:]
435
+
436
+ return outputs # hidden_states, present, (attentions, cross_attentions)
437
+
438
+
439
+ class VGPT2GatedCrossAttentionBlock(nn.Module):
440
+ def __init__(self, config, layer_idx=None):
441
+ super().__init__()
442
+ hidden_size = config.hidden_size
443
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
444
+
445
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
446
+ self.cross_attn = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
447
+ self.mlp = GPT2MLP(inner_dim, config)
448
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
449
+ self.act = nn.Tanh()
450
+
451
+ if config.alpha_initializer == "zeros":
452
+ if config.alpha_type == "vector":
453
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, hidden_size))
454
+ self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
455
+ elif config.alpha_type == "float":
456
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
457
+ self.alpha_dense = nn.Parameter(torch.zeros(1))
458
+ else:
459
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
460
+
461
+ elif config.alpha_initializer == "ones":
462
+ if config.alpha_type == "vector":
463
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, hidden_size))
464
+ self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
465
+ elif config.alpha_type == "float":
466
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1))
467
+ self.alpha_dense = nn.Parameter(torch.ones(1))
468
+ else:
469
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
470
+
471
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
472
+ if config.alpha_type == "vector":
473
+ self.alpha_cross_attn = nn.Parameter(
474
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
475
+ )
476
+ self.alpha_dense = nn.Parameter(
477
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
478
+ )
479
+ elif config.alpha_type == "float":
480
+ self.alpha_cross_attn = nn.Parameter(
481
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
482
+ )
483
+ self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
484
+ else:
485
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
486
+
487
+ else:
488
+ raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
489
+
490
+ def forward(
491
+ self,
492
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
493
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ head_mask: Optional[torch.FloatTensor] = None,
496
+ image_hidden_states: Optional[torch.Tensor] = None,
497
+ image_attention_mask: Optional[torch.FloatTensor] = None,
498
+ use_cache: Optional[bool] = False,
499
+ output_attentions: Optional[bool] = False,
500
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
501
+ if image_hidden_states is None:
502
+ raise ValueError(
503
+ "`image_hidden_states` is required for VGPT2 cross attention module which are visual features to be"
504
+ " conditioned on."
505
+ )
506
+ # add one self-attention block for cross-attention
507
+
508
+ # TODO(aps): Handle cross attention in the outputs
509
+ # if not hasattr(self, "crossattention"):
510
+ # raise ValueError(
511
+ # f"If `image_hidden_states` are passed, {self} has to be instantiated with "
512
+ # "cross-attention layers by setting `config.add_cross_attention=True`"
513
+ # )
514
+ residual = hidden_states
515
+
516
+ hidden_states = self.ln_1(hidden_states)
517
+ cross_attn_outputs = self.cross_attn(
518
+ hidden_states,
519
+ attention_mask=attention_mask,
520
+ head_mask=head_mask,
521
+ encoder_hidden_states=image_hidden_states,
522
+ encoder_attention_mask=image_attention_mask,
523
+ output_attentions=output_attentions,
524
+ )
525
+ attn_output = cross_attn_outputs[0]
526
+ outputs = cross_attn_outputs[1:]
527
+ # residual connection
528
+ hidden_states = residual + self.act(self.alpha_cross_attn) * attn_output
529
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
530
+
531
+ residual = hidden_states
532
+ hidden_states = self.ln_2(hidden_states)
533
+ feed_forward_hidden_states = self.mlp(hidden_states)
534
+ # residual connection
535
+ hidden_states = residual + self.act(self.alpha_dense) * feed_forward_hidden_states
536
+
537
+ if use_cache:
538
+ outputs = (hidden_states,) + outputs
539
+ else:
540
+ outputs = (hidden_states,) + outputs[1:]
541
+
542
+ return outputs # hidden_states, present, (attentions, cross_attentions)
543
+
544
+
545
+ class VGPT2PreTrainedModel(VLOOMPreTrainedModelBase):
546
+ """
547
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
548
+ models.
549
+ """
550
+
551
+ config_class = VGPT2Config
552
+ load_tf_weights = load_tf_weights_in_gpt2
553
+ base_model_prefix = "transformer"
554
+ is_parallelizable = True
555
+ supports_gradient_checkpointing = True
556
+ _no_split_modules = ["GPT2Block"]
557
+
558
+ def __init__(self, *inputs, **kwargs):
559
+ super().__init__(*inputs, **kwargs)
560
+
561
+ def _init_weights(self, module):
562
+ """Initialize the weights."""
563
+ if isinstance(module, (nn.Linear, Conv1D)):
564
+ # Slightly different from the TF version which uses truncated_normal for initialization
565
+ # cf https://github.com/pytorch/pytorch/pull/5617
566
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
567
+ if module.bias is not None:
568
+ module.bias.data.zero_()
569
+ elif isinstance(module, nn.Embedding):
570
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
571
+ if module.padding_idx is not None:
572
+ module.weight.data[module.padding_idx].zero_()
573
+ elif isinstance(module, nn.LayerNorm):
574
+ module.bias.data.zero_()
575
+ module.weight.data.fill_(1.0)
576
+
577
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
578
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
579
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
580
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
581
+ #
582
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
583
+ for name, p in module.named_parameters():
584
+ if name == "c_proj.weight":
585
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
586
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
587
+
588
+ def _set_gradient_checkpointing(self, module, value=False):
589
+ if isinstance(module, VGPT2Model):
590
+ module.gradient_checkpointing = value
591
+
592
+ @classmethod
593
+ def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
594
+ # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
595
+ beheaded_model = model.transformer if hasattr(model, "transformer") else model
596
+ cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
597
+ beheaded_model.freeze_relevant_params(config)
598
+
599
+
600
+ GPT2_START_DOCSTRING = r"""
601
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
602
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
603
+ etc.)
604
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
605
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
606
+ and behavior.
607
+ Parameters:
608
+ config ([`VGPT2Config`]): Model configuration class with all the parameters of the model.
609
+ Initializing with a config file does not load the weights associated with the model, only the
610
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
611
+ """
612
+
613
+ GPT2_INPUTS_DOCSTRING = r"""
614
+ Args:
615
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
616
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
617
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
618
+ sequence tokens in the vocabulary.
619
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
620
+ `input_ids`.
621
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
622
+ [`PreTrainedTokenizer.__call__`] for details.
623
+ [What are input IDs?](../glossary#input-ids)
624
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
625
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
626
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
627
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
628
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
629
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
630
+ - 1 for tokens that are **not masked**,
631
+ - 0 for tokens that are **masked**.
632
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
633
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
634
+ `len(past_key_values) + len(input_ids)`
635
+ [What are attention masks?](../glossary#attention-mask)
636
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
637
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
638
+ 1]`:
639
+ - 0 corresponds to a *sentence A* token,
640
+ - 1 corresponds to a *sentence B* token.
641
+ [What are token type IDs?](../glossary#token-type-ids)
642
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
643
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
644
+ config.max_position_embeddings - 1]`.
645
+ [What are position IDs?](../glossary#position-ids)
646
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
647
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
648
+ - 1 indicates the head is **not masked**,
649
+ - 0 indicates the head is **masked**.
650
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
651
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
652
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
653
+ model's internal embedding lookup matrix.
654
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
655
+ `past_key_values`).
656
+ use_cache (`bool`, *optional*):
657
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
658
+ `past_key_values`).
659
+ output_attentions (`bool`, *optional*):
660
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
661
+ tensors for more detail.
662
+ output_hidden_states (`bool`, *optional*):
663
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
664
+ more detail.
665
+ return_dict (`bool`, *optional*):
666
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
667
+ """
668
+ PARALLELIZE_DOCSTRING = r"""
669
+ This is an experimental feature and is a subject to change at a moment's notice.
670
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
671
+ it will evenly distribute blocks across all devices.
672
+ Args:
673
+ device_map (`Dict[int, list]`, optional, defaults to None):
674
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
675
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
676
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
677
+ following number of attention modules:
678
+ - gpt2: 12
679
+ - gpt2-medium: 24
680
+ - gpt2-large: 36
681
+ - gpt2-xl: 48
682
+ Example:
683
+ ```python
684
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
685
+ model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
686
+ device_map = {
687
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
688
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
689
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
690
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
691
+ }
692
+ model.parallelize(device_map)
693
+ ```
694
+ """
695
+ DEPARALLELIZE_DOCSTRING = r"""
696
+ Moves the model to cpu from a model parallel state.
697
+ Example:
698
+ ```python
699
+ # On a 4 GPU machine with gpt2-large:
700
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
701
+ device_map = {
702
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
703
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
704
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
705
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
706
+ }
707
+ model.parallelize(device_map) # Splits the model across several devices
708
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
709
+ ```
710
+ """
711
+
712
+
713
+ @add_start_docstrings(
714
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
715
+ GPT2_START_DOCSTRING,
716
+ )
717
+ class VGPT2Model(VGPT2PreTrainedModel):
718
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
719
+
720
+ def __init__(self, config, vision_model=None):
721
+ super().__init__(config)
722
+
723
+ self.embed_dim = config.hidden_size
724
+ self.config = config
725
+
726
+ self.wte = DecoupledEmbedding(
727
+ num_embeddings=config.vocab_size,
728
+ num_additional_embeddings=config.additional_vocab_size,
729
+ embedding_dim=self.embed_dim,
730
+ partially_freeze=config.freeze_text_layers,
731
+ )
732
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
733
+
734
+ self.drop = nn.Dropout(config.embd_pdrop)
735
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
736
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
737
+
738
+ self.cross_layer_interval = config.cross_layer_interval
739
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
740
+ self.gated_cross_attn_layers = nn.ModuleList(
741
+ [VGPT2GatedCrossAttentionBlock(config, layer_idx=i) for i in range(num_cross_layers)]
742
+ )
743
+
744
+ # Perceiver Resampler
745
+ if config.use_resampler:
746
+ self.perceiver_resampler = PerceiverResampler(
747
+ self.config,
748
+ self.config.vision_embed_dim,
749
+ config.resampler_depth,
750
+ config.resampler_n_heads,
751
+ config.resampler_head_dim,
752
+ config.resampler_n_latents,
753
+ )
754
+ # Model parallel
755
+ self.model_parallel = False
756
+ self.device_map = None
757
+ self.gradient_checkpointing = False
758
+ # will be vocab_size because of indices starting from 0
759
+ self.image_token_idx = config.image_token_index
760
+
761
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
762
+ # this solves the losing of weights in `from_pretrained` on the main model
763
+ self.vision_model = vision_model
764
+
765
+ # Initialize weights and apply final processing
766
+ self.post_init()
767
+
768
+ self.freeze_relevant_params(config)
769
+
770
+ def freeze_relevant_params(self, config=None):
771
+ if config is None:
772
+ config = self.config
773
+
774
+ if config.freeze_text_layers:
775
+ self.freeze_text_layers()
776
+
777
+ if config.freeze_vision_layers:
778
+ freeze_model(self.vision_model)
779
+
780
+ def freeze_text_layers(self):
781
+ for module in [self.wpe, self.h, self.ln_f]:
782
+ freeze_model(module)
783
+
784
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
785
+ # TODO(aps): Implement later for VGPT2
786
+ def parallelize(self, device_map=None):
787
+ # Check validity of device_map
788
+ self.device_map = (
789
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
790
+ )
791
+ assert_device_map(self.device_map, len(self.h))
792
+ self.model_parallel = True
793
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
794
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
795
+ self.wte = self.wte.to(self.first_device)
796
+ self.wpe = self.wpe.to(self.first_device)
797
+ # Load onto devices
798
+ for k, v in self.device_map.items():
799
+ for block in v:
800
+ cuda_device = "cuda:" + str(k)
801
+ self.h[block] = self.h[block].to(cuda_device)
802
+ # ln_f to last
803
+ self.ln_f = self.ln_f.to(self.last_device)
804
+
805
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
806
+ # TODO(aps): Implement later for VGPT2
807
+ def deparallelize(self):
808
+ self.model_parallel = False
809
+ self.device_map = None
810
+ self.first_device = "cpu"
811
+ self.last_device = "cpu"
812
+ self.wte = self.wte.to("cpu")
813
+ self.wpe = self.wpe.to("cpu")
814
+ for index in range(len(self.h)):
815
+ self.h[index] = self.h[index].to("cpu")
816
+ self.ln_f = self.ln_f.to("cpu")
817
+ torch.cuda.empty_cache()
818
+
819
+ def get_input_embeddings(self):
820
+ return self.wte
821
+
822
+ def set_input_embeddings(self, new_embeddings):
823
+ self.wte = new_embeddings
824
+
825
+ def _prune_heads(self, heads_to_prune):
826
+ """
827
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
828
+ """
829
+ for layer, heads in heads_to_prune.items():
830
+ self.h[layer].attn.prune_heads(heads)
831
+
832
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
833
+ @add_code_sample_docstrings(
834
+ processor_class=_TOKENIZER_FOR_DOC,
835
+ checkpoint=_CHECKPOINT_FOR_DOC,
836
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
837
+ config_class=_CONFIG_FOR_DOC,
838
+ )
839
+ def forward(
840
+ self,
841
+ input_ids: Optional[torch.LongTensor] = None,
842
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
843
+ attention_mask: Optional[torch.FloatTensor] = None,
844
+ token_type_ids: Optional[torch.LongTensor] = None,
845
+ position_ids: Optional[torch.LongTensor] = None,
846
+ head_mask: Optional[torch.FloatTensor] = None,
847
+ inputs_embeds: Optional[torch.FloatTensor] = None,
848
+ pixel_values: Optional[torch.FloatTensor] = None,
849
+ image_embeddings: Optional[torch.FloatTensor] = None,
850
+ image_attention_mask: Optional[torch.Tensor] = None,
851
+ crossblock_head_mask: Optional[torch.Tensor] = None,
852
+ use_cache: Optional[bool] = None,
853
+ output_attentions: Optional[bool] = None,
854
+ output_hidden_states: Optional[bool] = None,
855
+ return_dict: Optional[bool] = None,
856
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
857
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
858
+ output_hidden_states = (
859
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
860
+ )
861
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
862
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
863
+
864
+ if input_ids is not None and inputs_embeds is not None:
865
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
866
+ elif input_ids is not None:
867
+ input_shape = input_ids.size()
868
+ input_ids = input_ids.view(-1, input_shape[-1])
869
+ batch_size = input_ids.shape[0]
870
+ elif inputs_embeds is not None:
871
+ input_shape = inputs_embeds.size()[:-1]
872
+ batch_size = inputs_embeds.shape[0]
873
+ else:
874
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
875
+
876
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
877
+
878
+ if token_type_ids is not None:
879
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
880
+ if position_ids is not None:
881
+ position_ids = position_ids.view(-1, input_shape[-1])
882
+
883
+ if past_key_values is None:
884
+ past_length = 0
885
+ past_key_values = tuple([None] * len(self.h))
886
+ else:
887
+ past_length = past_key_values[0][0].size(-2)
888
+ if position_ids is None:
889
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
890
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
891
+
892
+ # GPT2Attention mask.
893
+ if attention_mask is not None:
894
+ if batch_size <= 0:
895
+ raise ValueError("batch_size has to be defined and > 0")
896
+ attention_mask = attention_mask.view(batch_size, -1)
897
+ # We create a 3D attention mask from a 2D tensor mask.
898
+ # Sizes are [batch_size, 1, 1, to_seq_length]
899
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
900
+ # this attention mask is more simple than the triangular masking of causal attention
901
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
902
+ attention_mask = attention_mask[:, None, None, :]
903
+
904
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
905
+ # masked positions, this operation will create a tensor which is 0.0 for
906
+ # positions we want to attend and the dtype's smallest value for masked positions.
907
+ # Since we are adding it to the raw scores before the softmax, this is
908
+ # effectively the same as removing these entirely.
909
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
910
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
911
+
912
+ # If a 2D or 3D attention mask is provided for the cross-attention
913
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
914
+ if pixel_values is not None and image_embeddings is not None:
915
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
916
+ elif pixel_values is not None:
917
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
918
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
919
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
920
+ # Get sequence from the vision encoder
921
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
922
+ elif image_embeddings is not None:
923
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
924
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
925
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
926
+
927
+ if self.config.use_resampler:
928
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
929
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
930
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
931
+
932
+ # Make image_attention_mask compatible with hidden states
933
+ text_seq_len = image_attention_mask.size(1)
934
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
935
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
936
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
937
+ if image_hidden_states is not None:
938
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
939
+ image_hidden_shape = (image_batch_size, image_sequence_length)
940
+ if image_attention_mask is None:
941
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
942
+ image_attention_mask = self.invert_attention_mask(image_attention_mask)
943
+ else:
944
+ image_attention_mask = None
945
+
946
+ # Prepare head mask if needed
947
+ # 1.0 in head_mask indicate we keep the head
948
+ # attention_probs has shape bsz x n_heads x N x N
949
+ # head_mask has shape n_layer x batch x n_heads x N x N
950
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
951
+
952
+ if inputs_embeds is None:
953
+ inputs_embeds = self.wte(input_ids)
954
+
955
+ position_embeds = self.wpe(position_ids)
956
+ hidden_states = inputs_embeds + position_embeds
957
+
958
+ if token_type_ids is not None:
959
+ token_type_embeds = self.wte(token_type_ids)
960
+ hidden_states = hidden_states + token_type_embeds
961
+
962
+ hidden_states = self.drop(hidden_states)
963
+
964
+ output_shape = input_shape + (hidden_states.size(-1),)
965
+
966
+ presents = () if use_cache else None
967
+ all_self_attentions = () if output_attentions else None
968
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
969
+ all_hidden_states = () if output_hidden_states else None
970
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
971
+ # Model parallel
972
+ if self.model_parallel:
973
+ torch.cuda.set_device(hidden_states.device)
974
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
975
+ if layer_past is not None:
976
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
977
+ # Ensure that attention_mask is always on the same device as hidden_states
978
+ if attention_mask is not None:
979
+ attention_mask = attention_mask.to(hidden_states.device)
980
+ if isinstance(head_mask, torch.Tensor):
981
+ head_mask = head_mask.to(hidden_states.device)
982
+ if output_hidden_states:
983
+ all_hidden_states = all_hidden_states + (hidden_states,)
984
+
985
+ def vblock(
986
+ main_block,
987
+ hidden_states,
988
+ layer_past,
989
+ attention_mask,
990
+ layer_head_mask,
991
+ use_cache,
992
+ output_attentions,
993
+ image_hidden_states,
994
+ image_attention_mask,
995
+ layer_idx,
996
+ cross_layer_interval,
997
+ gated_cross_attn_layers,
998
+ ):
999
+ # TODO(aps): Add cross attention values to respective lists
1000
+ # TODO(aps): Add xblock head mask support
1001
+ if layer_idx % cross_layer_interval == 0:
1002
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
1003
+ outputs = xblock(
1004
+ hidden_states,
1005
+ attention_mask=attention_mask,
1006
+ image_hidden_states=image_hidden_states,
1007
+ image_attention_mask=image_attention_mask,
1008
+ use_cache=use_cache,
1009
+ output_attentions=output_attentions,
1010
+ )
1011
+ hidden_states = outputs[0]
1012
+
1013
+ outputs = main_block(
1014
+ hidden_states,
1015
+ layer_past=layer_past,
1016
+ attention_mask=attention_mask,
1017
+ head_mask=layer_head_mask,
1018
+ use_cache=use_cache,
1019
+ output_attentions=output_attentions,
1020
+ )
1021
+
1022
+ return outputs
1023
+
1024
+ if self.gradient_checkpointing and self.training:
1025
+ layer_past = None
1026
+ if use_cache:
1027
+ logger.warning_once(
1028
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1029
+ )
1030
+ use_cache = False
1031
+
1032
+ outputs = torch.utils.checkpoint.checkpoint(
1033
+ vblock,
1034
+ block,
1035
+ hidden_states,
1036
+ layer_past,
1037
+ attention_mask,
1038
+ head_mask[i],
1039
+ use_cache,
1040
+ output_attentions,
1041
+ image_hidden_states,
1042
+ image_attention_mask,
1043
+ i,
1044
+ self.cross_layer_interval,
1045
+ self.gated_cross_attn_layers,
1046
+ )
1047
+ else:
1048
+ outputs = vblock(
1049
+ block,
1050
+ hidden_states,
1051
+ layer_past=layer_past,
1052
+ attention_mask=attention_mask,
1053
+ layer_head_mask=head_mask[i],
1054
+ use_cache=use_cache,
1055
+ output_attentions=output_attentions,
1056
+ image_hidden_states=image_hidden_states,
1057
+ image_attention_mask=image_attention_mask,
1058
+ layer_idx=i,
1059
+ cross_layer_interval=self.cross_layer_interval,
1060
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
1061
+ )
1062
+
1063
+ hidden_states = outputs[0]
1064
+ if use_cache is True:
1065
+ presents = presents + (outputs[1],)
1066
+
1067
+ if output_attentions:
1068
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1069
+ if self.config.add_cross_attention:
1070
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1071
+
1072
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1073
+ if self.model_parallel:
1074
+ for k, v in self.device_map.items():
1075
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1076
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1077
+
1078
+ hidden_states = self.ln_f(hidden_states)
1079
+
1080
+ hidden_states = hidden_states.view(output_shape)
1081
+ # Add last hidden state
1082
+ if output_hidden_states:
1083
+ all_hidden_states = all_hidden_states + (hidden_states,)
1084
+
1085
+ if not return_dict:
1086
+ return tuple(
1087
+ v
1088
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1089
+ if v is not None
1090
+ )
1091
+
1092
+ return BaseModelOutputWithPastAndCrossAttentions(
1093
+ last_hidden_state=hidden_states,
1094
+ past_key_values=presents,
1095
+ hidden_states=all_hidden_states,
1096
+ attentions=all_self_attentions,
1097
+ cross_attentions=all_cross_attentions,
1098
+ )
1099
+
1100
+
1101
+ @add_start_docstrings(
1102
+ """
1103
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1104
+ embeddings).
1105
+ """,
1106
+ GPT2_START_DOCSTRING,
1107
+ )
1108
+ class VGPT2LMHeadModel(VGPT2PreTrainedModel):
1109
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
1110
+
1111
+ def __init__(self, config, vision_model=None):
1112
+ super().__init__(config)
1113
+ self.transformer = VGPT2Model(config, vision_model=vision_model)
1114
+ self.lm_head = DecoupledLinear(
1115
+ in_features=config.n_embd,
1116
+ out_features=config.vocab_size,
1117
+ out_additional_features=config.additional_vocab_size,
1118
+ bias=False,
1119
+ partially_freeze=config.freeze_lm_head,
1120
+ )
1121
+
1122
+ # Model parallel
1123
+ self.model_parallel = False
1124
+ self.device_map = None
1125
+
1126
+ # Initialize weights and apply final processing
1127
+ self.post_init()
1128
+
1129
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1130
+ def parallelize(self, device_map=None):
1131
+ self.device_map = (
1132
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1133
+ if device_map is None
1134
+ else device_map
1135
+ )
1136
+ assert_device_map(self.device_map, len(self.transformer.h))
1137
+ self.transformer.parallelize(self.device_map)
1138
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1139
+ self.model_parallel = True
1140
+
1141
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1142
+ def deparallelize(self):
1143
+ self.transformer.deparallelize()
1144
+ self.transformer = self.transformer.to("cpu")
1145
+ self.lm_head = self.lm_head.to("cpu")
1146
+ self.model_parallel = False
1147
+ torch.cuda.empty_cache()
1148
+
1149
+ def get_output_embeddings(self):
1150
+ return self.lm_head
1151
+
1152
+ def set_output_embeddings(self, new_embeddings):
1153
+ self.lm_head = new_embeddings
1154
+
1155
+ def tie_weights(self):
1156
+ """
1157
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1158
+ """
1159
+ output_embeddings = self.get_output_embeddings()
1160
+ input_embeddings = self.get_input_embeddings()
1161
+
1162
+ if getattr(self.config, "tie_word_embeddings", True):
1163
+ output_embeddings.weight = input_embeddings.weight
1164
+ if input_embeddings.num_additional_embeddings > 0:
1165
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1166
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1167
+
1168
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1169
+ output_embeddings.out_features = input_embeddings.num_embeddings
1170
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1171
+ input_embeddings, "num_additional_embeddings"
1172
+ ):
1173
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1174
+
1175
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1176
+ return prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1177
+
1178
+ @staticmethod
1179
+ def _expand_inputs_for_generation(
1180
+ *args,
1181
+ **model_kwargs,
1182
+ ):
1183
+ return expand_inputs_for_generation(*args, **model_kwargs)
1184
+
1185
+ @staticmethod
1186
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
1187
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder)
1188
+
1189
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1190
+ @add_code_sample_docstrings(
1191
+ processor_class=_TOKENIZER_FOR_DOC,
1192
+ checkpoint=_CHECKPOINT_FOR_DOC,
1193
+ output_type=CausalLMOutputWithCrossAttentions,
1194
+ config_class=_CONFIG_FOR_DOC,
1195
+ )
1196
+ def forward(
1197
+ self,
1198
+ input_ids: Optional[torch.LongTensor] = None,
1199
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1200
+ attention_mask: Optional[torch.FloatTensor] = None,
1201
+ token_type_ids: Optional[torch.LongTensor] = None,
1202
+ position_ids: Optional[torch.LongTensor] = None,
1203
+ head_mask: Optional[torch.FloatTensor] = None,
1204
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1205
+ pixel_values: Optional[torch.FloatTensor] = None,
1206
+ image_embeddings: Optional[torch.FloatTensor] = None,
1207
+ image_attention_mask: Optional[torch.Tensor] = None,
1208
+ crossblock_head_mask: Optional[torch.Tensor] = None,
1209
+ labels: Optional[torch.LongTensor] = None,
1210
+ use_cache: Optional[bool] = None,
1211
+ output_attentions: Optional[bool] = None,
1212
+ output_hidden_states: Optional[bool] = None,
1213
+ return_dict: Optional[bool] = None,
1214
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1215
+ r"""
1216
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1217
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1218
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1219
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1220
+ """
1221
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1222
+
1223
+ transformer_outputs = self.transformer(
1224
+ input_ids,
1225
+ past_key_values=past_key_values,
1226
+ attention_mask=attention_mask,
1227
+ token_type_ids=token_type_ids,
1228
+ position_ids=position_ids,
1229
+ head_mask=head_mask,
1230
+ inputs_embeds=inputs_embeds,
1231
+ pixel_values=pixel_values,
1232
+ image_embeddings=image_embeddings,
1233
+ image_attention_mask=image_attention_mask,
1234
+ crossblock_head_mask=crossblock_head_mask,
1235
+ use_cache=use_cache,
1236
+ output_attentions=output_attentions,
1237
+ output_hidden_states=output_hidden_states,
1238
+ return_dict=return_dict,
1239
+ )
1240
+ hidden_states = transformer_outputs[0]
1241
+
1242
+ # Set device for model parallelism
1243
+ if self.model_parallel:
1244
+ torch.cuda.set_device(self.transformer.first_device)
1245
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1246
+
1247
+ lm_logits = self.lm_head(hidden_states)
1248
+
1249
+ loss = None
1250
+ if labels is not None:
1251
+ # Shift so that tokens < n predict n
1252
+ if attention_mask is not None:
1253
+ shift_attention_mask = attention_mask[..., 1:]
1254
+ shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1255
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1256
+ else:
1257
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1258
+ shift_labels = labels[..., 1:].contiguous()
1259
+ # Flatten the tokens
1260
+ loss_fct = CrossEntropyLoss()
1261
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1262
+
1263
+ if not return_dict:
1264
+ output = (lm_logits,) + transformer_outputs[1:]
1265
+ return ((loss,) + output) if loss is not None else output
1266
+
1267
+ return CausalLMOutputWithCrossAttentions(
1268
+ loss=loss,
1269
+ logits=lm_logits,
1270
+ past_key_values=transformer_outputs.past_key_values,
1271
+ hidden_states=transformer_outputs.hidden_states,
1272
+ attentions=transformer_outputs.attentions,
1273
+ cross_attentions=transformer_outputs.cross_attentions,
1274
+ )
1275
+
1276
+ @staticmethod
1277
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
1278
+ """
1279
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1280
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1281
+ beam_idx at every generation step.
1282
+ """
1283
+ return tuple(
1284
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1285
+ for layer_past in past
1286
+ )
1287
+
1288
+ def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
1289
+ config_vl_model = self.config
1290
+
1291
+ language_embed_size = config_vl_model.n_embd
1292
+ num_language_layers = config_vl_model.n_layer
1293
+ ffn_inner_size = config_vl_model.n_inner
1294
+
1295
+ vision_config = self.transformer.vision_model.config
1296
+ if hasattr(vision_config, "vision_config"):
1297
+ vision_config = vision_config.vision_config
1298
+
1299
+ # Get vision model blocks infos
1300
+ vision_patch_size = vision_config.patch_size
1301
+ vision_hidden_size = vision_config.hidden_size
1302
+ num_vision_layers = vision_config.num_hidden_layers
1303
+ # The +1 is for the CLS token
1304
+ single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
1305
+ vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
1306
+
1307
+ # Get language and cross-att blocks infos
1308
+ num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
1309
+ language_seq_len = data_param.max_seq_len
1310
+ language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1311
+ cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1312
+ k_v_cross_attn_seq_len = (
1313
+ (self.config.resampler_n_latents * max_num_images)
1314
+ if self.config.use_resampler
1315
+ else (single_image_seq_len * max_num_images)
1316
+ )
1317
+
1318
+ language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1319
+ num_layers=num_language_layers,
1320
+ batch_size=hparams.batch_size_per_gpu,
1321
+ q_seq_len=language_seq_len,
1322
+ k_seq_len=language_seq_len,
1323
+ hidden_size=language_embed_size,
1324
+ kv_in_dim=language_embed_size,
1325
+ ff_exp_factor=language_exp_factor,
1326
+ grad_acc_size=hparams.grad_acc_size,
1327
+ swiglu=False,
1328
+ vocab_size=tokenizer.vocab_size,
1329
+ count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
1330
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1331
+ )
1332
+ cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1333
+ num_layers=num_cross_attn_layers,
1334
+ batch_size=hparams.batch_size_per_gpu,
1335
+ q_seq_len=language_seq_len,
1336
+ k_seq_len=k_v_cross_attn_seq_len,
1337
+ hidden_size=language_embed_size,
1338
+ kv_in_dim=vision_hidden_size,
1339
+ ff_exp_factor=cross_att_exp_factor,
1340
+ grad_acc_size=hparams.grad_acc_size,
1341
+ swiglu=False,
1342
+ vocab_size=None,
1343
+ count_backward=True,
1344
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1345
+ )
1346
+ vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1347
+ num_layers=num_vision_layers,
1348
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1349
+ q_seq_len=single_image_seq_len,
1350
+ k_seq_len=single_image_seq_len,
1351
+ hidden_size=vision_hidden_size,
1352
+ kv_in_dim=vision_hidden_size,
1353
+ ff_exp_factor=vision_exp_factor,
1354
+ grad_acc_size=hparams.grad_acc_size,
1355
+ swiglu=False,
1356
+ vocab_size=None,
1357
+ count_backward=not hparams.model_params["freeze_vision_layers"],
1358
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1359
+ )
1360
+ if self.config.use_resampler:
1361
+ perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
1362
+ num_layers=self.config.resampler_depth,
1363
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1364
+ q_seq_len=self.config.resampler_n_latents,
1365
+ vision_embed_seq_len=single_image_seq_len,
1366
+ q_k_v_input_dim=vision_hidden_size,
1367
+ attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
1368
+ ff_exp_factor=cross_att_exp_factor,
1369
+ count_backward=True,
1370
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1371
+ )
1372
+ flop_count = (
1373
+ language_tflops_per_batch_per_gpu
1374
+ + cross_attention_tflops_per_batch_per_gpu
1375
+ + vision_tflops_per_batch_per_gpu
1376
+ + perceiver_tflops_per_batch_per_gpu
1377
+ )
1378
+ else:
1379
+ flop_count = (
1380
+ language_tflops_per_batch_per_gpu
1381
+ + cross_attention_tflops_per_batch_per_gpu
1382
+ + vision_tflops_per_batch_per_gpu
1383
+ )
1384
+ return flop_count
m4/models/vgpt_neo/__init__.py ADDED
File without changes
m4/models/vgpt_neo/configuration_vgpt_neo.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ GPT Neo model configuration"""
16
+ import os
17
+ from typing import Tuple, Union
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "EleutherAI/gpt-neo-125M": "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
28
+ "EleutherAI/gpt-neo-1.3B": "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
29
+ # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
30
+ }
31
+
32
+
33
+ class VGPTNeoConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`GPTNeoModel`]. It is used to instantiate a GPT
36
+ Neo model according to the specified arguments, defining the model architecture. Instantiating a configuration with
37
+ the defaults will yield a similar configuration to that of the GPTNeo
38
+ [EleutherAI/gpt-neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ TODO: this doc is completely out of sync with the actual args
44
+
45
+ Args:
46
+ vocab_size (`int`, *optional*, defaults to 50257):
47
+ Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
48
+ `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
49
+ tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
50
+ additional_vocab_size (`int`, *optional`, defaults to 0):
51
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
52
+ are always trainable whereas regular vocab tokens can be frozen or not.
53
+ attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
54
+ The type of attention for each layer in a `List` of the following format `[[["attention_type"],
55
+ num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
56
+ value of `attention_type` from `["global", "local"]`
57
+ hidden_size (`int`, *optional*, defaults to 2048):
58
+ Dimensionality of the encoder layers and the pooler layer.
59
+ num_layers (`int`, *optional*, defaults to 24):
60
+ Number of hidden layers in the Transformer encoder.
61
+ num_heads (`int`, *optional*, defaults to 16):
62
+ Number of attention heads for each attention layer in the Transformer encoder.
63
+ intermediate_size (`int`, *optional*, defaults to 8192):
64
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
65
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
66
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
67
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
68
+ embed_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
70
+ attention_dropout (`float`, *optional*, defaults to 0.0):
71
+ The dropout ratio for the attention probabilities.
72
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
73
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
74
+ just in case (e.g., 512 or 1024 or 2048).
75
+ type_vocab_size (`int`, *optional*, defaults to 2):
76
+ The vocabulary size of the `token_type_ids` passed when calling [`GPTNeoModel`].
77
+ initializer_range (`float`, *optional*, defaults to 0.02):
78
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79
+ alpha_initializer (`str`, *optional*, defaults to `"ones"`):
80
+ Initialization type for the alphas.
81
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
82
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
83
+ alpha_type (`str`, *optional*, defaults to `"vector"`):
84
+ Whether the gating alphas should be vectors or single floats.
85
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
86
+ The epsilon used by the layer normalization layers.
87
+ use_cache (`bool`, *optional*, defaults to `True`):
88
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
89
+ relevant if `config.is_decoder=True`.
90
+ cross_layer_interval (`int`, *optional*, default to 1)
91
+ Interval for cross attention (from text to image) layers.
92
+ Example:
93
+ ```python
94
+ >>> from transformers import GPTNeoConfig, GPTNeoModel
95
+ >>> # Initializing a GPTNeo EleutherAI/gpt-neo-1.3B style configuration
96
+ >>> configuration = GPTNeoConfig()
97
+ >>> # Initializing a model (with random weights) from the EleutherAI/gpt-neo-1.3B style configuration
98
+ >>> model = GPTNeoModel(configuration)
99
+ >>> # Accessing the model configuration
100
+ >>> configuration = model.config
101
+ ```"""
102
+ model_type = "vgpt_neo"
103
+ keys_to_ignore_at_inference = ["past_key_values"]
104
+ attribute_map = {"num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=50257,
109
+ additional_vocab_size=0,
110
+ max_position_embeddings=2048,
111
+ hidden_size=2048,
112
+ num_layers=24,
113
+ attention_types=[[["global", "local"], 12]],
114
+ num_heads=16,
115
+ intermediate_size=None,
116
+ window_size=256,
117
+ activation_function="gelu_new",
118
+ resid_dropout=0.0,
119
+ embed_dropout=0.0,
120
+ attention_dropout=0.0,
121
+ layer_norm_epsilon=1e-5,
122
+ initializer_range=0.02,
123
+ alpha_initializer="ones",
124
+ alphas_initializer_range=0.0,
125
+ alpha_type="vector",
126
+ summary_type="cls_index",
127
+ summary_use_proj=True,
128
+ summary_activation=None,
129
+ summary_proj_to_labels=True,
130
+ summary_first_dropout=0.1,
131
+ use_cache=True,
132
+ bos_token_id=50256,
133
+ eos_token_id=50256,
134
+ cross_layer_interval=1,
135
+ tie_word_embeddings=False,
136
+ freeze_text_layers=True,
137
+ freeze_lm_head=False,
138
+ freeze_vision_layers=True,
139
+ vision_model_name="google/vit-base-patch16-224",
140
+ vision_model_params="{}",
141
+ vision_embed_dim=768,
142
+ vision_image_size=224,
143
+ image_token_index=50257,
144
+ use_resampler=False,
145
+ resampler_n_latents=64,
146
+ resampler_depth=6,
147
+ resampler_n_heads=16,
148
+ resampler_head_dim=96,
149
+ **kwargs,
150
+ ):
151
+ self.vocab_size = vocab_size
152
+ self.additional_vocab_size = additional_vocab_size
153
+ self.max_position_embeddings = max_position_embeddings
154
+ self.hidden_size = hidden_size
155
+ self.num_layers = num_layers
156
+ self.num_heads = num_heads
157
+ self.intermediate_size = intermediate_size
158
+ self.window_size = window_size
159
+ self.activation_function = activation_function
160
+ self.resid_dropout = resid_dropout
161
+ self.embed_dropout = embed_dropout
162
+ self.attention_dropout = attention_dropout
163
+ self.layer_norm_epsilon = layer_norm_epsilon
164
+ self.initializer_range = initializer_range
165
+ self.alpha_initializer = alpha_initializer
166
+ self.alphas_initializer_range = alphas_initializer_range
167
+ self.alpha_type = alpha_type
168
+ self.summary_type = summary_type
169
+ self.summary_use_proj = summary_use_proj
170
+ self.summary_activation = summary_activation
171
+ self.summary_first_dropout = summary_first_dropout
172
+ self.summary_proj_to_labels = summary_proj_to_labels
173
+ self.use_cache = use_cache
174
+
175
+ self.bos_token_id = bos_token_id
176
+ self.eos_token_id = eos_token_id
177
+ self.cross_layer_interval = cross_layer_interval
178
+ self.freeze_vision_layers = freeze_vision_layers
179
+ self.vision_model_name = vision_model_name
180
+ self.vision_model_params = vision_model_params
181
+
182
+ self.tie_word_embeddings = tie_word_embeddings
183
+ self.freeze_text_layers = freeze_text_layers
184
+ self.freeze_lm_head = freeze_lm_head
185
+ self.image_token_index = image_token_index
186
+ self.attention_types = attention_types
187
+ self.attention_layers = self.expand_attention_types_params(attention_types)
188
+
189
+ self.vision_embed_dim = vision_embed_dim
190
+ self.vision_image_size = vision_image_size
191
+
192
+ # Resampler params
193
+ self.use_resampler = use_resampler
194
+ self.resampler_n_latents = resampler_n_latents
195
+ self.resampler_depth = resampler_depth
196
+ self.resampler_n_heads = resampler_n_heads
197
+ self.resampler_head_dim = resampler_head_dim
198
+
199
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
200
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
201
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
202
+ # of this object many attributes have default values and haven't yet been overridden.
203
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
204
+
205
+ super().__init__(
206
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
207
+ )
208
+
209
+ def check_compatibilities(self):
210
+ if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
211
+ raise ValueError(
212
+ "if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
213
+ )
214
+
215
+ vision_model_params = eval(self.vision_model_params)
216
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
217
+ if hasattr(config, "vision_config"):
218
+ vision_config = config.vision_config
219
+ else:
220
+ vision_config = config
221
+ vision_embed_dim = vision_config.hidden_size
222
+ if self.vision_embed_dim != vision_embed_dim:
223
+ raise ValueError(
224
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
225
+ f" ({vision_embed_dim})"
226
+ )
227
+ vision_image_size = vision_config.image_size
228
+ if self.vision_image_size != vision_image_size:
229
+ raise ValueError(
230
+ f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
231
+ f" ({vision_image_size})"
232
+ )
233
+
234
+ @classmethod
235
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
236
+ outputs = super(VGPTNeoConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
237
+ if isinstance(outputs, Tuple):
238
+ # When called with return_unused_kwargs=True, the first item will be the config
239
+ outputs[0].check_compatibilities()
240
+ else:
241
+ outputs.check_compatibilities()
242
+ return outputs
243
+
244
+ @staticmethod
245
+ def expand_attention_types_params(attention_types):
246
+ attentions = []
247
+ for item in attention_types:
248
+ for _ in range(item[1]):
249
+ attentions.extend(item[0])
250
+ return attentions
m4/models/vgpt_neo/modeling_vgpt_neo.py ADDED
@@ -0,0 +1,1182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch GPT Neo model."""
16
+
17
+
18
+ import os
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
27
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
28
+
29
+ from m4.models import DecoupledEmbedding, DecoupledLinear
30
+ from m4.models.common import (
31
+ expand_inputs_for_generation,
32
+ prepare_inputs_for_generation,
33
+ update_model_kwargs_for_generation,
34
+ )
35
+ from m4.models.custom_modules import VLOOMPreTrainedModelBase
36
+ from m4.models.perceiver.perceiver import PerceiverResampler
37
+ from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig
38
+ from m4.training.utils import (
39
+ compute_perceiver_tflops_per_batch_per_gpu,
40
+ compute_tflops_per_batch_per_gpu,
41
+ freeze_model,
42
+ )
43
+ from m4.utils import logging
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
49
+ _CONFIG_FOR_DOC = "VGPTNeoConfig"
50
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
51
+
52
+ GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
53
+ "EleutherAI/gpt-neo-125M",
54
+ "EleutherAI/gpt-neo-1.3B",
55
+ # See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
56
+ ]
57
+
58
+
59
+ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
60
+ """Load tf checkpoints in a pytorch model"""
61
+ try:
62
+ import re
63
+
64
+ import tensorflow as tf
65
+ except ImportError:
66
+ logger.error(
67
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
68
+ "https://www.tensorflow.org/install/ for installation instructions."
69
+ )
70
+ raise
71
+ tf_path = os.path.abspath(gpt_neo_checkpoint_path)
72
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
73
+ # Load weights from TF model
74
+ init_vars = tf.train.list_variables(tf_path)
75
+ names = []
76
+ arrays = []
77
+ for name, shape in init_vars:
78
+ if "global_step" not in name and "adam" not in name:
79
+ array = tf.train.load_variable(tf_path, name)
80
+ array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
81
+ name = name.replace("attn/q", "attn/attention/q_proj/w")
82
+ name = name.replace("attn/k", "attn/attention/k_proj/w")
83
+ name = name.replace("attn/v", "attn/attention/v_proj/w")
84
+ name = name.replace("attn/o", "attn/attention/out_proj/w")
85
+ name = name.replace("norm_1", "ln_1")
86
+ name = name.replace("norm_2", "ln_2")
87
+ name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
88
+ name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
89
+ name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
90
+ name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
91
+ name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")
92
+
93
+ names.append(name)
94
+ arrays.append(array)
95
+
96
+ for name, array in zip(names, arrays):
97
+ name = name[5:] # skip "gpt2/"
98
+ name = name.split("/")
99
+ pointer = model.transformer
100
+ for m_name in name:
101
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
102
+ scope_names = re.split(r"(\d+)", m_name)
103
+ else:
104
+ scope_names = [m_name]
105
+ if scope_names[0] == "w" or scope_names[0] == "g":
106
+ pointer = getattr(pointer, "weight")
107
+ elif scope_names[0] == "b":
108
+ pointer = getattr(pointer, "bias")
109
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
110
+ pointer = getattr(pointer, scope_names[0])
111
+ pointer = getattr(pointer, "weight")
112
+ else:
113
+ pointer = getattr(pointer, scope_names[0])
114
+ if len(scope_names) >= 2:
115
+ num = int(scope_names[1])
116
+ pointer = pointer[num]
117
+
118
+ if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
119
+ array = array.transpose()
120
+
121
+ if name == ["wte"]:
122
+ # if vocab is padded, then trim off the padding embeddings
123
+ array = array[: config.vocab_size]
124
+
125
+ if pointer.shape != array.shape:
126
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}")
127
+
128
+ print(f"Initialize PyTorch weight {name}")
129
+ pointer.data = torch.from_numpy(array)
130
+
131
+ # init the final linear layer using word embeddings
132
+ embs = model.transformer.wte.weight
133
+ lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
134
+ lin.weight = embs
135
+ model.set_output_embeddings(lin)
136
+ return model
137
+
138
+
139
+ class GPTNeoSelfAttention(nn.Module):
140
+ def __init__(self, config, attention_type, is_cross_attention=False):
141
+ super().__init__()
142
+
143
+ max_positions = config.max_position_embeddings
144
+ bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
145
+ 1, 1, max_positions, max_positions
146
+ )
147
+
148
+ # local causal self attention is a sliding window where each token can only attend to the previous
149
+ # window_size tokens. This is implemented by updating the causal mask such that for each token
150
+ # all other tokens are masked except the previous window_size tokens.
151
+ if attention_type == "local":
152
+ bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
153
+ self.is_cross_attention = is_cross_attention
154
+ self.register_buffer("bias", bias)
155
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
156
+
157
+ self.attn_dropout = nn.Dropout(float(config.attention_dropout))
158
+ self.resid_dropout = nn.Dropout(float(config.resid_dropout))
159
+
160
+ self.embed_dim = config.hidden_size
161
+ self.num_heads = config.num_heads
162
+ self.head_dim = self.embed_dim // self.num_heads
163
+ if self.head_dim * self.num_heads != self.embed_dim:
164
+ raise ValueError(
165
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
166
+ f" {self.num_heads})."
167
+ )
168
+ if self.is_cross_attention:
169
+ in_dim = self.embed_dim if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
170
+ self.k_proj = nn.Linear(in_dim, self.embed_dim, bias=False)
171
+ self.v_proj = nn.Linear(in_dim, self.embed_dim, bias=False)
172
+ else:
173
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
174
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
175
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
176
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
177
+
178
+ def _split_heads(self, tensor, num_heads, attn_head_size):
179
+ """
180
+ Splits hidden_size dim into attn_head_size and num_heads
181
+ """
182
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
183
+ tensor = tensor.view(new_shape)
184
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
185
+
186
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
187
+ """
188
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
189
+ """
190
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
191
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
192
+ return tensor.view(new_shape)
193
+
194
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
195
+ # Keep the attention weights computation in fp32 to avoid overflow issues
196
+ query = query.to(torch.float32)
197
+ key = key.to(torch.float32)
198
+
199
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
200
+
201
+ if not self.is_cross_attention:
202
+ query_length, key_length = query.size(-2), key.size(-2)
203
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
204
+ mask_value = torch.finfo(attn_weights.dtype).min
205
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
206
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
207
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
208
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
209
+
210
+ if attention_mask is not None:
211
+ # Apply the attention mask
212
+ attn_weights = attn_weights + attention_mask
213
+
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
215
+ attn_weights = attn_weights.to(value.dtype)
216
+ attn_weights = self.attn_dropout(attn_weights)
217
+
218
+ # Mask heads if we want to
219
+ if head_mask is not None:
220
+ attn_weights = attn_weights * head_mask
221
+
222
+ attn_output = torch.matmul(attn_weights, value)
223
+
224
+ return attn_output, attn_weights
225
+
226
+ def forward(
227
+ self,
228
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
229
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
230
+ attention_mask: Optional[torch.FloatTensor] = None,
231
+ head_mask: Optional[torch.FloatTensor] = None,
232
+ encoder_hidden_states: Optional[torch.Tensor] = None,
233
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
234
+ use_cache: Optional[bool] = False,
235
+ output_attentions: Optional[bool] = False,
236
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
237
+ if encoder_hidden_states is not None:
238
+ key = self.k_proj(encoder_hidden_states)
239
+ value = self.v_proj(encoder_hidden_states)
240
+ attention_mask = encoder_attention_mask
241
+ else:
242
+ key = self.k_proj(hidden_states)
243
+ value = self.v_proj(hidden_states)
244
+ query = self.q_proj(hidden_states)
245
+
246
+ query = self._split_heads(query, self.num_heads, self.head_dim)
247
+ key = self._split_heads(key, self.num_heads, self.head_dim)
248
+ value = self._split_heads(value, self.num_heads, self.head_dim)
249
+
250
+ if layer_past is not None:
251
+ past_key = layer_past[0]
252
+ past_value = layer_past[1]
253
+ key = torch.cat((past_key, key), dim=-2)
254
+ value = torch.cat((past_value, value), dim=-2)
255
+
256
+ if use_cache is True:
257
+ present = (key, value)
258
+ else:
259
+ present = None
260
+
261
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
262
+
263
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
264
+ attn_output = self.out_proj(attn_output)
265
+ attn_output = self.resid_dropout(attn_output)
266
+
267
+ outputs = (attn_output, present)
268
+ if output_attentions:
269
+ outputs += (attn_weights,)
270
+
271
+ return outputs # a, present, (attentions)
272
+
273
+
274
+ class GPTNeoAttention(nn.Module):
275
+ def __init__(self, config, layer_id=0, is_cross_attention=False):
276
+ super().__init__()
277
+ self.layer_id = layer_id
278
+ self.attention_layers = config.attention_layers
279
+ self.attention_type = self.attention_layers[layer_id]
280
+ if self.attention_type in ["global", "local"]:
281
+ self.attention = GPTNeoSelfAttention(config, self.attention_type, is_cross_attention=is_cross_attention)
282
+ else:
283
+ raise NotImplementedError(
284
+ "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
285
+ f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
286
+ )
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
291
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
292
+ attention_mask: Optional[torch.FloatTensor] = None,
293
+ head_mask: Optional[torch.FloatTensor] = None,
294
+ encoder_hidden_states: Optional[torch.Tensor] = None,
295
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
296
+ use_cache: Optional[bool] = False,
297
+ output_attentions: Optional[bool] = False,
298
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
299
+ return self.attention(
300
+ hidden_states,
301
+ attention_mask=attention_mask,
302
+ layer_past=layer_past,
303
+ head_mask=head_mask,
304
+ encoder_hidden_states=encoder_hidden_states,
305
+ encoder_attention_mask=encoder_attention_mask,
306
+ use_cache=use_cache,
307
+ output_attentions=output_attentions,
308
+ )
309
+
310
+
311
+ class GPTNeoMLP(nn.Module):
312
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
313
+ super().__init__()
314
+ embed_dim = config.hidden_size
315
+ self.c_fc = nn.Linear(embed_dim, intermediate_size)
316
+ self.c_proj = nn.Linear(intermediate_size, embed_dim)
317
+ self.act = ACT2FN[config.activation_function]
318
+ self.dropout = nn.Dropout(float(config.resid_dropout))
319
+
320
+ def forward(self, hidden_states):
321
+ hidden_states = self.c_fc(hidden_states)
322
+ hidden_states = self.act(hidden_states)
323
+ hidden_states = self.c_proj(hidden_states)
324
+ hidden_states = self.dropout(hidden_states)
325
+ return hidden_states
326
+
327
+
328
+ class GPTNeoBlock(nn.Module):
329
+ def __init__(self, config, layer_id):
330
+ super().__init__()
331
+ hidden_size = config.hidden_size
332
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
333
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
334
+ self.attn = GPTNeoAttention(config, layer_id, is_cross_attention=False)
335
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
336
+
337
+ self.mlp = GPTNeoMLP(inner_dim, config)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
342
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
343
+ attention_mask: Optional[torch.FloatTensor] = None,
344
+ head_mask: Optional[torch.FloatTensor] = None,
345
+ encoder_hidden_states: Optional[torch.Tensor] = None,
346
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
347
+ use_cache: Optional[bool] = False,
348
+ output_attentions: Optional[bool] = False,
349
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
350
+ residual = hidden_states
351
+ hidden_states = self.ln_1(hidden_states)
352
+ attn_outputs = self.attn(
353
+ hidden_states,
354
+ layer_past=layer_past,
355
+ attention_mask=attention_mask,
356
+ head_mask=head_mask,
357
+ use_cache=use_cache,
358
+ output_attentions=output_attentions,
359
+ )
360
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
361
+ outputs = attn_outputs[1:]
362
+ # residual connection
363
+ hidden_states = attn_output + residual
364
+ residual = hidden_states
365
+ hidden_states = self.ln_2(hidden_states)
366
+ feed_forward_hidden_states = self.mlp(hidden_states)
367
+ # residual connection
368
+ hidden_states = residual + feed_forward_hidden_states
369
+
370
+ if use_cache:
371
+ outputs = (hidden_states,) + outputs
372
+ else:
373
+ outputs = (hidden_states,) + outputs[1:]
374
+
375
+ return outputs # hidden_states, present, (attentions, cross_attentions)
376
+
377
+
378
+ class VGPTNeoGatedCrossAttentionBlock(nn.Module):
379
+ def __init__(self, config, layer_id):
380
+ super().__init__()
381
+ hidden_size = config.hidden_size
382
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
383
+
384
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
+ self.cross_attn = GPTNeoAttention(config, layer_id, is_cross_attention=True)
386
+ self.mlp = GPTNeoMLP(inner_dim, config)
387
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
388
+ self.act = nn.Tanh()
389
+
390
+ if config.alpha_initializer == "zeros":
391
+ if config.alpha_type == "vector":
392
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, hidden_size))
393
+ self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
394
+ elif config.alpha_type == "float":
395
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
396
+ self.alpha_dense = nn.Parameter(torch.zeros(1))
397
+ else:
398
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
399
+
400
+ elif config.alpha_initializer == "ones":
401
+ if config.alpha_type == "vector":
402
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, hidden_size))
403
+ self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
404
+ elif config.alpha_type == "float":
405
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1))
406
+ self.alpha_dense = nn.Parameter(torch.ones(1))
407
+ else:
408
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
409
+
410
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
411
+ if config.alpha_type == "vector":
412
+ self.alpha_cross_attn = nn.Parameter(
413
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
414
+ )
415
+ self.alpha_dense = nn.Parameter(
416
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
417
+ )
418
+ elif config.alpha_type == "float":
419
+ self.alpha_cross_attn = nn.Parameter(
420
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
421
+ )
422
+ self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
423
+ else:
424
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
425
+
426
+ else:
427
+ raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
432
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
433
+ attention_mask: Optional[torch.FloatTensor] = None,
434
+ head_mask: Optional[torch.FloatTensor] = None,
435
+ image_hidden_states: Optional[torch.Tensor] = None,
436
+ image_attention_mask: Optional[torch.FloatTensor] = None,
437
+ use_cache: Optional[bool] = False,
438
+ output_attentions: Optional[bool] = False,
439
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
440
+ if image_hidden_states is None:
441
+ raise ValueError(
442
+ "`image_hidden_states` is required for VGPT2 cross attention module which are visual features to be"
443
+ " conditioned on."
444
+ )
445
+ # add one self-attention block for cross-attention
446
+
447
+ # TODO(aps): Handle cross attention in the outputs
448
+ # if not hasattr(self, "crossattention"):
449
+ # raise ValueError(
450
+ # f"If `image_hidden_states` are passed, {self} has to be instantiated with "
451
+ # "cross-attention layers by setting `config.add_cross_attention=True`"
452
+ # )
453
+ residual = hidden_states
454
+
455
+ hidden_states = self.ln_1(hidden_states)
456
+ cross_attn_outputs = self.cross_attn(
457
+ hidden_states,
458
+ attention_mask=attention_mask,
459
+ head_mask=head_mask,
460
+ encoder_hidden_states=image_hidden_states,
461
+ encoder_attention_mask=image_attention_mask,
462
+ output_attentions=output_attentions,
463
+ )
464
+ attn_output = cross_attn_outputs[0]
465
+ outputs = cross_attn_outputs[1:]
466
+ # residual connection
467
+ hidden_states = residual + self.act(self.alpha_cross_attn) * attn_output
468
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
469
+
470
+ residual = hidden_states
471
+ hidden_states = self.ln_2(hidden_states)
472
+ feed_forward_hidden_states = self.mlp(hidden_states)
473
+ # residual connection
474
+ hidden_states = residual + self.act(self.alpha_dense) * feed_forward_hidden_states
475
+
476
+ if use_cache:
477
+ outputs = (hidden_states,) + outputs
478
+ else:
479
+ outputs = (hidden_states,) + outputs[1:]
480
+
481
+ return outputs
482
+
483
+
484
+ class VGPTNeoPreTrainedModel(VLOOMPreTrainedModelBase):
485
+ """
486
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
487
+ models.
488
+ """
489
+
490
+ config_class = VGPTNeoConfig
491
+ load_tf_weights = load_tf_weights_in_gpt_neo
492
+ base_model_prefix = "transformer"
493
+ supports_gradient_checkpointing = True
494
+ _no_split_modules = ["GPTNeoBlock"]
495
+
496
+ def __init__(self, *inputs, **kwargs):
497
+ super().__init__(*inputs, **kwargs)
498
+
499
+ def _init_weights(self, module):
500
+ """Initialize the weights."""
501
+ if isinstance(module, (nn.Linear,)):
502
+ # Slightly different from the TF version which uses truncated_normal for initialization
503
+ # cf https://github.com/pytorch/pytorch/pull/5617
504
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
505
+ if module.bias is not None:
506
+ module.bias.data.zero_()
507
+ elif isinstance(module, nn.Embedding):
508
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
509
+ if module.padding_idx is not None:
510
+ module.weight.data[module.padding_idx].zero_()
511
+ elif isinstance(module, nn.LayerNorm):
512
+ module.bias.data.zero_()
513
+ module.weight.data.fill_(1.0)
514
+
515
+ def _set_gradient_checkpointing(self, module, value=False):
516
+ if isinstance(module, VGPTNeoModel):
517
+ module.gradient_checkpointing = value
518
+
519
+ @classmethod
520
+ def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
521
+ # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
522
+ beheaded_model = model.transformer if hasattr(model, "transformer") else model
523
+ cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
524
+ beheaded_model.freeze_relevant_params(config)
525
+
526
+
527
+ GPT_NEO_START_DOCSTRING = r"""
528
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
529
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
530
+ etc.)
531
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
532
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
533
+ and behavior.
534
+ Parameters:
535
+ config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model.
536
+ Initializing with a config file does not load the weights associated with the model, only the
537
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
538
+ """
539
+
540
+ GPT_NEO_INPUTS_DOCSTRING = r"""
541
+ Args:
542
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
543
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
544
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
545
+ sequence tokens in the vocabulary.
546
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
547
+ `input_ids`.
548
+ Indices can be obtained using [`GPTNeoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
549
+ [`PreTrainedTokenizer.__call__`] for details.
550
+ [What are input IDs?](../glossary#input-ids)
551
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
552
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
553
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
554
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
555
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
556
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
557
+ - 1 for tokens that are **not masked**,
558
+ - 0 for tokens that are **masked**.
559
+ [What are attention masks?](../glossary#attention-mask)
560
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
561
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
562
+ 1]`:
563
+ - 0 corresponds to a *sentence A* token,
564
+ - 1 corresponds to a *sentence B* token.
565
+ [What are token type IDs?](../glossary#token-type-ids)
566
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
567
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
568
+ config.max_position_embeddings - 1]`.
569
+ [What are position IDs?](../glossary#position-ids)
570
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
571
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
572
+ - 1 indicates the head is **not masked**,
573
+ - 0 indicates the head is **masked**.
574
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
575
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
576
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
577
+ model's internal embedding lookup matrix.
578
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
579
+ `past_key_values`).
580
+ use_cache (`bool`, *optional*):
581
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
582
+ `past_key_values`).
583
+ output_attentions (`bool`, *optional*):
584
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
585
+ tensors for more detail.
586
+ output_hidden_states (`bool`, *optional*):
587
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
588
+ more detail.
589
+ return_dict (`bool`, *optional*):
590
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
591
+ """
592
+
593
+
594
+ @add_start_docstrings(
595
+ "The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
596
+ GPT_NEO_START_DOCSTRING,
597
+ )
598
+ class VGPTNeoModel(VGPTNeoPreTrainedModel):
599
+ def __init__(self, config, vision_model=None):
600
+ super().__init__(config)
601
+
602
+ self.embed_dim = config.hidden_size
603
+ self.wte = DecoupledEmbedding(
604
+ num_embeddings=config.vocab_size,
605
+ num_additional_embeddings=config.additional_vocab_size,
606
+ embedding_dim=self.embed_dim,
607
+ partially_freeze=config.freeze_text_layers,
608
+ )
609
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
610
+ self.drop = nn.Dropout(float(config.embed_dropout))
611
+ self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
612
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
613
+
614
+ self.cross_layer_interval = config.cross_layer_interval
615
+ num_cross_layers = config.num_layers // self.cross_layer_interval
616
+ self.gated_cross_attn_layers = nn.ModuleList(
617
+ [VGPTNeoGatedCrossAttentionBlock(config, layer_id=i) for i in range(num_cross_layers)]
618
+ )
619
+
620
+ # Perceiver Resampler
621
+ if config.use_resampler:
622
+ self.perceiver_resampler = PerceiverResampler(
623
+ self.config,
624
+ self.config.vision_embed_dim,
625
+ config.resampler_depth,
626
+ config.resampler_n_heads,
627
+ config.resampler_head_dim,
628
+ config.resampler_n_latents,
629
+ )
630
+ self.gradient_checkpointing = False
631
+ self.image_token_idx = config.image_token_index
632
+
633
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
634
+ # this solves the losing of weights in `from_pretrained` on the main model
635
+ self.vision_model = vision_model
636
+
637
+ # Initialize weights and apply final processing
638
+ self.post_init()
639
+
640
+ self.freeze_relevant_params(config)
641
+
642
+ def freeze_relevant_params(self, config=None):
643
+ if config is None:
644
+ config = self.config
645
+
646
+ if config.freeze_text_layers:
647
+ self.freeze_text_layers()
648
+
649
+ if config.freeze_vision_layers:
650
+ freeze_model(self.vision_model)
651
+
652
+ def freeze_text_layers(self):
653
+ for module in [self.wpe, self.h, self.ln_f]:
654
+ freeze_model(module)
655
+
656
+ def get_input_embeddings(self):
657
+ return self.wte
658
+
659
+ def set_input_embeddings(self, new_embeddings):
660
+ self.wte = new_embeddings
661
+
662
+ @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
663
+ @add_code_sample_docstrings(
664
+ processor_class=_TOKENIZER_FOR_DOC,
665
+ checkpoint=_CHECKPOINT_FOR_DOC,
666
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
667
+ config_class=_CONFIG_FOR_DOC,
668
+ )
669
+ def forward(
670
+ self,
671
+ input_ids: Optional[torch.LongTensor] = None,
672
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
673
+ attention_mask: Optional[torch.FloatTensor] = None,
674
+ token_type_ids: Optional[torch.LongTensor] = None,
675
+ position_ids: Optional[torch.LongTensor] = None,
676
+ head_mask: Optional[torch.FloatTensor] = None,
677
+ inputs_embeds: Optional[torch.FloatTensor] = None,
678
+ pixel_values: Optional[torch.FloatTensor] = None,
679
+ image_embeddings: Optional[torch.FloatTensor] = None,
680
+ image_attention_mask: Optional[torch.Tensor] = None,
681
+ crossblock_head_mask: Optional[torch.Tensor] = None,
682
+ use_cache: Optional[bool] = None,
683
+ output_attentions: Optional[bool] = None,
684
+ output_hidden_states: Optional[bool] = None,
685
+ return_dict: Optional[bool] = None,
686
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
687
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
688
+ output_hidden_states = (
689
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
690
+ )
691
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
692
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
693
+
694
+ if input_ids is not None and inputs_embeds is not None:
695
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
696
+ elif input_ids is not None:
697
+ input_shape = input_ids.size()
698
+ input_ids = input_ids.view(-1, input_shape[-1])
699
+ batch_size = input_ids.shape[0]
700
+ elif inputs_embeds is not None:
701
+ input_shape = inputs_embeds.size()[:-1]
702
+ batch_size = inputs_embeds.shape[0]
703
+ else:
704
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
705
+
706
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
707
+
708
+ if token_type_ids is not None:
709
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
710
+ if position_ids is not None:
711
+ position_ids = position_ids.view(-1, input_shape[-1])
712
+
713
+ if past_key_values is None:
714
+ past_length = 0
715
+ past_key_values = tuple([None] * len(self.h))
716
+ else:
717
+ past_length = past_key_values[0][0].size(-2)
718
+ if position_ids is None:
719
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
720
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
721
+
722
+ # GPT2Attention mask.
723
+ if attention_mask is not None:
724
+ if batch_size <= 0:
725
+ raise ValueError("batch_size has to be defined and > 0")
726
+ attention_mask = attention_mask.view(batch_size, -1)
727
+ # We create a 3D attention mask from a 2D tensor mask.
728
+ # Sizes are [batch_size, 1, 1, to_seq_length]
729
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
730
+ # this attention mask is more simple than the triangular masking of causal attention
731
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
732
+ attention_mask = attention_mask[:, None, None, :]
733
+
734
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
735
+ # masked positions, this operation will create a tensor which is 0.0 for
736
+ # positions we want to attend and the dtype's smallest value for masked positions.
737
+ # Since we are adding it to the raw scores before the softmax, this is
738
+ # effectively the same as removing these entirely.
739
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
740
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
741
+
742
+ # If a 2D or 3D attention mask is provided for the cross-attention
743
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
744
+ if pixel_values is not None and image_embeddings is not None:
745
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
746
+ elif pixel_values is not None:
747
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
748
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
749
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
750
+ # Get sequence from the vision encoder
751
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
752
+ elif image_embeddings is not None:
753
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
754
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
755
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
756
+
757
+ if self.config.use_resampler:
758
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
759
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
760
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
761
+ # Make image_attention_mask compatible with hidden states
762
+ text_seq_len = image_attention_mask.size(1)
763
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
764
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
765
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
766
+
767
+ if image_hidden_states is not None:
768
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
769
+ image_hidden_shape = (image_batch_size, image_sequence_length)
770
+ if image_attention_mask is None:
771
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
772
+ # image_attention_mask = self.invert_attention_mask(image_attention_mask)
773
+ image_attention_mask = image_attention_mask.to(torch.bool)
774
+ image_attention_mask = image_attention_mask[:, None, :, :]
775
+ else:
776
+ image_attention_mask = None
777
+
778
+ # Prepare head mask if needed
779
+ # 1.0 in head_mask indicate we keep the head
780
+ # attention_probs has shape bsz x n_heads x N x N
781
+ # head_mask has shape n_layer x batch x n_heads x N x N
782
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
783
+
784
+ if inputs_embeds is None:
785
+ inputs_embeds = self.wte(input_ids)
786
+
787
+ position_embeds = self.wpe(position_ids)
788
+ hidden_states = inputs_embeds + position_embeds
789
+
790
+ if token_type_ids is not None:
791
+ token_type_embeds = self.wte(token_type_ids)
792
+ hidden_states = hidden_states + token_type_embeds
793
+
794
+ hidden_states = self.drop(hidden_states)
795
+
796
+ output_shape = input_shape + (hidden_states.size(-1),)
797
+
798
+ presents = () if use_cache else None
799
+ all_self_attentions = () if output_attentions else None
800
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
801
+ all_hidden_states = () if output_hidden_states else None
802
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
803
+ if output_hidden_states:
804
+ all_hidden_states = all_hidden_states + (hidden_states,)
805
+
806
+ def vblock(
807
+ main_block,
808
+ hidden_states,
809
+ layer_past,
810
+ attention_mask,
811
+ layer_head_mask,
812
+ use_cache,
813
+ output_attentions,
814
+ image_hidden_states,
815
+ image_attention_mask,
816
+ layer_idx,
817
+ cross_layer_interval,
818
+ gated_cross_attn_layers,
819
+ ):
820
+ # TODO(aps): Add cross attention values to respective lists
821
+ # TODO(aps): Add xblock head mask support
822
+ if layer_idx % cross_layer_interval == 0:
823
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
824
+ outputs = xblock(
825
+ hidden_states,
826
+ attention_mask=attention_mask,
827
+ image_hidden_states=image_hidden_states,
828
+ image_attention_mask=image_attention_mask,
829
+ use_cache=use_cache,
830
+ output_attentions=output_attentions,
831
+ )
832
+ hidden_states = outputs[0]
833
+
834
+ outputs = main_block(
835
+ hidden_states,
836
+ layer_past=layer_past,
837
+ attention_mask=attention_mask,
838
+ head_mask=layer_head_mask,
839
+ use_cache=use_cache,
840
+ output_attentions=output_attentions,
841
+ )
842
+
843
+ return outputs
844
+
845
+ if self.gradient_checkpointing and self.training:
846
+ if use_cache:
847
+ logger.warning_once(
848
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
849
+ )
850
+ use_cache = False
851
+
852
+ outputs = torch.utils.checkpoint.checkpoint(
853
+ vblock,
854
+ block,
855
+ hidden_states,
856
+ layer_past,
857
+ attention_mask,
858
+ head_mask[i],
859
+ use_cache,
860
+ output_attentions,
861
+ image_hidden_states,
862
+ image_attention_mask,
863
+ i,
864
+ self.cross_layer_interval,
865
+ self.gated_cross_attn_layers,
866
+ )
867
+ else:
868
+ outputs = vblock(
869
+ block,
870
+ hidden_states,
871
+ layer_past=layer_past,
872
+ attention_mask=attention_mask,
873
+ layer_head_mask=head_mask[i],
874
+ use_cache=use_cache,
875
+ output_attentions=output_attentions,
876
+ layer_idx=i,
877
+ image_hidden_states=image_hidden_states,
878
+ image_attention_mask=image_attention_mask,
879
+ cross_layer_interval=self.cross_layer_interval,
880
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
881
+ )
882
+
883
+ hidden_states = outputs[0]
884
+ if use_cache is True:
885
+ presents = presents + (outputs[1],)
886
+
887
+ if output_attentions:
888
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
889
+ if self.config.add_cross_attention:
890
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
891
+
892
+ hidden_states = self.ln_f(hidden_states)
893
+
894
+ hidden_states = hidden_states.view(output_shape)
895
+ # Add last hidden state
896
+ if output_hidden_states:
897
+ all_hidden_states = all_hidden_states + (hidden_states,)
898
+
899
+ if not return_dict:
900
+ return tuple(
901
+ v
902
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
903
+ if v is not None
904
+ )
905
+
906
+ return BaseModelOutputWithPastAndCrossAttentions(
907
+ last_hidden_state=hidden_states,
908
+ past_key_values=presents,
909
+ hidden_states=all_hidden_states,
910
+ attentions=all_self_attentions,
911
+ cross_attentions=all_cross_attentions,
912
+ )
913
+
914
+
915
+ @add_start_docstrings(
916
+ """
917
+ The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
918
+ embeddings).
919
+ """,
920
+ GPT_NEO_START_DOCSTRING,
921
+ )
922
+ class VGPTNeoForCausalLM(VGPTNeoPreTrainedModel):
923
+ _keys_to_ignore_on_load_missing = [
924
+ r"h\.\d+\.attn\.masked_bias",
925
+ r"lm_head.weight",
926
+ r"h\.\d+\.attn\.attention\.bias",
927
+ ]
928
+ _keys_to_ignore_on_save = [r"lm_head.weight"]
929
+
930
+ def __init__(self, config, vision_model=None):
931
+ super().__init__(config)
932
+ self.transformer = VGPTNeoModel(config, vision_model=vision_model)
933
+ self.lm_head = DecoupledLinear(
934
+ in_features=config.hidden_size,
935
+ out_features=config.vocab_size,
936
+ out_additional_features=config.additional_vocab_size,
937
+ bias=False,
938
+ partially_freeze=config.freeze_lm_head,
939
+ )
940
+
941
+ # Initialize weights and apply final processing
942
+ self.post_init()
943
+
944
+ def get_output_embeddings(self):
945
+ return self.lm_head
946
+
947
+ def set_output_embeddings(self, new_embeddings):
948
+ self.lm_head = new_embeddings
949
+
950
+ def tie_weights(self):
951
+ """
952
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
953
+ """
954
+ output_embeddings = self.get_output_embeddings()
955
+ input_embeddings = self.get_input_embeddings()
956
+
957
+ if getattr(self.config, "tie_word_embeddings", True):
958
+ output_embeddings.weight = input_embeddings.weight
959
+ if input_embeddings.num_additional_embeddings > 0:
960
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
961
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
962
+
963
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
964
+ output_embeddings.out_features = input_embeddings.num_embeddings
965
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
966
+ input_embeddings, "num_additional_embeddings"
967
+ ):
968
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
969
+
970
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
971
+ return prepare_inputs_for_generation(input_ids, past=past, **kwargs)
972
+
973
+ @staticmethod
974
+ def _expand_inputs_for_generation(
975
+ *args,
976
+ **model_kwargs,
977
+ ):
978
+ return expand_inputs_for_generation(*args, **model_kwargs)
979
+
980
+ @staticmethod
981
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
982
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
983
+
984
+ @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
985
+ @add_code_sample_docstrings(
986
+ processor_class=_TOKENIZER_FOR_DOC,
987
+ checkpoint=_CHECKPOINT_FOR_DOC,
988
+ output_type=CausalLMOutputWithCrossAttentions,
989
+ config_class=_CONFIG_FOR_DOC,
990
+ )
991
+ def forward(
992
+ self,
993
+ input_ids: Optional[torch.Tensor] = None,
994
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
995
+ attention_mask: Optional[torch.Tensor] = None,
996
+ token_type_ids: Optional[torch.Tensor] = None,
997
+ position_ids: Optional[torch.Tensor] = None,
998
+ head_mask: Optional[torch.Tensor] = None,
999
+ inputs_embeds: Optional[torch.Tensor] = None,
1000
+ pixel_values: Optional[torch.FloatTensor] = None,
1001
+ image_embeddings: Optional[torch.FloatTensor] = None,
1002
+ image_attention_mask: Optional[torch.Tensor] = None,
1003
+ crossblock_head_mask: Optional[torch.Tensor] = None,
1004
+ labels: Optional[torch.Tensor] = None,
1005
+ use_cache: Optional[bool] = None,
1006
+ output_attentions: Optional[bool] = None,
1007
+ output_hidden_states: Optional[bool] = None,
1008
+ return_dict: Optional[bool] = None,
1009
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1010
+ r"""
1011
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1012
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1013
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1014
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1015
+ """
1016
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1017
+
1018
+ transformer_outputs = self.transformer(
1019
+ input_ids,
1020
+ past_key_values=past_key_values,
1021
+ attention_mask=attention_mask,
1022
+ token_type_ids=token_type_ids,
1023
+ position_ids=position_ids,
1024
+ head_mask=head_mask,
1025
+ inputs_embeds=inputs_embeds,
1026
+ pixel_values=pixel_values,
1027
+ image_embeddings=image_embeddings,
1028
+ image_attention_mask=image_attention_mask,
1029
+ crossblock_head_mask=crossblock_head_mask,
1030
+ use_cache=use_cache,
1031
+ output_attentions=output_attentions,
1032
+ output_hidden_states=output_hidden_states,
1033
+ return_dict=return_dict,
1034
+ )
1035
+ hidden_states = transformer_outputs[0]
1036
+
1037
+ lm_logits = self.lm_head(hidden_states)
1038
+
1039
+ loss = None
1040
+ if labels is not None:
1041
+ # Compute loss in fp32 to match with mesh-tf version
1042
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
1043
+ lm_logits = lm_logits.to(torch.float32)
1044
+
1045
+ # Shift so that tokens < n predict n
1046
+ if attention_mask is not None:
1047
+ shift_attention_mask = attention_mask[..., 1:]
1048
+ shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1049
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1050
+ else:
1051
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1052
+ shift_labels = labels[..., 1:].contiguous()
1053
+ # Flatten the tokens
1054
+ loss_fct = CrossEntropyLoss()
1055
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1056
+
1057
+ lm_logits = lm_logits.to(hidden_states.dtype)
1058
+ loss = loss.to(hidden_states.dtype)
1059
+
1060
+ if not return_dict:
1061
+ output = (lm_logits,) + transformer_outputs[1:]
1062
+ return ((loss,) + output) if loss is not None else output
1063
+
1064
+ return CausalLMOutputWithCrossAttentions(
1065
+ loss=loss,
1066
+ logits=lm_logits,
1067
+ past_key_values=transformer_outputs.past_key_values,
1068
+ hidden_states=transformer_outputs.hidden_states,
1069
+ attentions=transformer_outputs.attentions,
1070
+ cross_attentions=transformer_outputs.cross_attentions,
1071
+ )
1072
+
1073
+ @staticmethod
1074
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
1075
+ """
1076
+ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
1077
+ [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1078
+ beam_idx at every generation step.
1079
+ """
1080
+ return tuple(
1081
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1082
+ for layer_past in past
1083
+ )
1084
+
1085
+ def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
1086
+ config_vl_model = self.config
1087
+
1088
+ language_embed_size = config_vl_model.hidden_size
1089
+ vision_config = self.transformer.vision_model.config
1090
+ num_language_layers = config_vl_model.num_layers
1091
+ ffn_inner_size = (
1092
+ config_vl_model.intermediate_size
1093
+ if config_vl_model.intermediate_size is not None
1094
+ else 4 * config_vl_model.hidden_size
1095
+ )
1096
+
1097
+ # Get vision model blocks infos
1098
+ vision_patch_size = vision_config.patch_size
1099
+ vision_hidden_size = vision_config.hidden_size
1100
+ num_vision_layers = vision_config.num_hidden_layers
1101
+ # The +1 is for the CLS token
1102
+ single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
1103
+ vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
1104
+
1105
+ # Get language and cross-att blocks infos
1106
+ num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
1107
+ language_seq_len = data_param.max_seq_len
1108
+ language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1109
+ cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1110
+ k_v_cross_attn_seq_len = (
1111
+ (self.config.resampler_n_latents * max_num_images)
1112
+ if self.config.use_resampler
1113
+ else (single_image_seq_len * max_num_images)
1114
+ )
1115
+
1116
+ language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1117
+ num_layers=num_language_layers,
1118
+ batch_size=hparams.batch_size_per_gpu,
1119
+ q_seq_len=language_seq_len,
1120
+ k_seq_len=language_seq_len,
1121
+ hidden_size=language_embed_size,
1122
+ kv_in_dim=language_embed_size,
1123
+ ff_exp_factor=language_exp_factor,
1124
+ grad_acc_size=hparams.grad_acc_size,
1125
+ swiglu=False,
1126
+ vocab_size=tokenizer.vocab_size,
1127
+ count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
1128
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1129
+ )
1130
+ cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1131
+ num_layers=num_cross_attn_layers,
1132
+ batch_size=hparams.batch_size_per_gpu,
1133
+ q_seq_len=language_seq_len,
1134
+ k_seq_len=k_v_cross_attn_seq_len,
1135
+ hidden_size=language_embed_size,
1136
+ kv_in_dim=vision_hidden_size,
1137
+ ff_exp_factor=cross_att_exp_factor,
1138
+ grad_acc_size=hparams.grad_acc_size,
1139
+ swiglu=False,
1140
+ vocab_size=None,
1141
+ count_backward=True,
1142
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1143
+ )
1144
+ vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1145
+ num_layers=num_vision_layers,
1146
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1147
+ q_seq_len=single_image_seq_len,
1148
+ k_seq_len=single_image_seq_len,
1149
+ hidden_size=vision_hidden_size,
1150
+ kv_in_dim=vision_hidden_size,
1151
+ ff_exp_factor=vision_exp_factor,
1152
+ grad_acc_size=hparams.grad_acc_size,
1153
+ swiglu=False,
1154
+ vocab_size=None,
1155
+ count_backward=not hparams.model_params["freeze_vision_layers"],
1156
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1157
+ )
1158
+ if self.config.use_resampler:
1159
+ perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
1160
+ num_layers=self.config.resampler_depth,
1161
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1162
+ q_seq_len=self.config.resampler_n_latents,
1163
+ vision_embed_seq_len=single_image_seq_len,
1164
+ q_k_v_input_dim=vision_hidden_size,
1165
+ attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
1166
+ ff_exp_factor=cross_att_exp_factor,
1167
+ count_backward=True,
1168
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1169
+ )
1170
+ flop_count = (
1171
+ language_tflops_per_batch_per_gpu
1172
+ + cross_attention_tflops_per_batch_per_gpu
1173
+ + vision_tflops_per_batch_per_gpu
1174
+ + perceiver_tflops_per_batch_per_gpu
1175
+ )
1176
+ else:
1177
+ flop_count = (
1178
+ language_tflops_per_batch_per_gpu
1179
+ + cross_attention_tflops_per_batch_per_gpu
1180
+ + vision_tflops_per_batch_per_gpu
1181
+ )
1182
+ return flop_count
m4/models/vllama/configuration_vllama.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LLaMA model configuration"""
21
+ import os
22
+ from typing import Tuple, Union
23
+
24
+ from transformers import AutoConfig
25
+ from transformers.configuration_utils import PretrainedConfig
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
32
+
33
+
34
+ class VLlamaConfig(PretrainedConfig):
35
+ r"""
36
+ TODO: update docstring with respect to new arguments
37
+
38
+ This is the configuration class to store the configuration of a [`~LlamaModel`]. It is used to instantiate an LLaMA
39
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
40
+ defaults will yield a similar configuration to that of the LLaMA-7B.
41
+
42
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
43
+ documentation from [`PretrainedConfig`] for more information.
44
+
45
+
46
+ Args:
47
+ vocab_size (`int`, *optional*, defaults to 32000):
48
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
49
+ `inputs_ids` passed when calling [`~LlamaModel`]
50
+ hidden_size (`int`, *optional*, defaults to 4096):
51
+ Dimension of the hidden representations.
52
+ intermediate_size (`int`, *optional*, defaults to 11008):
53
+ Dimension of the MLP representations.
54
+ num_hidden_layers (`int`, *optional*, defaults to 32):
55
+ Number of hidden layers in the Transformer encoder.
56
+ num_attention_heads (`int`, *optional*, defaults to 32):
57
+ Number of attention heads for each attention layer in the Transformer encoder.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
63
+ The epsilon used by the rms normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
68
+ Whether to tie weight embeddings
69
+ Example:
70
+
71
+ ```python
72
+ >>> from transformers import LlamaModel, LlamaConfig
73
+
74
+ >>> # Initializing a LLaMA llama-7b style configuration
75
+ >>> configuration = LlamaConfig()
76
+
77
+ >>> # Initializing a model from the llama-7b style configuration
78
+ >>> model = LlamaModel(configuration)
79
+
80
+ >>> # Accessing the model configuration
81
+ >>> configuration = model.config
82
+ ```"""
83
+ model_type = "vllama"
84
+
85
+ def __init__(
86
+ self,
87
+ vocab_size=32000,
88
+ additional_vocab_size=0,
89
+ hidden_size=4096,
90
+ intermediate_size=11008,
91
+ num_hidden_layers=32,
92
+ num_attention_heads=32,
93
+ dropout=0.0,
94
+ hidden_act="silu",
95
+ initializer_range=0.02,
96
+ alpha_initializer="ones",
97
+ alphas_initializer_range=0.0,
98
+ alpha_type="vector",
99
+ rms_norm_eps=1e-6,
100
+ use_cache=True,
101
+ pad_token_id=0,
102
+ bos_token_id=1,
103
+ eos_token_id=2,
104
+ tie_word_embeddings=False,
105
+ cross_layer_interval=1,
106
+ cross_layer_activation_function="swiglu",
107
+ qk_layer_norms=False,
108
+ qk_layer_norms_perceiver=False,
109
+ freeze_text_layers=True,
110
+ freeze_text_module_exceptions=[],
111
+ freeze_lm_head=False,
112
+ freeze_vision_layers=True,
113
+ freeze_vision_module_exceptions=[],
114
+ vision_model_name="google/vit-base-patch16-224",
115
+ vision_model_params="{}",
116
+ vision_embed_dim=768,
117
+ vision_image_size=224,
118
+ use_resampler=False,
119
+ resampler_n_latents=64,
120
+ resampler_depth=6,
121
+ resampler_n_heads=16,
122
+ resampler_head_dim=96,
123
+ **kwargs,
124
+ ):
125
+ self.vocab_size = vocab_size
126
+ self.additional_vocab_size = additional_vocab_size
127
+ self.hidden_size = hidden_size
128
+ self.intermediate_size = intermediate_size
129
+ self.num_hidden_layers = num_hidden_layers
130
+ self.num_attention_heads = num_attention_heads
131
+ self.dropout = dropout
132
+ self.hidden_act = hidden_act
133
+ self.initializer_range = initializer_range
134
+ self.alpha_initializer = alpha_initializer
135
+ self.alphas_initializer_range = alphas_initializer_range
136
+ self.alpha_type = alpha_type
137
+ self.rms_norm_eps = rms_norm_eps
138
+ self.use_cache = use_cache
139
+ super().__init__(
140
+ pad_token_id=pad_token_id,
141
+ bos_token_id=bos_token_id,
142
+ eos_token_id=eos_token_id,
143
+ tie_word_embeddings=tie_word_embeddings,
144
+ **kwargs,
145
+ )
146
+
147
+ self.cross_layer_interval = cross_layer_interval
148
+ self.cross_layer_activation_function = cross_layer_activation_function
149
+ self.qk_layer_norms = qk_layer_norms
150
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
151
+ self.freeze_vision_layers = freeze_vision_layers
152
+ self.vision_model_name = vision_model_name
153
+ self.vision_model_params = vision_model_params
154
+
155
+ self.freeze_text_layers = freeze_text_layers
156
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
157
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
158
+ self.freeze_lm_head = freeze_lm_head
159
+
160
+ self.vision_embed_dim = vision_embed_dim
161
+ self.vision_image_size = vision_image_size
162
+
163
+ # Resampler params
164
+ self.use_resampler = use_resampler
165
+ self.resampler_n_latents = resampler_n_latents
166
+ self.resampler_depth = resampler_depth
167
+ self.resampler_n_heads = resampler_n_heads
168
+ self.resampler_head_dim = resampler_head_dim
169
+
170
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
171
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
172
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
173
+ # of this object many attributes have default values and haven't yet been overridden.
174
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
175
+
176
+ def check_compatibilities(self):
177
+ vision_model_params = eval(self.vision_model_params)
178
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
179
+ if hasattr(config, "vision_config"):
180
+ vision_config = config.vision_config
181
+ else:
182
+ vision_config = config
183
+ vision_embed_dim = vision_config.hidden_size
184
+ if self.vision_embed_dim != vision_embed_dim:
185
+ raise ValueError(
186
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
187
+ f" ({vision_embed_dim})"
188
+ )
189
+ vision_image_size = vision_config.image_size
190
+ if self.vision_image_size != vision_image_size:
191
+ raise ValueError(
192
+ f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
193
+ f" ({vision_image_size})"
194
+ )
195
+
196
+ @classmethod
197
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
198
+ outputs = super(VLlamaConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
199
+ if isinstance(outputs, Tuple):
200
+ # When called with return_unused_kwargs=True, the first item will be the config
201
+ outputs[0].check_compatibilities()
202
+ else:
203
+ outputs.check_compatibilities()
204
+ return outputs
m4/models/vllama/make_tiny_llama.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script creates a super tiny model that is useful inside tests, when we just want to test that
4
+ # the machinery works, without needing to check the quality of the outcomes.
5
+ #
6
+ # usage: adjust the configs if wanted, but otherwise just run the script
7
+
8
+ from pathlib import Path
9
+
10
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
11
+
12
+
13
+ mname_tiny = "tiny-random-LlamaForCausalLM"
14
+
15
+ path = Path(mname_tiny)
16
+ path.mkdir(parents=True, exist_ok=True)
17
+
18
+ config = LlamaConfig()
19
+ config.update(
20
+ dict(
21
+ vocab_size=32000,
22
+ hidden_size=16,
23
+ intermediate_size=16 * 4,
24
+ num_hidden_layers=2,
25
+ num_attention_heads=4,
26
+ )
27
+ )
28
+ model = LlamaForCausalLM(config)
29
+ tokenizer = LlamaTokenizer.from_pretrained("path_to_llama_7b")
30
+
31
+ # Test w/ one text
32
+ query = "This is a test"
33
+ query_tokens = tokenizer(query, return_tensors="pt")
34
+
35
+ input = {
36
+ "input_ids": query_tokens["input_ids"],
37
+ "attention_mask": query_tokens["attention_mask"],
38
+ }
39
+
40
+ out_gen = model.generate(**input)
41
+ text = tokenizer.batch_decode(out_gen)
42
+
43
+ # Save model + config + tokenizer
44
+ model.half() # makes it smaller
45
+ model.save_pretrained(path)
46
+ tokenizer.save_pretrained(path)
47
+
48
+ # test we can load it back
49
+ model = LlamaForCausalLM.from_pretrained(path)
50
+
51
+ print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
m4/models/vllama/make_tiny_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script creates a super tiny model that is useful inside tests, when we just want to test that
4
+ # the machinery works, without needing to check the quality of the outcomes.
5
+ #
6
+ # usage: adjust the configs if wanted, but otherwise just run the script
7
+
8
+ from pathlib import Path
9
+ from types import SimpleNamespace
10
+
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+
14
+ from m4.models.vllama.modeling_vllama import VLlamaConfig, VLlamaForCausalLM
15
+ from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
16
+ from m4.training.utils import get_tokenizer
17
+
18
+
19
+ mname_tiny = "tiny-random-vllama-clip"
20
+
21
+ path = Path(mname_tiny)
22
+ path.mkdir(parents=True, exist_ok=True)
23
+
24
+ # from the hardcoded https://github.com/huggingface/m4/blob/adf102f0000cb2632cd8a3ebb87398c65e448a97/m4/training/main.py#L80
25
+ additional_vocab_size = 2
26
+
27
+ config = VLlamaConfig()
28
+ config.update(
29
+ dict(
30
+ ffn_dim=64,
31
+ hidden_size=16,
32
+ max_position_embeddings=128,
33
+ num_attention_heads=4,
34
+ num_hidden_layers=2,
35
+ word_embed_proj_dim=16,
36
+ max_new_tokens=100,
37
+ use_resampler=True,
38
+ resampler_depth=2,
39
+ resampler_head_dim=8,
40
+ resampler_n_heads=2,
41
+ resampler_n_latents=16,
42
+ vision_embed_dim=32,
43
+ vision_image_size=30,
44
+ vision_model_name="hf-internal-testing/tiny-random-clip",
45
+ vision_model_params="{}",
46
+ vocab_size=32000,
47
+ additional_vocab_size=additional_vocab_size,
48
+ )
49
+ )
50
+
51
+ # print(config)
52
+ # can now modify config to say tiny values
53
+
54
+ model = VLlamaForCausalLM.from_config(config)
55
+ # print(model.config)
56
+ # print(model)
57
+
58
+ tokenizer_config = dict(
59
+ tokenizer_add_special_tokens="{}",
60
+ tokenizer_add_tokens=(
61
+ '[AddedToken("<fake_token_around_image>", rstrip=False, lstrip=False), AddedToken("<image>", rstrip=False,'
62
+ " lstrip=False)]"
63
+ ),
64
+ tokenizer_name="HuggingFaceM4/huggy-llama-tokenizer-7b",
65
+ tokenizer_params='{"use_fast": True}',
66
+ )
67
+ tokenizer_config = SimpleNamespace(**tokenizer_config)
68
+ # print(tokenizer_config)
69
+
70
+ tokenizer = get_tokenizer(
71
+ tokenizer_name=tokenizer_config.tokenizer_name,
72
+ tokenizer_add_tokens=tokenizer_config.tokenizer_add_tokens,
73
+ tokenizer_add_special_tokens=tokenizer_config.tokenizer_add_special_tokens,
74
+ tokenizer_params=tokenizer_config.tokenizer_params,
75
+ additional_vocab_size=model.config.additional_vocab_size,
76
+ model_vocab_size=model.config.vocab_size,
77
+ )
78
+ assert "<image>" in tokenizer.get_vocab()
79
+
80
+ # Test w/ one image and one text
81
+ query = "<fake_token_around_image><image><fake_token_around_image>This is a picture of a cat."
82
+ query_tokens = tokenizer(query, return_tensors="pt")
83
+
84
+ num_images_per_ex = 1
85
+ pixel_values = transforms.ToTensor()(Image.new("RGB", (30, 30))).repeat(1, 1, 1, 1).unsqueeze(0)
86
+ image_attention_mask, _ = image_attention_mask_for_packed_input_ids(query_tokens["input_ids"], tokenizer)
87
+ image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=num_images_per_ex)
88
+
89
+ input = {
90
+ "input_ids": query_tokens["input_ids"],
91
+ "attention_mask": query_tokens["attention_mask"],
92
+ "pixel_values": pixel_values,
93
+ "pixel_values": pixel_values,
94
+ "image_attention_mask": image_attention_mask,
95
+ }
96
+ # debug shapes
97
+ # print(query_tokens["input_ids"].shape)
98
+ # print(query_tokens["attention_mask"].shape)
99
+ # print(pixel_values.shape)
100
+ # print(image_attention_mask.shape)
101
+
102
+ out_gen = model.generate(**input)
103
+ text = tokenizer.batch_decode(out_gen)
104
+ # print(text)
105
+
106
+ # Save model + config + tokenizer
107
+ model.half() # makes it smaller
108
+ model.save_pretrained(path)
109
+ tokenizer.save_pretrained(path)
110
+
111
+ # test we can load it back
112
+ model = VLlamaForCausalLM.from_pretrained(path)
113
+
114
+ print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
m4/models/vllama/modeling_vllama.py ADDED
@@ -0,0 +1,1260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
+ from transformers.modeling_utils import PretrainedConfig
30
+ from transformers.utils import (
31
+ ContextManagers,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ replace_return_docstrings,
35
+ )
36
+
37
+ from m4.models import DecoupledEmbedding, DecoupledLinear
38
+ from m4.models.common import (
39
+ expand_inputs_for_generation,
40
+ prepare_inputs_for_generation,
41
+ update_model_kwargs_for_generation,
42
+ )
43
+ from m4.models.custom_modules import VLOOMPreTrainedModelBase
44
+ from m4.models.perceiver.perceiver import PerceiverResampler
45
+ from m4.models.vllama.configuration_vllama import VLlamaConfig
46
+ from m4.training.utils import (
47
+ compute_perceiver_tflops_per_batch_per_gpu,
48
+ compute_tflops_per_batch_per_gpu,
49
+ deepspeed_gathered_parameters_context_manager,
50
+ freeze_model,
51
+ )
52
+ from m4.utils import logging
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "VLlamaConfig"
58
+
59
+
60
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
61
+ """
62
+ Make causal mask used for bi-directional self-attention.
63
+ """
64
+ bsz, tgt_len = input_ids_shape
65
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
66
+ mask_cond = torch.arange(mask.size(-1))
67
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
68
+ mask = mask.to(dtype)
69
+
70
+ if past_key_values_length > 0:
71
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
72
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
73
+
74
+
75
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
76
+ """
77
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
78
+ """
79
+ bsz, src_len = mask.size()
80
+ tgt_len = tgt_len if tgt_len is not None else src_len
81
+
82
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
83
+
84
+ inverted_mask = 1.0 - expanded_mask
85
+
86
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
+
88
+
89
+ class LlamaRMSNorm(nn.Module):
90
+ def __init__(self, hidden_size, eps=1e-6):
91
+ """
92
+ LlamaRMSNorm is equivalent to T5LayerNorm
93
+ """
94
+ super().__init__()
95
+ self.weight = nn.Parameter(torch.ones(hidden_size))
96
+ self.variance_epsilon = eps
97
+
98
+ def forward(self, hidden_states):
99
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
100
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
101
+
102
+ # convert into half-precision if necessary
103
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
104
+ hidden_states = hidden_states.to(self.weight.dtype)
105
+
106
+ return self.weight * hidden_states
107
+
108
+
109
+ class LlamaRotaryEmbedding(torch.nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
+ self.register_buffer("inv_freq", inv_freq)
114
+
115
+ # Build here to make `torch.jit.trace` work.
116
+ self.max_seq_len_cached = max_position_embeddings
117
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
118
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1)
121
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
122
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
123
+
124
+ def forward(self, x, seq_len=None):
125
+ # x: [bs, num_attention_heads, seq_len, head_size]
126
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
127
+ if seq_len > self.max_seq_len_cached:
128
+ self.max_seq_len_cached = seq_len
129
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
130
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
133
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
134
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
135
+ return (
136
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
137
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
138
+ )
139
+
140
+
141
+ def rotate_half(x):
142
+ """Rotates half the hidden dims of the input."""
143
+ x1 = x[..., : x.shape[-1] // 2]
144
+ x2 = x[..., x.shape[-1] // 2 :]
145
+ return torch.cat((-x2, x1), dim=-1)
146
+
147
+
148
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
149
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
150
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
151
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
152
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
153
+ q_embed = (q * cos) + (rotate_half(q) * sin)
154
+ k_embed = (k * cos) + (rotate_half(k) * sin)
155
+ return q_embed, k_embed
156
+
157
+
158
+ class LlamaMLP(nn.Module):
159
+ def __init__(
160
+ self,
161
+ hidden_size: int,
162
+ intermediate_size: int,
163
+ hidden_act: str,
164
+ ):
165
+ super().__init__()
166
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
167
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
168
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
169
+ self.act_fn = ACT2FN[hidden_act]
170
+
171
+ def forward(self, x):
172
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
173
+
174
+
175
+ class LlamaAttention(nn.Module):
176
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
177
+
178
+ def __init__(
179
+ self,
180
+ hidden_size: int,
181
+ num_heads: int,
182
+ dropout: float = 0.0,
183
+ is_cross_attention: bool = False,
184
+ config: PretrainedConfig = None,
185
+ qk_layer_norms: bool = False,
186
+ ):
187
+ super().__init__()
188
+ self.hidden_size = hidden_size
189
+ self.num_heads = num_heads
190
+ self.head_dim = hidden_size // num_heads
191
+ self.dropout = dropout
192
+
193
+ if (self.head_dim * num_heads) != self.hidden_size:
194
+ raise ValueError(
195
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
196
+ f" and `num_heads`: {num_heads})."
197
+ )
198
+
199
+ self.is_cross_attention = is_cross_attention
200
+
201
+ if self.is_cross_attention:
202
+ kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
203
+ self.q_proj = nn.Linear(
204
+ self.hidden_size,
205
+ num_heads * self.head_dim,
206
+ bias=False,
207
+ )
208
+ self.k_proj = nn.Linear(kv_input_dim, num_heads * self.head_dim, bias=False)
209
+ self.v_proj = nn.Linear(
210
+ kv_input_dim,
211
+ num_heads * self.head_dim,
212
+ bias=False,
213
+ )
214
+ else:
215
+ self.q_proj = nn.Linear(
216
+ self.hidden_size,
217
+ num_heads * self.head_dim,
218
+ bias=False,
219
+ )
220
+ self.k_proj = nn.Linear(
221
+ self.hidden_size,
222
+ num_heads * self.head_dim,
223
+ bias=False,
224
+ )
225
+ self.v_proj = nn.Linear(
226
+ self.hidden_size,
227
+ num_heads * self.head_dim,
228
+ bias=False,
229
+ )
230
+ self.o_proj = nn.Linear(
231
+ num_heads * self.head_dim,
232
+ hidden_size,
233
+ bias=False,
234
+ )
235
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
236
+
237
+ self.qk_layer_norms = qk_layer_norms
238
+ if self.qk_layer_norms:
239
+ self.q_layer_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
240
+ self.k_layer_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
241
+
242
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
243
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ key_value_states: Optional[torch.Tensor] = None,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ position_ids: Optional[torch.LongTensor] = None,
251
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
252
+ output_attentions: bool = False,
253
+ use_cache: bool = False,
254
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
255
+ # if key_value_states are provided this layer is used as a cross-attention layer
256
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
257
+
258
+ bsz, q_len, _ = hidden_states.size()
259
+
260
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
261
+ if not is_cross_attention:
262
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
263
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
264
+ else:
265
+ _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
266
+ key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
267
+ value_states = (
268
+ self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
269
+ )
270
+
271
+ kv_seq_len = key_states.shape[-2]
272
+ if past_key_value is not None:
273
+ kv_seq_len += past_key_value[0].shape[-2]
274
+ if not is_cross_attention:
275
+ cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
276
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
277
+ # [bsz, nh, t, hd]
278
+
279
+ if past_key_value is not None:
280
+ # reuse k, v, self_attention
281
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
282
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
283
+
284
+ past_key_value = (key_states, value_states) if use_cache else None
285
+
286
+ if self.qk_layer_norms:
287
+ query_states = self.q_layer_norm(query_states)
288
+ key_states = self.k_layer_norm(key_states)
289
+
290
+ if attention_mask is not None:
291
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
292
+ raise ValueError(
293
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
294
+ )
295
+
296
+ attn_output = nn.functional.scaled_dot_product_attention(
297
+ query_states,
298
+ key_states,
299
+ value_states,
300
+ attn_mask=attention_mask,
301
+ dropout_p=self.dropout,
302
+ )
303
+
304
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
305
+ raise ValueError(
306
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
307
+ f" {attn_output.size()}"
308
+ )
309
+
310
+ attn_output = attn_output.transpose(1, 2)
311
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
312
+
313
+ attn_output = self.o_proj(attn_output)
314
+
315
+ attn_weights = None
316
+ logger.warning_once(
317
+ "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
318
+ )
319
+
320
+ return attn_output, attn_weights, past_key_value
321
+
322
+
323
+ class LlamaDecoderLayer(nn.Module):
324
+ def __init__(self, config: VLlamaConfig):
325
+ super().__init__()
326
+ self.hidden_size = config.hidden_size
327
+ self.self_attn = LlamaAttention(
328
+ hidden_size=self.hidden_size,
329
+ num_heads=config.num_attention_heads,
330
+ dropout=config.dropout,
331
+ config=config,
332
+ )
333
+ self.mlp = LlamaMLP(
334
+ hidden_size=self.hidden_size,
335
+ intermediate_size=config.intermediate_size,
336
+ hidden_act=config.hidden_act,
337
+ )
338
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
339
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
340
+ self.dropout = config.dropout
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
348
+ output_attentions: Optional[bool] = False,
349
+ use_cache: Optional[bool] = False,
350
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
351
+ """
352
+ Args:
353
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
354
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
355
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
356
+ output_attentions (`bool`, *optional*):
357
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
358
+ returned tensors for more detail.
359
+ use_cache (`bool`, *optional*):
360
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
361
+ (see `past_key_values`).
362
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
363
+ """
364
+
365
+ residual = hidden_states
366
+
367
+ hidden_states = self.input_layernorm(hidden_states)
368
+
369
+ # Self Attention
370
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
371
+ hidden_states=hidden_states,
372
+ attention_mask=attention_mask,
373
+ position_ids=position_ids,
374
+ past_key_value=past_key_value,
375
+ output_attentions=output_attentions,
376
+ use_cache=use_cache,
377
+ )
378
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
379
+ hidden_states = residual + hidden_states
380
+
381
+ # Fully Connected
382
+ residual = hidden_states
383
+ hidden_states = self.post_attention_layernorm(hidden_states)
384
+ hidden_states = self.mlp(hidden_states)
385
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
386
+ hidden_states = residual + hidden_states
387
+
388
+ outputs = (hidden_states,)
389
+
390
+ if output_attentions:
391
+ outputs += (self_attn_weights,)
392
+
393
+ if use_cache:
394
+ outputs += (present_key_value,)
395
+
396
+ return outputs
397
+
398
+
399
+ class VLlamaGatedCrossAttentionLayer(nn.Module):
400
+ def __init__(self, config: VLlamaConfig):
401
+ super().__init__()
402
+ self.hidden_size = config.hidden_size
403
+ self.cross_attn = LlamaAttention(
404
+ hidden_size=self.hidden_size,
405
+ num_heads=config.num_attention_heads,
406
+ is_cross_attention=True,
407
+ dropout=config.dropout,
408
+ config=config,
409
+ qk_layer_norms=config.qk_layer_norms,
410
+ )
411
+ self.mlp = LlamaMLP(
412
+ hidden_size=self.hidden_size,
413
+ intermediate_size=config.intermediate_size,
414
+ hidden_act=config.hidden_act,
415
+ )
416
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
417
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418
+ self.config = config.dropout
419
+
420
+ self.act_cross_attn = nn.Tanh()
421
+ self.act_dense = nn.Tanh()
422
+
423
+ if config.alpha_initializer == "zeros":
424
+ if config.alpha_type == "vector":
425
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
426
+ self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
427
+ elif config.alpha_type == "float":
428
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
429
+ self.alpha_dense = nn.Parameter(torch.zeros(1))
430
+ else:
431
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
432
+
433
+ elif config.alpha_initializer == "ones":
434
+ if config.alpha_type == "vector":
435
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size))
436
+ self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.hidden_size))
437
+ elif config.alpha_type == "float":
438
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1))
439
+ self.alpha_dense = nn.Parameter(torch.ones(1))
440
+ else:
441
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
442
+
443
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
444
+ if config.alpha_type == "vector":
445
+ self.alpha_cross_attn = nn.Parameter(
446
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
447
+ )
448
+ self.alpha_dense = nn.Parameter(
449
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
450
+ )
451
+ elif config.alpha_type == "float":
452
+ self.alpha_cross_attn = nn.Parameter(
453
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
454
+ )
455
+ self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
456
+ else:
457
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
458
+
459
+ else:
460
+ raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
461
+
462
+ if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
463
+ raise ValueError("Alpha parameters not initialized correctly!")
464
+
465
+ def forward(
466
+ self,
467
+ hidden_states: torch.Tensor,
468
+ attention_mask: Optional[torch.Tensor] = None,
469
+ image_hidden_states: Optional[torch.Tensor] = None,
470
+ image_attention_mask: Optional[torch.Tensor] = None,
471
+ output_attentions: Optional[bool] = False,
472
+ use_cache: Optional[bool] = False,
473
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
474
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
475
+ """
476
+ Args:
477
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
478
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
479
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
480
+ output_attentions (`bool`, *optional*):
481
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
482
+ returned tensors for more detail.
483
+ use_cache (`bool`, *optional*):
484
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
485
+ (see `past_key_values`).
486
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
487
+ """
488
+ if image_hidden_states is None:
489
+ raise ValueError(
490
+ "`image_hidden_states` is required for VLlama cross attention module which are visual features to be"
491
+ " conditioned on."
492
+ )
493
+
494
+ if past_key_value is not None:
495
+ raise NotImplementedError("Past key value states are not implemented for VLlama cross attention module.")
496
+
497
+ residual = hidden_states
498
+
499
+ hidden_states = self.input_layernorm(hidden_states)
500
+
501
+ # Self Attention
502
+ hidden_states, self_attn_weights, present_key_value = self.cross_attn(
503
+ hidden_states=hidden_states,
504
+ key_value_states=image_hidden_states,
505
+ attention_mask=image_attention_mask,
506
+ output_attentions=output_attentions,
507
+ )
508
+ hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
509
+ hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
510
+
511
+ # Fully Connected
512
+ residual = hidden_states
513
+ hidden_states = self.post_attention_layernorm(hidden_states)
514
+ hidden_states = self.mlp(hidden_states)
515
+ hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
516
+ hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
517
+
518
+ outputs = (hidden_states,)
519
+
520
+ if output_attentions:
521
+ outputs += (self_attn_weights,)
522
+
523
+ if use_cache:
524
+ outputs += (present_key_value,)
525
+
526
+ return outputs
527
+
528
+
529
+ LLAMA_START_DOCSTRING = r"""
530
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
531
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
532
+ etc.)
533
+
534
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
535
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
536
+ and behavior.
537
+
538
+ Parameters:
539
+ config ([`VLlamaConfig`]):
540
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
541
+ load the weights associated with the model, only the configuration. Check out the
542
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
543
+ """
544
+
545
+
546
+ @add_start_docstrings(
547
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
548
+ LLAMA_START_DOCSTRING,
549
+ )
550
+ class VLlamaPreTrainedModel(VLOOMPreTrainedModelBase):
551
+ config_class = VLlamaConfig
552
+ base_model_prefix = "model"
553
+ supports_gradient_checkpointing = True
554
+ _no_split_modules = ["LlamaDecoderLayer", "VLlamaGatedCrossAttentionLayer"]
555
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
556
+
557
+ def _init_weights(self, module):
558
+ def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
559
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
560
+ module.weight.data.normal_(mean=mean, std=std)
561
+ if module.bias is not None:
562
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
563
+ module.bias.data.zero_()
564
+
565
+ if isinstance(module, VLlamaGatedCrossAttentionLayer):
566
+ for sub_module_name, sub_module in module.named_modules():
567
+ if isinstance(sub_module, nn.Linear):
568
+ if "down_proj" in sub_module_name:
569
+ factor = 2 * self.config.num_hidden_layers
570
+ else:
571
+ factor = 1.0
572
+ init_a_linear(sub_module, std=(0.4 / (sub_module.in_features * factor)) ** 0.5)
573
+ elif isinstance(module, PerceiverResampler):
574
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
575
+ module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.vision_embed_dim) ** 0.5)
576
+ for sub_module_name, sub_module in module.named_modules():
577
+ if isinstance(sub_module, nn.Linear):
578
+ if "c_proj" in sub_module_name:
579
+ factor = 2 * self.config.num_hidden_layers
580
+ else:
581
+ factor = 1.0
582
+ init_a_linear(sub_module, std=(0.4 / (self.config.vision_embed_dim * factor)) ** 0.5)
583
+ elif isinstance(module, nn.Embedding):
584
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
585
+ module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
586
+ if module.padding_idx is not None:
587
+ module.weight.data[module.padding_idx].zero_()
588
+ elif isinstance(module, DecoupledLinear):
589
+ if hasattr(module, "additional_fc"):
590
+ init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)
591
+
592
+ def _set_gradient_checkpointing(self, module, value=False):
593
+ if isinstance(module, VLlamaModel):
594
+ module.gradient_checkpointing = value
595
+
596
+ @classmethod
597
+ def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
598
+ # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
599
+ beheaded_model = model.model if hasattr(model, "model") else model
600
+ cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
601
+ beheaded_model.freeze_relevant_params(config)
602
+
603
+
604
+ LLAMA_INPUTS_DOCSTRING = r"""
605
+ Args:
606
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
607
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
608
+ it.
609
+
610
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
611
+ [`PreTrainedTokenizer.__call__`] for details.
612
+
613
+ [What are input IDs?](../glossary#input-ids)
614
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
615
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
616
+
617
+ - 1 for tokens that are **not masked**,
618
+ - 0 for tokens that are **masked**.
619
+
620
+ [What are attention masks?](../glossary#attention-mask)
621
+
622
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
623
+ [`PreTrainedTokenizer.__call__`] for details.
624
+
625
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
626
+ `past_key_values`).
627
+
628
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
629
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
630
+ information on the default strategy.
631
+
632
+ - 1 indicates the head is **not masked**,
633
+ - 0 indicates the head is **masked**.
634
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
635
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
636
+ config.n_positions - 1]`.
637
+ [What are position IDs?](../glossary#position-ids)
638
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
639
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
640
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
641
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
642
+
643
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
644
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
645
+
646
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
647
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
648
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
649
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
650
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
651
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
652
+ model's internal embedding lookup matrix.
653
+ use_cache (`bool`, *optional*):
654
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
655
+ `past_key_values`).
656
+ output_attentions (`bool`, *optional*):
657
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
658
+ tensors for more detail.
659
+ output_hidden_states (`bool`, *optional*):
660
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
661
+ more detail.
662
+ return_dict (`bool`, *optional*):
663
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
664
+ """
665
+
666
+
667
+ @add_start_docstrings(
668
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
669
+ LLAMA_START_DOCSTRING,
670
+ )
671
+ class VLlamaModel(VLlamaPreTrainedModel):
672
+ """
673
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
674
+
675
+ Args:
676
+ config: VLlamaConfig
677
+ """
678
+
679
+ def __init__(self, config: VLlamaConfig, vision_model=None):
680
+ super().__init__(config)
681
+ self.config = config
682
+ self.padding_idx = config.pad_token_id
683
+ self.vocab_size = config.vocab_size
684
+
685
+ self.embed_tokens = DecoupledEmbedding(
686
+ num_embeddings=config.vocab_size,
687
+ num_additional_embeddings=config.additional_vocab_size,
688
+ embedding_dim=config.hidden_size,
689
+ partially_freeze=config.freeze_text_layers,
690
+ padding_idx=self.padding_idx,
691
+ )
692
+
693
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
694
+ # this solves the losing of weights in `from_pretrained` on the main model
695
+ self.vision_model = vision_model
696
+
697
+ # Perceiver Resampler
698
+ if config.use_resampler:
699
+ self.perceiver_resampler = PerceiverResampler(
700
+ self.config,
701
+ self.config.vision_embed_dim,
702
+ config.resampler_depth,
703
+ config.resampler_n_heads,
704
+ config.resampler_head_dim,
705
+ config.resampler_n_latents,
706
+ )
707
+
708
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
709
+
710
+ self.cross_layer_interval = config.cross_layer_interval
711
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
712
+ self.gated_cross_attn_layers = nn.ModuleList(
713
+ [VLlamaGatedCrossAttentionLayer(config) for _ in range(num_cross_layers)]
714
+ )
715
+ self.gradient_checkpointing = False
716
+
717
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
718
+
719
+ self.gradient_checkpointing = False
720
+ # Initialize weights and apply final processing
721
+ self.post_init()
722
+
723
+ self.freeze_relevant_params(config)
724
+
725
+ def freeze_relevant_params(self, config=None):
726
+ if config is None:
727
+ config = self.config
728
+
729
+ if config.freeze_text_layers:
730
+ self.freeze_text_layers(config.freeze_text_module_exceptions)
731
+
732
+ if config.freeze_vision_layers:
733
+ freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
734
+
735
+ def freeze_text_layers(self, module_exceptions):
736
+ for module in [self.layers, self.norm]:
737
+ freeze_model(module, module_exceptions=module_exceptions)
738
+
739
+ def get_input_embeddings(self):
740
+ return self.embed_tokens
741
+
742
+ def set_input_embeddings(self, value):
743
+ self.embed_tokens = value
744
+
745
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
746
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
747
+ # create causal mask
748
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
749
+ combined_attention_mask = None
750
+ if input_shape[-1] > 1:
751
+ combined_attention_mask = _make_causal_mask(
752
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
753
+ ).to(inputs_embeds.device)
754
+
755
+ if attention_mask is not None:
756
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
757
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
758
+ inputs_embeds.device
759
+ )
760
+ combined_attention_mask = (
761
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
762
+ )
763
+
764
+ return combined_attention_mask
765
+
766
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
767
+ def forward(
768
+ self,
769
+ input_ids: torch.LongTensor = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ position_ids: Optional[torch.LongTensor] = None,
772
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
773
+ inputs_embeds: Optional[torch.FloatTensor] = None,
774
+ pixel_values: Optional[torch.FloatTensor] = None,
775
+ image_embeddings: Optional[torch.FloatTensor] = None,
776
+ image_attention_mask: Optional[torch.Tensor] = None,
777
+ use_cache: Optional[bool] = None,
778
+ output_attentions: Optional[bool] = None,
779
+ output_hidden_states: Optional[bool] = None,
780
+ return_dict: Optional[bool] = None,
781
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
782
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
783
+
784
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
785
+ output_hidden_states = (
786
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
787
+ )
788
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
789
+
790
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
791
+
792
+ # retrieve input_ids and inputs_embeds
793
+ if input_ids is not None and inputs_embeds is not None:
794
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
795
+ elif input_ids is not None:
796
+ batch_size, seq_length = input_ids.shape
797
+ elif inputs_embeds is not None:
798
+ batch_size, seq_length, _ = inputs_embeds.shape
799
+ else:
800
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
801
+
802
+ seq_length_with_past = seq_length
803
+ past_key_values_length = 0
804
+
805
+ if past_key_values is not None:
806
+ past_key_values_length = past_key_values[0][0].shape[2]
807
+ seq_length_with_past = seq_length_with_past + past_key_values_length
808
+
809
+ if attention_mask is not None and position_ids is None:
810
+ # create position_ids on the fly for batch generation
811
+ position_ids = attention_mask.long().cumsum(-1) - 1
812
+ position_ids.masked_fill_(attention_mask == 0, 1)
813
+ elif position_ids is None:
814
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
815
+ position_ids = torch.arange(
816
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
817
+ )
818
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
819
+ else:
820
+ position_ids = position_ids.view(-1, seq_length).long()
821
+
822
+ if pixel_values is not None and image_embeddings is not None:
823
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
824
+ elif pixel_values is not None:
825
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
826
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
827
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
828
+ # Get sequence from the vision encoder
829
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
830
+ elif image_embeddings is not None:
831
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
832
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
833
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
834
+
835
+ if self.config.use_resampler:
836
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
837
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
838
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
839
+ # Make image_attention_mask compatible with hidden states
840
+ text_seq_len = image_attention_mask.size(1)
841
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
842
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
843
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
844
+
845
+ if image_hidden_states is not None:
846
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
847
+ image_hidden_shape = (image_batch_size, image_sequence_length)
848
+ if image_attention_mask is None:
849
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
850
+ image_attention_mask = self.invert_attention_mask(image_attention_mask)
851
+ else:
852
+ image_attention_mask = None
853
+
854
+ if inputs_embeds is None:
855
+ inputs_embeds = self.embed_tokens(input_ids)
856
+ # embed positions
857
+ if attention_mask is None:
858
+ attention_mask = torch.ones(
859
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
860
+ )
861
+ attention_mask = self._prepare_decoder_attention_mask(
862
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
863
+ )
864
+
865
+ hidden_states = inputs_embeds
866
+
867
+ if self.gradient_checkpointing and self.training:
868
+ if use_cache:
869
+ logger.warning_once(
870
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
871
+ )
872
+ use_cache = False
873
+
874
+ # decoder layers
875
+ all_hidden_states = () if output_hidden_states else None
876
+ all_self_attns = () if output_attentions else None
877
+ next_decoder_cache = () if use_cache else None
878
+
879
+ for idx, decoder_layer in enumerate(self.layers):
880
+ if output_hidden_states:
881
+ all_hidden_states += (hidden_states,)
882
+
883
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
884
+
885
+ def vblock(
886
+ main_block,
887
+ hidden_states,
888
+ attention_mask,
889
+ position_ids,
890
+ past_key_value,
891
+ image_hidden_states,
892
+ image_attention_mask,
893
+ output_attentions,
894
+ use_cache,
895
+ layer_idx,
896
+ cross_layer_interval,
897
+ gated_cross_attn_layers,
898
+ ):
899
+ # TODO(ls): Add cross attention values to respective lists
900
+ if layer_idx % cross_layer_interval == 0:
901
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
902
+ outputs = xblock(
903
+ hidden_states,
904
+ attention_mask=attention_mask,
905
+ image_hidden_states=image_hidden_states,
906
+ image_attention_mask=image_attention_mask,
907
+ output_attentions=output_attentions,
908
+ use_cache=use_cache,
909
+ past_key_value=None, # not implemented
910
+ )
911
+ hidden_states = outputs[0]
912
+
913
+ layer_outputs = main_block(
914
+ hidden_states,
915
+ attention_mask=attention_mask,
916
+ position_ids=position_ids,
917
+ past_key_value=past_key_value,
918
+ output_attentions=output_attentions,
919
+ use_cache=use_cache,
920
+ )
921
+
922
+ return layer_outputs
923
+
924
+ if self.gradient_checkpointing and self.training:
925
+ past_key_value = None
926
+ if use_cache:
927
+ logger.warning_once(
928
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
929
+ )
930
+ use_cache = False
931
+
932
+ layer_outputs = torch.utils.checkpoint.checkpoint(
933
+ vblock,
934
+ decoder_layer,
935
+ hidden_states,
936
+ attention_mask,
937
+ position_ids,
938
+ past_key_value,
939
+ image_hidden_states,
940
+ image_attention_mask,
941
+ output_attentions,
942
+ use_cache,
943
+ idx,
944
+ self.cross_layer_interval,
945
+ self.gated_cross_attn_layers,
946
+ )
947
+ else:
948
+ layer_outputs = vblock(
949
+ decoder_layer,
950
+ hidden_states,
951
+ attention_mask=attention_mask,
952
+ position_ids=position_ids,
953
+ past_key_value=past_key_value,
954
+ image_hidden_states=image_hidden_states,
955
+ image_attention_mask=image_attention_mask,
956
+ output_attentions=output_attentions,
957
+ use_cache=use_cache,
958
+ layer_idx=idx,
959
+ cross_layer_interval=self.cross_layer_interval,
960
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
961
+ )
962
+
963
+ hidden_states = layer_outputs[0]
964
+
965
+ if use_cache:
966
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
967
+
968
+ if output_attentions:
969
+ all_self_attns += (layer_outputs[1],)
970
+
971
+ hidden_states = self.norm(hidden_states)
972
+
973
+ # add hidden states from the last decoder layer
974
+ if output_hidden_states:
975
+ all_hidden_states += (hidden_states,)
976
+
977
+ next_cache = next_decoder_cache if use_cache else None
978
+ if not return_dict:
979
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
980
+ return BaseModelOutputWithPast(
981
+ last_hidden_state=hidden_states,
982
+ past_key_values=next_cache,
983
+ hidden_states=all_hidden_states,
984
+ attentions=all_self_attns,
985
+ )
986
+
987
+
988
+ class VLlamaForCausalLM(VLlamaPreTrainedModel):
989
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
990
+
991
+ def __init__(self, config, vision_model=None):
992
+ super().__init__(config)
993
+ self.model = VLlamaModel(config, vision_model=vision_model)
994
+
995
+ self.lm_head = DecoupledLinear(
996
+ in_features=config.hidden_size,
997
+ out_features=config.vocab_size,
998
+ out_additional_features=config.additional_vocab_size,
999
+ bias=False,
1000
+ partially_freeze=config.freeze_lm_head,
1001
+ )
1002
+
1003
+ # Initialize weights and apply final processing
1004
+ self.post_init()
1005
+
1006
+ def get_input_embeddings(self):
1007
+ return self.model.embed_tokens
1008
+
1009
+ def set_input_embeddings(self, value):
1010
+ self.model.embed_tokens = value
1011
+
1012
+ def get_output_embeddings(self):
1013
+ return self.lm_head
1014
+
1015
+ def set_output_embeddings(self, new_embeddings):
1016
+ self.lm_head = new_embeddings
1017
+
1018
+ def set_decoder(self, decoder):
1019
+ self.model = decoder
1020
+
1021
+ def get_decoder(self):
1022
+ return self.model
1023
+
1024
+ def tie_weights(self):
1025
+ """
1026
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1027
+ """
1028
+ output_embeddings = self.get_output_embeddings()
1029
+ input_embeddings = self.get_input_embeddings()
1030
+
1031
+ if getattr(self.config, "tie_word_embeddings", True):
1032
+ output_embeddings.weight = input_embeddings.weight
1033
+ if input_embeddings.num_additional_embeddings > 0:
1034
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1035
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1036
+
1037
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1038
+ output_embeddings.out_features = input_embeddings.num_embeddings
1039
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1040
+ input_embeddings, "num_additional_embeddings"
1041
+ ):
1042
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1043
+
1044
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1045
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1046
+ def forward(
1047
+ self,
1048
+ input_ids: torch.LongTensor = None,
1049
+ attention_mask: Optional[torch.Tensor] = None,
1050
+ position_ids: Optional[torch.LongTensor] = None,
1051
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1052
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1053
+ pixel_values: Optional[torch.FloatTensor] = None,
1054
+ image_embeddings: Optional[torch.FloatTensor] = None,
1055
+ image_attention_mask: Optional[torch.Tensor] = None,
1056
+ labels: Optional[torch.LongTensor] = None,
1057
+ use_cache: Optional[bool] = None,
1058
+ output_attentions: Optional[bool] = None,
1059
+ output_hidden_states: Optional[bool] = None,
1060
+ return_dict: Optional[bool] = None,
1061
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1062
+ r"""
1063
+ Args:
1064
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1065
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1066
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1067
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1068
+
1069
+ Returns:
1070
+
1071
+ Example:
1072
+
1073
+ ```python
1074
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1075
+
1076
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1077
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1078
+
1079
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1080
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1081
+
1082
+ >>> # Generate
1083
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1084
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1085
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1086
+ ```"""
1087
+
1088
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1089
+ output_hidden_states = (
1090
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1091
+ )
1092
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1093
+
1094
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1095
+ outputs = self.model(
1096
+ input_ids=input_ids,
1097
+ attention_mask=attention_mask,
1098
+ position_ids=position_ids,
1099
+ past_key_values=past_key_values,
1100
+ inputs_embeds=inputs_embeds,
1101
+ pixel_values=pixel_values,
1102
+ image_embeddings=image_embeddings,
1103
+ image_attention_mask=image_attention_mask,
1104
+ use_cache=use_cache,
1105
+ output_attentions=output_attentions,
1106
+ output_hidden_states=output_hidden_states,
1107
+ return_dict=return_dict,
1108
+ )
1109
+
1110
+ hidden_states = outputs[0]
1111
+ logits = self.lm_head(hidden_states)
1112
+
1113
+ loss = None
1114
+ if labels is not None:
1115
+ # Shift so that tokens < n predict n
1116
+ if attention_mask is not None:
1117
+ shift_attention_mask = attention_mask[..., 1:]
1118
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1119
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1120
+ else:
1121
+ shift_logits = logits[..., :-1, :].contiguous()
1122
+ shift_labels = labels[..., 1:].contiguous()
1123
+ # Flatten the tokens
1124
+ loss_fct = CrossEntropyLoss()
1125
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1126
+
1127
+ if not return_dict:
1128
+ output = (logits,) + outputs[1:]
1129
+ return (loss,) + output if loss is not None else output
1130
+
1131
+ return CausalLMOutputWithPast(
1132
+ loss=loss,
1133
+ logits=logits,
1134
+ past_key_values=outputs.past_key_values,
1135
+ hidden_states=outputs.hidden_states,
1136
+ attentions=outputs.attentions,
1137
+ )
1138
+
1139
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1140
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1141
+ unwanted_kwargs = ["token_type_ids"]
1142
+ for kwarg in unwanted_kwargs:
1143
+ inputs.pop(kwarg, None)
1144
+ return inputs
1145
+
1146
+ @staticmethod
1147
+ def _expand_inputs_for_generation(
1148
+ *args,
1149
+ **model_kwargs,
1150
+ ):
1151
+ return expand_inputs_for_generation(*args, **model_kwargs)
1152
+
1153
+ @staticmethod
1154
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
1155
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
1156
+
1157
+ @staticmethod
1158
+ def _reorder_cache(past, beam_idx):
1159
+ reordered_past = ()
1160
+ for layer_past in past:
1161
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1162
+ return reordered_past
1163
+
1164
+ def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
1165
+ config_vl_model = self.config
1166
+
1167
+ language_embed_size = config_vl_model.hidden_size
1168
+ num_language_layers = config_vl_model.num_hidden_layers
1169
+ ffn_inner_size = config_vl_model.intermediate_size
1170
+
1171
+ vision_config = self.model.vision_model.config
1172
+ if hasattr(vision_config, "vision_config"):
1173
+ vision_config = vision_config.vision_config
1174
+
1175
+ # Get vision model blocks infos
1176
+ vision_patch_size = vision_config.patch_size
1177
+ vision_hidden_size = vision_config.hidden_size
1178
+ num_vision_layers = vision_config.num_hidden_layers
1179
+ # The +1 is for the CLS token
1180
+ single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
1181
+ vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
1182
+
1183
+ # Get language and cross-att blocks infos
1184
+ num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
1185
+ language_seq_len = data_param.max_seq_len
1186
+ language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1187
+ cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1188
+ k_v_cross_attn_seq_len = (
1189
+ (self.config.resampler_n_latents * max_num_images)
1190
+ if self.config.use_resampler
1191
+ else (single_image_seq_len * max_num_images)
1192
+ )
1193
+
1194
+ language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1195
+ num_layers=num_language_layers,
1196
+ batch_size=hparams.batch_size_per_gpu,
1197
+ q_seq_len=language_seq_len,
1198
+ k_seq_len=language_seq_len,
1199
+ hidden_size=language_embed_size,
1200
+ kv_in_dim=language_embed_size,
1201
+ ff_exp_factor=language_exp_factor,
1202
+ grad_acc_size=hparams.grad_acc_size,
1203
+ swiglu=True,
1204
+ vocab_size=tokenizer.vocab_size,
1205
+ count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
1206
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1207
+ )
1208
+ cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1209
+ num_layers=num_cross_attn_layers,
1210
+ batch_size=hparams.batch_size_per_gpu,
1211
+ q_seq_len=language_seq_len,
1212
+ k_seq_len=k_v_cross_attn_seq_len,
1213
+ hidden_size=language_embed_size,
1214
+ kv_in_dim=vision_hidden_size,
1215
+ ff_exp_factor=cross_att_exp_factor,
1216
+ grad_acc_size=hparams.grad_acc_size,
1217
+ swiglu=True,
1218
+ vocab_size=None,
1219
+ count_backward=True,
1220
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1221
+ )
1222
+ vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1223
+ num_layers=num_vision_layers,
1224
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1225
+ q_seq_len=single_image_seq_len,
1226
+ k_seq_len=single_image_seq_len,
1227
+ hidden_size=vision_hidden_size,
1228
+ kv_in_dim=vision_hidden_size,
1229
+ ff_exp_factor=vision_exp_factor,
1230
+ grad_acc_size=hparams.grad_acc_size,
1231
+ swiglu=False,
1232
+ vocab_size=None,
1233
+ count_backward=not hparams.model_params["freeze_vision_layers"],
1234
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1235
+ )
1236
+ if self.config.use_resampler:
1237
+ perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
1238
+ num_layers=self.config.resampler_depth,
1239
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1240
+ q_seq_len=self.config.resampler_n_latents,
1241
+ vision_embed_seq_len=single_image_seq_len,
1242
+ q_k_v_input_dim=vision_hidden_size,
1243
+ attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
1244
+ ff_exp_factor=cross_att_exp_factor,
1245
+ count_backward=True,
1246
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1247
+ )
1248
+ flop_count = (
1249
+ language_tflops_per_batch_per_gpu
1250
+ + cross_attention_tflops_per_batch_per_gpu
1251
+ + vision_tflops_per_batch_per_gpu
1252
+ + perceiver_tflops_per_batch_per_gpu
1253
+ )
1254
+ else:
1255
+ flop_count = (
1256
+ language_tflops_per_batch_per_gpu
1257
+ + cross_attention_tflops_per_batch_per_gpu
1258
+ + vision_tflops_per_batch_per_gpu
1259
+ )
1260
+ return flop_count
m4/models/vopt/__init__.py ADDED
File without changes
m4/models/vopt/configuration_vopt.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ OPT model configuration"""
16
+ import os
17
+ from typing import Tuple, Union
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ OPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/opt-125m": "https://huggingface.co/facebook/opt-125m/blob/main/config.json",
28
+ "facebook/opt-350m": "https://huggingface.co/facebook/opt-350m/blob/main/config.json",
29
+ "facebook/opt-1.3b": "https://huggingface.co/facebook/opt-1.3b/blob/main/config.json",
30
+ "facebook/opt-2.7b": "https://huggingface.co/facebook/opt-2.7b/blob/main/config.json",
31
+ "facebook/opt-6.7b": "https://huggingface.co/facebook/opt-6.7b/blob/main/config.json",
32
+ "facebook/opt-13b": "https://huggingface.co/facebook/opt-13b/blob/main/config.json",
33
+ }
34
+
35
+
36
+ class VOPTConfig(PretrainedConfig):
37
+ r"""
38
+ This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model
39
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
40
+ defaults will yield a similar configuration to that of the OPT
41
+ [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture.
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ TODO: this doc is completely out of sync with the actual args
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 50272):
50
+ Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the
51
+ `inputs_ids` passed when calling [`OPTModel`]
52
+ additional_vocab_size (`int`, *optional`, defaults to 0):
53
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
54
+ are always trainable whereas regular vocab tokens can be frozen or not.
55
+ hidden_size (`int`, *optional*, defaults to 768):
56
+ Dimensionality of the layers and the pooler layer.
57
+ num_hidden_layers (`int`, *optional*, defaults to 12):
58
+ Number of decoder layers.
59
+ ffn_dim (`int`, *optional*, defaults to 3072):
60
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
61
+ num_attention_heads (`int`, *optional*, defaults to 12):
62
+ Number of attention heads for each attention layer in the Transformer decoder.
63
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
65
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
66
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
67
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
68
+ just in case (e.g., 512 or 1024 or 2048).
69
+ do_layer_norm_before (`bool`, *optional*, defaults to `True`):
70
+ Whether to perform layer normalization before the attention block.
71
+ word_embed_proj_dim (`int`, *optional*):
72
+ `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to
73
+ `hidden_size`.
74
+ dropout (`float`, *optional*, defaults to 0.1):
75
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
76
+ attention_dropout (`float`, *optional*, defaults to 0.0):
77
+ The dropout ratio for the attention probabilities.
78
+ layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
80
+ details.
81
+ init_std (`float`, *optional*, defaults to 0.02):
82
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
83
+ alpha_initializer (`str`, *optional*, defaults to `"ones"`):
84
+ Initialization type for the alphas.
85
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
86
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
87
+ alpha_type (`str`, *optional*, defaults to `"vector"`):
88
+ Whether the gating alphas should be vectors or single floats.
89
+ use_cache (`bool`, *optional*, defaults to `True`):
90
+ Whether or not the model should return the last key/values attentions (not used by all models).
91
+ cross_layer_interval (`int`, *optional*, default to 1)
92
+ Interval for cross attention (from text to image) layers.
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import OPTModel, OPTConfig
97
+
98
+ >>> # Initializing a OPT facebook/opt-large style configuration
99
+ >>> configuration = OPTConfig()
100
+
101
+ >>> # Initializing a model from the facebook/opt-large style configuration
102
+ >>> model = OPTModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "vopt"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_size=50272,
113
+ additional_vocab_size=0,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ ffn_dim=3072,
117
+ max_position_embeddings=2048,
118
+ do_layer_norm_before=True,
119
+ _remove_final_layer_norm=False,
120
+ word_embed_proj_dim=None,
121
+ dropout=0.1,
122
+ attention_dropout=0.0,
123
+ num_attention_heads=12,
124
+ activation_function="relu",
125
+ layerdrop=0.0,
126
+ init_std=0.02,
127
+ alpha_initializer="ones",
128
+ alphas_initializer_range=0.0,
129
+ alpha_type="vector",
130
+ use_cache=True,
131
+ pad_token_id=1,
132
+ bos_token_id=2,
133
+ eos_token_id=2,
134
+ cross_layer_interval=1,
135
+ cross_layer_activation_function="swiglu",
136
+ normformer_layer_norms=False,
137
+ qk_layer_norms=False,
138
+ rms_norm=False,
139
+ qk_layer_norms_perceiver=False,
140
+ tie_word_embeddings=False,
141
+ freeze_text_layers=True,
142
+ freeze_text_module_exceptions=[],
143
+ freeze_lm_head=False,
144
+ freeze_vision_layers=True,
145
+ freeze_vision_module_exceptions=[],
146
+ vision_model_name="google/vit-base-patch16-224",
147
+ vision_model_params="{}",
148
+ vision_embed_dim=768,
149
+ vision_image_size=224,
150
+ image_token_index=50257, # TODO: change this to right value
151
+ use_resampler=False,
152
+ resampler_n_latents=64,
153
+ resampler_depth=6,
154
+ resampler_n_heads=16,
155
+ resampler_head_dim=96,
156
+ **kwargs,
157
+ ):
158
+ super().__init__(
159
+ pad_token_id=pad_token_id,
160
+ bos_token_id=bos_token_id,
161
+ eos_token_id=eos_token_id,
162
+ tie_word_embeddings=tie_word_embeddings,
163
+ **kwargs,
164
+ )
165
+ self.vocab_size = vocab_size
166
+ self.additional_vocab_size = additional_vocab_size
167
+ self.max_position_embeddings = max_position_embeddings
168
+ self.num_attention_heads = num_attention_heads
169
+ self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
170
+ self.ffn_dim = ffn_dim
171
+ self.hidden_size = hidden_size
172
+ self.num_hidden_layers = num_hidden_layers
173
+ self.dropout = dropout
174
+ self.attention_dropout = attention_dropout
175
+ self.activation_function = activation_function
176
+ self.init_std = init_std
177
+ self.alpha_initializer = alpha_initializer
178
+ self.alphas_initializer_range = alphas_initializer_range
179
+ self.alpha_type = alpha_type
180
+ self.layerdrop = layerdrop
181
+ self.use_cache = use_cache
182
+ self.do_layer_norm_before = do_layer_norm_before
183
+
184
+ # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
185
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
186
+ # see https://github.com/facebookresearch/metaseq/pull/164
187
+ self._remove_final_layer_norm = _remove_final_layer_norm
188
+
189
+ self.cross_layer_interval = cross_layer_interval
190
+ self.cross_layer_activation_function = cross_layer_activation_function
191
+ self.normformer_layer_norms = normformer_layer_norms
192
+ self.qk_layer_norms = qk_layer_norms
193
+ self.rms_norm = rms_norm
194
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
195
+ self.freeze_vision_layers = freeze_vision_layers
196
+ self.vision_model_name = vision_model_name
197
+ self.vision_model_params = vision_model_params
198
+
199
+ self.tie_word_embeddings = tie_word_embeddings
200
+ self.freeze_text_layers = freeze_text_layers
201
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
202
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
203
+ self.freeze_lm_head = freeze_lm_head
204
+ self.image_token_index = image_token_index
205
+
206
+ self.vision_embed_dim = vision_embed_dim
207
+ self.vision_image_size = vision_image_size
208
+
209
+ # Resampler params
210
+ self.use_resampler = use_resampler
211
+ self.resampler_n_latents = resampler_n_latents
212
+ self.resampler_depth = resampler_depth
213
+ self.resampler_n_heads = resampler_n_heads
214
+ self.resampler_head_dim = resampler_head_dim
215
+
216
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
217
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
218
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
219
+ # of this object many attributes have default values and haven't yet been overridden.
220
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
221
+
222
+ def check_compatibilities(self):
223
+ vision_model_params = eval(self.vision_model_params)
224
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
225
+ if hasattr(config, "vision_config"):
226
+ vision_config = config.vision_config
227
+ else:
228
+ vision_config = config
229
+ vision_embed_dim = vision_config.hidden_size
230
+ if self.vision_embed_dim != vision_embed_dim:
231
+ raise ValueError(
232
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
233
+ f" ({vision_embed_dim})"
234
+ )
235
+ vision_image_size = vision_config.image_size
236
+ if self.vision_image_size != vision_image_size:
237
+ raise ValueError(
238
+ f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
239
+ f" ({vision_image_size})"
240
+ )
241
+
242
+ @classmethod
243
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
244
+ outputs = super(VOPTConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
245
+ if isinstance(outputs, Tuple):
246
+ # When called with return_unused_kwargs=True, the first item will be the config
247
+ outputs[0].check_compatibilities()
248
+ else:
249
+ outputs.check_compatibilities()
250
+ return outputs
m4/models/vopt/make_tiny_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script creates a super tiny model that is useful inside tests, when we just want to test that
4
+ # the machinery works, without needing to check the quality of the outcomes.
5
+ #
6
+ # usage: adjust the configs if wanted, but otherwise just run the script
7
+
8
+ from pathlib import Path
9
+ from types import SimpleNamespace
10
+
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+
14
+ from m4.models.vopt.modeling_vopt import VOPTConfig, VOPTForCausalLM
15
+ from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
16
+ from m4.training.utils import get_tokenizer
17
+
18
+
19
+ mname_tiny = "tiny-random-vopt-clip"
20
+
21
+ path = Path(mname_tiny)
22
+ path.mkdir(parents=True, exist_ok=True)
23
+
24
+ # from the hardcoded https://github.com/huggingface/m4/blob/adf102f0000cb2632cd8a3ebb87398c65e448a97/m4/training/main.py#L80
25
+ additional_vocab_size = 2
26
+
27
+ config = VOPTConfig()
28
+ config.update(
29
+ dict(
30
+ ffn_dim=64,
31
+ hidden_size=16,
32
+ max_position_embeddings=128,
33
+ num_attention_heads=4,
34
+ num_hidden_layers=2,
35
+ word_embed_proj_dim=16,
36
+ max_new_tokens=100,
37
+ use_resampler=True,
38
+ resampler_depth=2,
39
+ resampler_head_dim=8,
40
+ resampler_n_heads=2,
41
+ resampler_n_latents=16,
42
+ vision_embed_dim=32,
43
+ vision_image_size=30,
44
+ vision_model_name="hf-internal-testing/tiny-random-clip",
45
+ vision_model_params="{}",
46
+ vocab_size=50265,
47
+ additional_vocab_size=additional_vocab_size,
48
+ )
49
+ )
50
+
51
+ # print(config)
52
+ # can now modify config to say tiny values
53
+
54
+ model = VOPTForCausalLM.from_config(config)
55
+ # print(model.config)
56
+ # print(model)
57
+
58
+ tokenizer_config = dict(
59
+ tokenizer_add_special_tokens="{}",
60
+ tokenizer_add_tokens=(
61
+ '[AddedToken("<fake_token_around_image>", rstrip=False, lstrip=False), AddedToken("<image>", rstrip=False,'
62
+ " lstrip=False)]"
63
+ ),
64
+ tokenizer_name="facebook/opt-13b",
65
+ tokenizer_params='{"use_fast":True}',
66
+ )
67
+ tokenizer_config = SimpleNamespace(**tokenizer_config)
68
+ # print(tokenizer_config)
69
+
70
+ tokenizer = get_tokenizer(
71
+ tokenizer_name=tokenizer_config.tokenizer_name,
72
+ tokenizer_add_tokens=tokenizer_config.tokenizer_add_tokens,
73
+ tokenizer_add_special_tokens=tokenizer_config.tokenizer_add_special_tokens,
74
+ tokenizer_params=tokenizer_config.tokenizer_params,
75
+ additional_vocab_size=model.config.additional_vocab_size,
76
+ model_vocab_size=model.config.vocab_size,
77
+ )
78
+ assert "<image>" in tokenizer.get_vocab()
79
+
80
+ # Test w/ one image and one text
81
+ query = "<fake_token_around_image><image><fake_token_around_image>This is a picture of a cat."
82
+ query_tokens = tokenizer(query, return_tensors="pt")
83
+
84
+ num_images_per_ex = 1
85
+ pixel_values = transforms.ToTensor()(Image.new("RGB", (30, 30))).repeat(1, 1, 1, 1).unsqueeze(0)
86
+ image_attention_mask, _ = image_attention_mask_for_packed_input_ids(query_tokens["input_ids"], tokenizer)
87
+ image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=num_images_per_ex)
88
+
89
+ input = {
90
+ "input_ids": query_tokens["input_ids"],
91
+ "attention_mask": query_tokens["attention_mask"],
92
+ "pixel_values": pixel_values,
93
+ "pixel_values": pixel_values,
94
+ "image_attention_mask": image_attention_mask,
95
+ }
96
+ # debug shapes
97
+ # print(query_tokens["input_ids"].shape)
98
+ # print(query_tokens["attention_mask"].shape)
99
+ # print(pixel_values.shape)
100
+ # print(image_attention_mask.shape)
101
+
102
+ out_gen = model.generate(**input)
103
+ text = tokenizer.batch_decode(out_gen)
104
+ # print(text)
105
+
106
+ # Save model + config + tokenizer
107
+ model.half() # makes it smaller
108
+ model.save_pretrained(path)
109
+ tokenizer.save_pretrained(path)
110
+
111
+ # test we can load it back
112
+ model = VOPTForCausalLM.from_pretrained(path)
113
+
114
+ print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
m4/models/vopt/modeling_vopt.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch OPT model."""
16
+ import random
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ from transformers.activations import ACT2FN
24
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
25
+ from transformers.utils import (
26
+ ContextManagers,
27
+ add_code_sample_docstrings,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ replace_return_docstrings,
31
+ )
32
+
33
+ from m4.models import DecoupledEmbedding, DecoupledLinear
34
+ from m4.models.common import (
35
+ expand_inputs_for_generation,
36
+ prepare_inputs_for_generation,
37
+ update_model_kwargs_for_generation,
38
+ )
39
+ from m4.models.custom_modules import VLOOMPreTrainedModelBase
40
+ from m4.models.perceiver.perceiver import PerceiverResampler
41
+ from m4.models.vopt.configuration_vopt import VOPTConfig
42
+ from m4.training.utils import (
43
+ compute_perceiver_tflops_per_batch_per_gpu,
44
+ compute_tflops_per_batch_per_gpu,
45
+ deepspeed_gathered_parameters_context_manager,
46
+ freeze_model,
47
+ )
48
+ from m4.utils import logging
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
54
+ _CONFIG_FOR_DOC = "VOPTConfig"
55
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
56
+
57
+ # Base model docstring
58
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
59
+
60
+ # SequenceClassification docstring
61
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
62
+ _SEQ_CLASS_EXPECTED_LOSS = 1.71
63
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
64
+
65
+
66
+ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
67
+ "facebook/opt-125m",
68
+ "facebook/opt-350m",
69
+ "facebook/opt-1.3b",
70
+ "facebook/opt-2.7b",
71
+ "facebook/opt-6.7b",
72
+ "facebook/opt-13b",
73
+ "facebook/opt-30b",
74
+ # See all OPT models at https://huggingface.co/models?filter=opt
75
+ ]
76
+
77
+
78
+ class SwiGLUActivation(nn.Module):
79
+ def __init__(self, in_features: int, out_features: int):
80
+ super().__init__()
81
+ self.gate = nn.Linear(in_features, out_features, bias=False)
82
+
83
+ def forward(self, hidden_states_to_gate, hidden_states):
84
+ gate = self.gate(hidden_states)
85
+ return nn.functional.silu(gate) * hidden_states_to_gate
86
+
87
+
88
+ # Taken from LLaMA codebase
89
+ class RMSNorm(torch.nn.Module):
90
+ def __init__(self, dim: int, eps: float = 1e-6):
91
+ super().__init__()
92
+ self.eps = eps
93
+ self.weight = nn.Parameter(torch.ones(dim))
94
+
95
+ def _norm(self, x):
96
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
97
+
98
+ def forward(self, x):
99
+ output = self._norm(x.float()).type_as(x)
100
+ return output * self.weight
101
+
102
+
103
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
104
+ """
105
+ Make causal mask used for bi-directional self-attention.
106
+ """
107
+ bsz, tgt_len = input_ids_shape
108
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
109
+ mask_cond = torch.arange(mask.size(-1))
110
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
111
+ mask = mask.to(dtype)
112
+
113
+ if past_key_values_length > 0:
114
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
115
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
116
+
117
+
118
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
119
+ """
120
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
121
+ """
122
+ bsz, src_len = mask.size()
123
+ tgt_len = tgt_len if tgt_len is not None else src_len
124
+
125
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
126
+
127
+ inverted_mask = 1.0 - expanded_mask
128
+
129
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
130
+
131
+
132
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
133
+ """
134
+ This module learns positional embeddings up to a fixed maximum size.
135
+ """
136
+
137
+ def __init__(self, num_embeddings: int, embedding_dim: int):
138
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
139
+ # and adjust num_embeddings appropriately. Other models don't have this hack
140
+ self.offset = 2
141
+ super().__init__(num_embeddings + self.offset, embedding_dim)
142
+
143
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
144
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
145
+ attention_mask = attention_mask.long()
146
+
147
+ # create positions depending on attention_mask
148
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
149
+
150
+ # cut positions if `past_key_values_length` is > 0
151
+ positions = positions[:, past_key_values_length:]
152
+
153
+ return super().forward(positions + self.offset)
154
+
155
+
156
+ class OPTAttention(nn.Module):
157
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
158
+
159
+ def __init__(
160
+ self,
161
+ embed_dim: int,
162
+ num_heads: int,
163
+ dropout: float = 0.0,
164
+ is_decoder: bool = False,
165
+ bias: bool = True,
166
+ is_cross_attention=False,
167
+ config=None,
168
+ qk_layer_norms=False,
169
+ ):
170
+ super().__init__()
171
+ self.embed_dim = embed_dim
172
+ self.num_heads = num_heads
173
+ self.dropout = dropout
174
+ self.head_dim = embed_dim // num_heads
175
+
176
+ if (self.head_dim * num_heads) != self.embed_dim:
177
+ raise ValueError(
178
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
179
+ f" and `num_heads`: {num_heads})."
180
+ )
181
+ self.scaling = self.head_dim**-0.5
182
+ self.is_decoder = is_decoder
183
+
184
+ self.is_cross_attention = is_cross_attention
185
+
186
+ if self.is_cross_attention:
187
+ kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
188
+ self.k_proj = nn.Linear(kv_input_dim, embed_dim, bias=bias)
189
+ self.v_proj = nn.Linear(kv_input_dim, embed_dim, bias=bias)
190
+
191
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
192
+ else:
193
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
194
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
195
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
196
+
197
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
198
+
199
+ self.qk_layer_norms = qk_layer_norms
200
+ if self.qk_layer_norms and config.rms_norm:
201
+ self.q_layer_norm = RMSNorm(self.head_dim, eps=1e-6)
202
+ self.k_layer_norm = RMSNorm(self.head_dim, eps=1e-6)
203
+ elif self.qk_layer_norms:
204
+ self.q_layer_norm = nn.LayerNorm(self.head_dim)
205
+ self.k_layer_norm = nn.LayerNorm(self.head_dim)
206
+
207
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
208
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
209
+
210
+ def forward(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ key_value_states: Optional[torch.Tensor] = None,
214
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
215
+ attention_mask: Optional[torch.Tensor] = None,
216
+ layer_head_mask: Optional[torch.Tensor] = None,
217
+ output_attentions: bool = False,
218
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
219
+ """Input shape: Batch x Time x Channel"""
220
+
221
+ # if key_value_states are provided this layer is used as a cross-attention layer
222
+ # for the decoder
223
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
224
+
225
+ bsz, tgt_len, _ = hidden_states.size()
226
+
227
+ # get query proj
228
+ query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
229
+ # get key, value proj
230
+ if is_cross_attention and past_key_value is not None:
231
+ # reuse k,v, cross_attentions
232
+ key_states = past_key_value[0]
233
+ value_states = past_key_value[1]
234
+ elif is_cross_attention:
235
+ # cross_attentions
236
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
237
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
238
+ elif past_key_value is not None:
239
+ # reuse k, v, self_attention
240
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
241
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
242
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
243
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
244
+ else:
245
+ # self_attention
246
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
247
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
248
+
249
+ if self.is_decoder:
250
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
251
+ # Further calls to cross_attention layer can then reuse all cross-attention
252
+ # key/value_states (first "if" case)
253
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
254
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
255
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
256
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
257
+ past_key_value = (key_states, value_states)
258
+
259
+ if self.qk_layer_norms:
260
+ query_states = self.q_layer_norm(query_states)
261
+ key_states = self.k_layer_norm(key_states)
262
+
263
+ src_len = key_states.size(2)
264
+
265
+ if attention_mask is not None:
266
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
267
+ raise ValueError(
268
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
269
+ )
270
+ if layer_head_mask is not None:
271
+ if layer_head_mask.size() != (self.num_heads,):
272
+ raise ValueError(
273
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
274
+ f" {layer_head_mask.size()}"
275
+ )
276
+ attention_mask = attention_mask.expand(-1, self.num_heads, -1, -1)
277
+ attention_mask = attention_mask + layer_head_mask.view(1, -1, 1, 1)
278
+
279
+ attn_output = nn.functional.scaled_dot_product_attention(
280
+ query_states,
281
+ key_states,
282
+ value_states,
283
+ attn_mask=attention_mask,
284
+ dropout_p=self.dropout,
285
+ )
286
+
287
+ attn_weights_reshaped = None
288
+ logger.warning_once(
289
+ "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
290
+ )
291
+ attn_output = attn_output.transpose(1, 2)
292
+
293
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
294
+ # partitioned aross GPUs when using tensor-parallelism.
295
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
296
+
297
+ attn_output = self.out_proj(attn_output)
298
+
299
+ return attn_output, attn_weights_reshaped, past_key_value
300
+
301
+
302
+ class OPTDecoderLayer(nn.Module):
303
+ def __init__(self, config: VOPTConfig):
304
+ super().__init__()
305
+ self.embed_dim = config.hidden_size
306
+ self.self_attn = OPTAttention(
307
+ embed_dim=self.embed_dim,
308
+ num_heads=config.num_attention_heads,
309
+ dropout=config.attention_dropout,
310
+ is_decoder=True,
311
+ config=config,
312
+ )
313
+ self.do_layer_norm_before = config.do_layer_norm_before
314
+ self.dropout = config.dropout
315
+ self.activation_fn = ACT2FN[config.activation_function]
316
+
317
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
318
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
319
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
320
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ layer_head_mask: Optional[torch.Tensor] = None,
327
+ output_attentions: Optional[bool] = False,
328
+ use_cache: Optional[bool] = False,
329
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
330
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
331
+ """
332
+ Args:
333
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
334
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
335
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
336
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
337
+ `(encoder_attention_heads,)`.
338
+ output_attentions (`bool`, *optional*):
339
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
340
+ returned tensors for more detail.
341
+ use_cache (`bool`, *optional*):
342
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
343
+ (see `past_key_values`).
344
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
345
+ """
346
+
347
+ residual = hidden_states
348
+
349
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
350
+ if self.do_layer_norm_before:
351
+ hidden_states = self.self_attn_layer_norm(hidden_states)
352
+
353
+ # Self Attention
354
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
355
+ hidden_states=hidden_states,
356
+ past_key_value=past_key_value,
357
+ attention_mask=attention_mask,
358
+ layer_head_mask=layer_head_mask,
359
+ output_attentions=output_attentions,
360
+ )
361
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
362
+ hidden_states = residual + hidden_states
363
+
364
+ # 350m applies layer norm AFTER attention
365
+ if not self.do_layer_norm_before:
366
+ hidden_states = self.self_attn_layer_norm(hidden_states)
367
+
368
+ # Fully Connected
369
+ hidden_states_shape = hidden_states.shape
370
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
371
+ residual = hidden_states
372
+
373
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
374
+ if self.do_layer_norm_before:
375
+ hidden_states = self.final_layer_norm(hidden_states)
376
+
377
+ hidden_states = self.fc1(hidden_states)
378
+ hidden_states = self.activation_fn(hidden_states)
379
+
380
+ hidden_states = self.fc2(hidden_states)
381
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
382
+
383
+ hidden_states = (residual + hidden_states).view(hidden_states_shape)
384
+
385
+ # 350m applies layer norm AFTER attention
386
+ if not self.do_layer_norm_before:
387
+ hidden_states = self.final_layer_norm(hidden_states)
388
+
389
+ outputs = (hidden_states,)
390
+
391
+ if output_attentions:
392
+ outputs += (self_attn_weights,)
393
+
394
+ if use_cache:
395
+ outputs += (present_key_value,)
396
+
397
+ return outputs
398
+
399
+
400
+ class VOPTGatedAttentionLayer(nn.Module):
401
+ def __init__(self, config: VOPTConfig):
402
+ """
403
+ Note: Based on `tr_101_cm401xPMD09_nobias`, setting the biases to False in all of the nn.Linear for the gated cross attention.
404
+ Provide a small stability gain at opt-13b scale.
405
+ """
406
+ super().__init__()
407
+ self.embed_dim = config.hidden_size
408
+ self.cross_attn = OPTAttention(
409
+ embed_dim=self.embed_dim,
410
+ num_heads=config.num_attention_heads,
411
+ dropout=config.attention_dropout,
412
+ is_decoder=True,
413
+ config=config,
414
+ is_cross_attention=True,
415
+ bias=False,
416
+ qk_layer_norms=config.qk_layer_norms,
417
+ )
418
+ self.do_layer_norm_before = config.do_layer_norm_before
419
+ self.normformer_layer_norms = config.normformer_layer_norms
420
+ self.dropout = config.dropout
421
+ if config.cross_layer_activation_function == "swiglu":
422
+ # We cannot put `SwiGLUActivation` in `ACT2FN` because it takes two arguments (`in_features` and
423
+ # `out_features`) that we don't know until entering this module.
424
+ self.activation_fn = SwiGLUActivation(self.embed_dim, config.ffn_dim)
425
+ else:
426
+ self.activation_fn = ACT2FN[config.cross_layer_activation_function]
427
+
428
+ if config.rms_norm:
429
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim, eps=1e-6)
430
+ else:
431
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
432
+
433
+ if self.normformer_layer_norms:
434
+ self.self_attn_post_layer_norm = nn.LayerNorm(self.embed_dim)
435
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
436
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
437
+
438
+ if config.rms_norm:
439
+ self.final_layer_norm = RMSNorm(self.embed_dim, eps=1e-6)
440
+ else:
441
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
442
+
443
+ if self.normformer_layer_norms:
444
+ self.mlp_post_layer_norm = nn.LayerNorm(config.ffn_dim)
445
+
446
+ self.act_cross_attn = nn.Tanh()
447
+ self.act_dense = nn.Tanh()
448
+
449
+ if config.alpha_initializer == "zeros":
450
+ if config.alpha_type == "vector":
451
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
452
+ self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
453
+ elif config.alpha_type == "float":
454
+ self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
455
+ self.alpha_dense = nn.Parameter(torch.zeros(1))
456
+ else:
457
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
458
+
459
+ elif config.alpha_initializer == "ones":
460
+ if config.alpha_type == "vector":
461
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.embed_dim))
462
+ self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.embed_dim))
463
+ elif config.alpha_type == "float":
464
+ self.alpha_cross_attn = nn.Parameter(torch.ones(1))
465
+ self.alpha_dense = nn.Parameter(torch.ones(1))
466
+ else:
467
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
468
+
469
+ elif config.alpha_initializer in {"normal", "gaussian", "random"}:
470
+ if config.alpha_type == "vector":
471
+ self.alpha_cross_attn = nn.Parameter(
472
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.embed_dim))
473
+ )
474
+ self.alpha_dense = nn.Parameter(
475
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.embed_dim))
476
+ )
477
+ elif config.alpha_type == "float":
478
+ self.alpha_cross_attn = nn.Parameter(
479
+ torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
480
+ )
481
+ self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
482
+ else:
483
+ raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
484
+
485
+ else:
486
+ raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
487
+
488
+ assert hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")
489
+
490
+ def forward(
491
+ self,
492
+ hidden_states: torch.Tensor,
493
+ attention_mask: Optional[torch.Tensor] = None,
494
+ layer_head_mask: Optional[torch.Tensor] = None,
495
+ image_hidden_states: Optional[torch.Tensor] = None,
496
+ image_attention_mask: Optional[torch.FloatTensor] = None,
497
+ output_attentions: Optional[bool] = False,
498
+ use_cache: Optional[bool] = False,
499
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
500
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
501
+ """
502
+ Args:
503
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
504
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
505
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
506
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
507
+ `(encoder_attention_heads,)`.
508
+ output_attentions (`bool`, *optional*):
509
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
510
+ returned tensors for more detail.
511
+ use_cache (`bool`, *optional*):
512
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
513
+ (see `past_key_values`).
514
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
515
+ """
516
+ if image_hidden_states is None:
517
+ raise ValueError(
518
+ "`image_hidden_states` is required for VOPT cross attention module which are visual features to be"
519
+ " conditioned on."
520
+ )
521
+
522
+ if past_key_value is not None:
523
+ raise NotImplementedError("Past key value states are not implemented for VOPT cross attention module.")
524
+
525
+ residual = hidden_states
526
+
527
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
528
+ if self.do_layer_norm_before:
529
+ hidden_states = self.self_attn_layer_norm(hidden_states)
530
+
531
+ # Self Attention
532
+ hidden_states, self_attn_weights, present_key_value = self.cross_attn(
533
+ hidden_states=hidden_states,
534
+ key_value_states=image_hidden_states,
535
+ attention_mask=image_attention_mask,
536
+ layer_head_mask=layer_head_mask,
537
+ output_attentions=output_attentions,
538
+ )
539
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
540
+ if self.normformer_layer_norms:
541
+ hidden_states = self.self_attn_post_layer_norm(hidden_states)
542
+ hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
543
+
544
+ # 350m applies layer norm AFTER attention
545
+ if not self.do_layer_norm_before:
546
+ hidden_states = self.self_attn_layer_norm(hidden_states)
547
+
548
+ # Fully Connected
549
+ hidden_states_shape = hidden_states.shape
550
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
551
+ residual = hidden_states
552
+
553
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
554
+ if self.do_layer_norm_before:
555
+ hidden_states = self.final_layer_norm(hidden_states)
556
+
557
+ hidden_states_to_gate = self.fc1(hidden_states)
558
+ if isinstance(self.activation_fn, SwiGLUActivation):
559
+ hidden_states = self.activation_fn(hidden_states_to_gate, hidden_states)
560
+ else:
561
+ hidden_states = self.activation_fn(hidden_states_to_gate)
562
+
563
+ if self.normformer_layer_norms:
564
+ hidden_states = self.mlp_post_layer_norm(hidden_states)
565
+ hidden_states = self.fc2(hidden_states)
566
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
567
+
568
+ hidden_states = (residual + self.act_dense(self.alpha_dense) * hidden_states).view(hidden_states_shape)
569
+
570
+ # 350m applies layer norm AFTER attention
571
+ if not self.do_layer_norm_before:
572
+ hidden_states = self.final_layer_norm(hidden_states)
573
+
574
+ outputs = (hidden_states,)
575
+
576
+ if output_attentions:
577
+ outputs += (self_attn_weights,)
578
+
579
+ if use_cache:
580
+ outputs += (present_key_value,)
581
+
582
+ return outputs
583
+
584
+
585
+ OPT_START_DOCSTRING = r"""
586
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
587
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
588
+ etc.)
589
+
590
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
591
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
592
+ and behavior.
593
+
594
+ Parameters:
595
+ config ([`VOPTConfig`]):
596
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
597
+ load the weights associated with the model, only the configuration. Check out the
598
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
599
+ """
600
+
601
+
602
+ @add_start_docstrings(
603
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
604
+ OPT_START_DOCSTRING,
605
+ )
606
+ class VOPTPreTrainedModel(VLOOMPreTrainedModelBase):
607
+ config_class = VOPTConfig
608
+ base_model_prefix = "model"
609
+ supports_gradient_checkpointing = True
610
+ _no_split_modules = ["OPTDecoderLayer", "VOPTGatedAttentionLayer", "CLIPEncoderLayer"]
611
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
612
+
613
+ def _init_weights(self, module):
614
+ def init_a_linear(module, mean=0.0, std=self.config.init_std):
615
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
616
+ module.weight.data.normal_(mean=mean, std=std)
617
+ if module.bias is not None:
618
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
619
+ module.bias.data.zero_()
620
+
621
+ if isinstance(module, VOPTGatedAttentionLayer):
622
+ for sub_module_name, sub_module in module.named_modules():
623
+ if isinstance(sub_module, nn.Linear):
624
+ if "fc2" in sub_module_name:
625
+ factor = 2 * self.config.num_hidden_layers
626
+ else:
627
+ factor = 1.0
628
+ init_a_linear(sub_module, std=(0.4 / (sub_module.in_features * factor)) ** 0.5)
629
+ elif isinstance(module, PerceiverResampler):
630
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
631
+ module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.vision_embed_dim) ** 0.5)
632
+ for sub_module_name, sub_module in module.named_modules():
633
+ if isinstance(sub_module, nn.Linear):
634
+ if "c_proj" in sub_module_name:
635
+ factor = 2 * self.config.num_hidden_layers
636
+ else:
637
+ factor = 1.0
638
+ init_a_linear(sub_module, std=(0.4 / (self.config.vision_embed_dim * factor)) ** 0.5)
639
+ elif isinstance(module, nn.Embedding):
640
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
641
+ module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
642
+ if module.padding_idx is not None:
643
+ module.weight.data[module.padding_idx].zero_()
644
+ elif isinstance(module, DecoupledLinear):
645
+ if hasattr(module, "additional_fc"):
646
+ init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)
647
+
648
+ def _set_gradient_checkpointing(self, module, value=False):
649
+ if isinstance(module, (VOPTDecoder)):
650
+ module.gradient_checkpointing = value
651
+
652
+ @classmethod
653
+ def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
654
+ # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
655
+ beheaded_model = model.model if hasattr(model, "model") else model
656
+ cls.override_vision_model(beheaded_model.decoder, vision_model_name, vision_model_params, torch_dtype)
657
+ beheaded_model.freeze_relevant_params(config)
658
+
659
+
660
+ OPT_INPUTS_DOCSTRING = r"""
661
+ Args:
662
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
663
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
664
+ it.
665
+
666
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
667
+ [`PreTrainedTokenizer.__call__`] for details.
668
+
669
+ [What are input IDs?](../glossary#input-ids)
670
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
671
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
672
+
673
+ - 1 for tokens that are **not masked**,
674
+ - 0 for tokens that are **masked**.
675
+
676
+ [What are attention masks?](../glossary#attention-mask)
677
+
678
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
679
+ [`PreTrainedTokenizer.__call__`] for details.
680
+
681
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
682
+ `past_key_values`).
683
+
684
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
685
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
686
+ information on the default strategy.
687
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
688
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
689
+
690
+ - 1 indicates the head is **not masked**,
691
+ - 0 indicates the head is **masked**.
692
+
693
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
694
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
695
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
696
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
697
+
698
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
699
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
700
+
701
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
702
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
703
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
704
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
705
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
706
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
707
+ model's internal embedding lookup matrix.
708
+ use_cache (`bool`, *optional*):
709
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
710
+ `past_key_values`).
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ class VOPTDecoder(VOPTPreTrainedModel):
723
+ """
724
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
725
+
726
+ Args:
727
+ config: VOPTConfig
728
+ """
729
+
730
+ def __init__(self, config: VOPTConfig, vision_model=None):
731
+ super().__init__(config)
732
+ self.config = config
733
+ self.dropout = config.dropout
734
+ self.layerdrop = config.layerdrop
735
+ self.padding_idx = config.pad_token_id
736
+ self.max_target_positions = config.max_position_embeddings
737
+ self.vocab_size = config.vocab_size
738
+
739
+ self.embed_tokens = DecoupledEmbedding(
740
+ num_embeddings=config.vocab_size,
741
+ num_additional_embeddings=config.additional_vocab_size,
742
+ embedding_dim=config.word_embed_proj_dim,
743
+ partially_freeze=config.freeze_text_layers,
744
+ padding_idx=self.padding_idx,
745
+ )
746
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
747
+
748
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
749
+ # this solves the losing of weights in `from_pretrained` on the main model
750
+ self.vision_model = vision_model
751
+
752
+ # Perceiver Resampler
753
+ if config.use_resampler:
754
+ self.perceiver_resampler = PerceiverResampler(
755
+ self.config,
756
+ self.config.vision_embed_dim,
757
+ config.resampler_depth,
758
+ config.resampler_n_heads,
759
+ config.resampler_head_dim,
760
+ config.resampler_n_latents,
761
+ )
762
+
763
+ if config.word_embed_proj_dim != config.hidden_size:
764
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
765
+ else:
766
+ self.project_in = None
767
+
768
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
769
+
770
+ self.cross_layer_interval = config.cross_layer_interval
771
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
772
+ self.gated_cross_attn_layers = nn.ModuleList(
773
+ [VOPTGatedAttentionLayer(config) for i in range(num_cross_layers)]
774
+ )
775
+ self.gradient_checkpointing = False
776
+
777
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
778
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
779
+ # see https://github.com/facebookresearch/metaseq/pull/164
780
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
781
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
782
+ else:
783
+ self.final_layer_norm = None
784
+
785
+ if config.word_embed_proj_dim != config.hidden_size:
786
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
787
+ else:
788
+ self.project_out = None
789
+
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ def get_input_embeddings(self):
794
+ return self.embed_tokens
795
+
796
+ def set_input_embeddings(self, value):
797
+ self.embed_tokens = value
798
+
799
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
800
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
801
+ # create causal mask
802
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
803
+ combined_attention_mask = None
804
+ if input_shape[-1] > 1:
805
+ combined_attention_mask = _make_causal_mask(
806
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
807
+ ).to(inputs_embeds.device)
808
+
809
+ if attention_mask is not None:
810
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
811
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
812
+ inputs_embeds.device
813
+ )
814
+ combined_attention_mask = (
815
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
816
+ )
817
+
818
+ return combined_attention_mask
819
+
820
+ def forward(
821
+ self,
822
+ input_ids: torch.LongTensor = None,
823
+ attention_mask: Optional[torch.Tensor] = None,
824
+ head_mask: Optional[torch.Tensor] = None,
825
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
826
+ inputs_embeds: Optional[torch.FloatTensor] = None,
827
+ pixel_values: Optional[torch.FloatTensor] = None,
828
+ image_embeddings: Optional[torch.FloatTensor] = None,
829
+ image_attention_mask: Optional[torch.Tensor] = None,
830
+ crossblock_head_mask: Optional[torch.Tensor] = None, # TOFO (ls): check if this is needed
831
+ use_cache: Optional[bool] = None,
832
+ output_attentions: Optional[bool] = None,
833
+ output_hidden_states: Optional[bool] = None,
834
+ return_dict: Optional[bool] = None,
835
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
836
+ r"""
837
+ Args:
838
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
839
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
840
+ provide it.
841
+
842
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
843
+ [`PreTrainedTokenizer.__call__`] for details.
844
+
845
+ [What are input IDs?](../glossary#input-ids)
846
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
847
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
848
+
849
+ - 1 for tokens that are **not masked**,
850
+ - 0 for tokens that are **masked**.
851
+
852
+ [What are attention masks?](../glossary#attention-mask)
853
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
854
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
855
+
856
+ - 1 indicates the head is **not masked**,
857
+ - 0 indicates the head is **masked**.
858
+
859
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
860
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
861
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
862
+
863
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
864
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
865
+
866
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
867
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
868
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
869
+
870
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
871
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
872
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
873
+ than the model's internal embedding lookup matrix.
874
+ output_attentions (`bool`, *optional*):
875
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
876
+ returned tensors for more detail.
877
+ output_hidden_states (`bool`, *optional*):
878
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
879
+ for more detail.
880
+ return_dict (`bool`, *optional*):
881
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
882
+ """
883
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
884
+
885
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
886
+ output_hidden_states = (
887
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
888
+ )
889
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
890
+
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ # retrieve input_ids and inputs_embeds
894
+ if input_ids is not None and inputs_embeds is not None:
895
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
896
+ elif input_ids is not None:
897
+ input_shape = input_ids.size()
898
+ input_ids = input_ids.view(-1, input_shape[-1])
899
+ elif inputs_embeds is not None:
900
+ input_shape = inputs_embeds.size()[:-1]
901
+ else:
902
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
903
+
904
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
905
+
906
+ if pixel_values is not None and image_embeddings is not None:
907
+ raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
908
+ elif pixel_values is not None:
909
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
910
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
911
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
912
+ # Get sequence from the vision encoder
913
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
914
+ elif image_embeddings is not None:
915
+ batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
916
+ image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
917
+ image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
918
+
919
+ if self.config.use_resampler:
920
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
921
+ image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
922
+ image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
923
+ # Make image_attention_mask compatible with hidden states
924
+ text_seq_len = image_attention_mask.size(1)
925
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
926
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
927
+ image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
928
+
929
+ if image_hidden_states is not None:
930
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
931
+ image_hidden_shape = (image_batch_size, image_sequence_length)
932
+ if image_attention_mask is None:
933
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
934
+ image_attention_mask = self.invert_attention_mask(image_attention_mask)
935
+ else:
936
+ image_attention_mask = None
937
+
938
+ if inputs_embeds is None:
939
+ inputs_embeds = self.embed_tokens(input_ids)
940
+
941
+ # embed positions
942
+ if attention_mask is None:
943
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
944
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
945
+
946
+ attention_mask = self._prepare_decoder_attention_mask(
947
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
948
+ )
949
+
950
+ if self.project_in is not None:
951
+ inputs_embeds = self.project_in(inputs_embeds)
952
+
953
+ hidden_states = inputs_embeds + pos_embeds
954
+
955
+ # decoder layers
956
+ all_hidden_states = () if output_hidden_states else None
957
+ all_self_attns = () if output_attentions else None
958
+ next_decoder_cache = () if use_cache else None
959
+
960
+ # check if head_mask has a correct number of layers specified if desired
961
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
962
+ if attn_mask is not None:
963
+ if attn_mask.size()[0] != (len(self.layers)):
964
+ raise ValueError(
965
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
966
+ f" {head_mask.size()[0]}."
967
+ )
968
+
969
+ for idx, decoder_layer in enumerate(self.layers):
970
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
971
+ if output_hidden_states:
972
+ all_hidden_states += (hidden_states,)
973
+
974
+ dropout_probability = random.uniform(0, 1)
975
+ if self.training and (dropout_probability < self.layerdrop):
976
+ continue
977
+
978
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
979
+ layer_head_mask = head_mask[idx] if head_mask is not None else None
980
+
981
+ def vblock(
982
+ main_block,
983
+ hidden_states,
984
+ attention_mask,
985
+ layer_head_mask,
986
+ past_key_value,
987
+ image_hidden_states,
988
+ image_attention_mask,
989
+ output_attentions,
990
+ use_cache,
991
+ layer_idx,
992
+ cross_layer_interval,
993
+ gated_cross_attn_layers,
994
+ ):
995
+ # TODO(ls): Add cross attention values to respective lists
996
+ if layer_idx % cross_layer_interval == 0:
997
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
998
+ outputs = xblock(
999
+ hidden_states,
1000
+ attention_mask=attention_mask,
1001
+ layer_head_mask=layer_head_mask,
1002
+ image_hidden_states=image_hidden_states,
1003
+ image_attention_mask=image_attention_mask,
1004
+ output_attentions=output_attentions,
1005
+ use_cache=use_cache,
1006
+ past_key_value=None, # not implemented
1007
+ )
1008
+ hidden_states = outputs[0]
1009
+
1010
+ layer_outputs = main_block(
1011
+ hidden_states,
1012
+ attention_mask=attention_mask,
1013
+ layer_head_mask=layer_head_mask,
1014
+ past_key_value=past_key_value,
1015
+ output_attentions=output_attentions,
1016
+ use_cache=use_cache,
1017
+ )
1018
+
1019
+ return layer_outputs
1020
+
1021
+ if self.gradient_checkpointing and self.training:
1022
+ past_key_value = None
1023
+ if use_cache:
1024
+ logger.warning_once(
1025
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1026
+ )
1027
+ use_cache = False
1028
+
1029
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1030
+ vblock,
1031
+ decoder_layer,
1032
+ hidden_states,
1033
+ attention_mask,
1034
+ layer_head_mask,
1035
+ past_key_value,
1036
+ image_hidden_states,
1037
+ image_attention_mask,
1038
+ output_attentions,
1039
+ use_cache,
1040
+ idx,
1041
+ self.cross_layer_interval,
1042
+ self.gated_cross_attn_layers,
1043
+ )
1044
+ else:
1045
+ layer_outputs = vblock(
1046
+ decoder_layer,
1047
+ hidden_states,
1048
+ attention_mask=attention_mask,
1049
+ layer_head_mask=layer_head_mask,
1050
+ past_key_value=past_key_value,
1051
+ image_hidden_states=image_hidden_states,
1052
+ image_attention_mask=image_attention_mask,
1053
+ output_attentions=output_attentions,
1054
+ use_cache=use_cache,
1055
+ layer_idx=idx,
1056
+ cross_layer_interval=self.cross_layer_interval,
1057
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
1058
+ )
1059
+
1060
+ hidden_states = layer_outputs[0]
1061
+
1062
+ if use_cache:
1063
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1064
+
1065
+ if output_attentions:
1066
+ all_self_attns += (layer_outputs[1],)
1067
+
1068
+ if self.final_layer_norm is not None:
1069
+ hidden_states = self.final_layer_norm(hidden_states)
1070
+
1071
+ if self.project_out is not None:
1072
+ hidden_states = self.project_out(hidden_states)
1073
+
1074
+ # add hidden states from the last decoder layer
1075
+ if output_hidden_states:
1076
+ all_hidden_states += (hidden_states,)
1077
+
1078
+ next_cache = next_decoder_cache if use_cache else None
1079
+ if not return_dict:
1080
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1081
+ return BaseModelOutputWithPast(
1082
+ last_hidden_state=hidden_states,
1083
+ past_key_values=next_cache,
1084
+ hidden_states=all_hidden_states,
1085
+ attentions=all_self_attns,
1086
+ )
1087
+
1088
+
1089
+ @add_start_docstrings(
1090
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
1091
+ OPT_START_DOCSTRING,
1092
+ )
1093
+ class VOPTModel(VOPTPreTrainedModel):
1094
+ def __init__(self, config: VOPTConfig, vision_model=None):
1095
+ super().__init__(config)
1096
+ self.decoder = VOPTDecoder(config, vision_model=vision_model)
1097
+
1098
+ # Initialize weights and apply final processing
1099
+ self.post_init()
1100
+
1101
+ self.freeze_relevant_params(config)
1102
+
1103
+ def freeze_relevant_params(self, config=None):
1104
+ if config is None:
1105
+ config = self.config
1106
+
1107
+ if config.freeze_text_layers:
1108
+ self.freeze_text_layers(config.freeze_text_module_exceptions)
1109
+
1110
+ if config.freeze_vision_layers:
1111
+ freeze_model(self.decoder.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
1112
+
1113
+ def freeze_text_layers(self, module_exceptions):
1114
+ for module in [self.decoder.embed_positions, self.decoder.layers]:
1115
+ freeze_model(module, module_exceptions=module_exceptions)
1116
+
1117
+ if self.decoder.project_out is not None:
1118
+ freeze_model(self.decoder.project_out, module_exceptions=module_exceptions)
1119
+
1120
+ if self.decoder.final_layer_norm is not None:
1121
+ freeze_model(self.decoder.final_layer_norm, module_exceptions=module_exceptions)
1122
+
1123
+ def get_input_embeddings(self):
1124
+ return self.decoder.embed_tokens
1125
+
1126
+ def set_input_embeddings(self, value):
1127
+ self.decoder.embed_tokens = value
1128
+
1129
+ def get_decoder(self):
1130
+ return self.decoder
1131
+
1132
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1133
+ @add_code_sample_docstrings(
1134
+ processor_class=_TOKENIZER_FOR_DOC,
1135
+ checkpoint=_CHECKPOINT_FOR_DOC,
1136
+ output_type=BaseModelOutputWithPast,
1137
+ config_class=_CONFIG_FOR_DOC,
1138
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1139
+ )
1140
+ def forward(
1141
+ self,
1142
+ input_ids: torch.LongTensor = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ head_mask: Optional[torch.Tensor] = None,
1145
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1147
+ pixel_values: Optional[torch.FloatTensor] = None,
1148
+ image_embeddings: Optional[torch.FloatTensor] = None,
1149
+ image_attention_mask: Optional[torch.Tensor] = None,
1150
+ crossblock_head_mask: Optional[torch.Tensor] = None, # TOFO (ls): check if this is needed
1151
+ use_cache: Optional[bool] = None,
1152
+ output_attentions: Optional[bool] = None,
1153
+ output_hidden_states: Optional[bool] = None,
1154
+ return_dict: Optional[bool] = None,
1155
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1156
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1157
+ output_hidden_states = (
1158
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1159
+ )
1160
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1161
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1162
+
1163
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1164
+ decoder_outputs = self.decoder(
1165
+ input_ids=input_ids,
1166
+ attention_mask=attention_mask,
1167
+ head_mask=head_mask,
1168
+ past_key_values=past_key_values,
1169
+ inputs_embeds=inputs_embeds,
1170
+ pixel_values=pixel_values,
1171
+ image_embeddings=image_embeddings,
1172
+ image_attention_mask=image_attention_mask,
1173
+ crossblock_head_mask=crossblock_head_mask,
1174
+ use_cache=use_cache,
1175
+ output_attentions=output_attentions,
1176
+ output_hidden_states=output_hidden_states,
1177
+ return_dict=return_dict,
1178
+ )
1179
+
1180
+ if not return_dict:
1181
+ return decoder_outputs
1182
+
1183
+ return BaseModelOutputWithPast(
1184
+ last_hidden_state=decoder_outputs.last_hidden_state,
1185
+ past_key_values=decoder_outputs.past_key_values,
1186
+ hidden_states=decoder_outputs.hidden_states,
1187
+ attentions=decoder_outputs.attentions,
1188
+ )
1189
+
1190
+
1191
+ class VOPTForCausalLM(VOPTPreTrainedModel):
1192
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1193
+
1194
+ def __init__(self, config, vision_model=None):
1195
+ super().__init__(config)
1196
+
1197
+ # Initialize LM head first so that it is not directly offloaded to the CPU/disk
1198
+ # the lm_head weight is automatically tied to the embed tokens weight
1199
+ self.lm_head = DecoupledLinear(
1200
+ in_features=config.word_embed_proj_dim,
1201
+ out_features=config.vocab_size,
1202
+ out_additional_features=config.additional_vocab_size,
1203
+ bias=False,
1204
+ partially_freeze=config.freeze_lm_head,
1205
+ )
1206
+
1207
+ self.model = VOPTModel(config, vision_model=vision_model)
1208
+ # Initialize weights and apply final processing
1209
+ self.post_init()
1210
+
1211
+ def tie_weights(self):
1212
+ """
1213
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1214
+ """
1215
+ output_embeddings = self.get_output_embeddings()
1216
+ input_embeddings = self.get_input_embeddings()
1217
+
1218
+ if getattr(self.config, "tie_word_embeddings", True):
1219
+ output_embeddings.weight = input_embeddings.weight
1220
+ if input_embeddings.num_additional_embeddings > 0:
1221
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1222
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1223
+
1224
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1225
+ output_embeddings.out_features = input_embeddings.num_embeddings
1226
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1227
+ input_embeddings, "num_additional_embeddings"
1228
+ ):
1229
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1230
+
1231
+ def get_input_embeddings(self):
1232
+ return self.model.decoder.embed_tokens
1233
+
1234
+ def set_input_embeddings(self, value):
1235
+ self.model.decoder.embed_tokens = value
1236
+
1237
+ def get_output_embeddings(self):
1238
+ return self.lm_head
1239
+
1240
+ def set_output_embeddings(self, new_embeddings):
1241
+ self.lm_head = new_embeddings
1242
+
1243
+ def set_decoder(self, decoder):
1244
+ self.model.decoder = decoder
1245
+
1246
+ def get_decoder(self):
1247
+ return self.model.decoder
1248
+
1249
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1250
+ def forward(
1251
+ self,
1252
+ input_ids: torch.LongTensor = None,
1253
+ attention_mask: Optional[torch.Tensor] = None,
1254
+ head_mask: Optional[torch.Tensor] = None,
1255
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1257
+ pixel_values: Optional[torch.FloatTensor] = None,
1258
+ image_embeddings: Optional[torch.FloatTensor] = None,
1259
+ image_attention_mask: Optional[torch.Tensor] = None,
1260
+ crossblock_head_mask: Optional[torch.Tensor] = None,
1261
+ labels: Optional[torch.LongTensor] = None,
1262
+ use_cache: Optional[bool] = None,
1263
+ output_attentions: Optional[bool] = None,
1264
+ output_hidden_states: Optional[bool] = None,
1265
+ return_dict: Optional[bool] = None,
1266
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1267
+ r"""
1268
+ Args:
1269
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1270
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1271
+ provide it.
1272
+
1273
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1274
+ [`PreTrainedTokenizer.__call__`] for details.
1275
+
1276
+ [What are input IDs?](../glossary#input-ids)
1277
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1278
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1279
+
1280
+ - 1 for tokens that are **not masked**,
1281
+ - 0 for tokens that are **masked**.
1282
+
1283
+ [What are attention masks?](../glossary#attention-mask)
1284
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
1285
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1286
+
1287
+ - 1 indicates the head is **not masked**,
1288
+ - 0 indicates the head is **masked**.
1289
+
1290
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1291
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1292
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1293
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1294
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1295
+
1296
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1297
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1298
+
1299
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1300
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1301
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1302
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1303
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1304
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1305
+ than the model's internal embedding lookup matrix.
1306
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1307
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1308
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1309
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1310
+ use_cache (`bool`, *optional*):
1311
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1312
+ (see `past_key_values`).
1313
+ output_attentions (`bool`, *optional*):
1314
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1315
+ returned tensors for more detail.
1316
+ output_hidden_states (`bool`, *optional*):
1317
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1318
+ for more detail.
1319
+ return_dict (`bool`, *optional*):
1320
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1321
+
1322
+ Returns:
1323
+
1324
+ Example:
1325
+
1326
+ ```python
1327
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
1328
+
1329
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
1330
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
1331
+
1332
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1333
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1334
+
1335
+ >>> # Generate
1336
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1337
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1338
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1339
+ ```"""
1340
+
1341
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1342
+ output_hidden_states = (
1343
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1344
+ )
1345
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1346
+
1347
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1348
+ outputs = self.model.decoder(
1349
+ input_ids=input_ids,
1350
+ attention_mask=attention_mask,
1351
+ head_mask=head_mask,
1352
+ past_key_values=past_key_values,
1353
+ inputs_embeds=inputs_embeds,
1354
+ pixel_values=pixel_values,
1355
+ image_embeddings=image_embeddings,
1356
+ image_attention_mask=image_attention_mask,
1357
+ crossblock_head_mask=crossblock_head_mask,
1358
+ use_cache=use_cache,
1359
+ output_attentions=output_attentions,
1360
+ output_hidden_states=output_hidden_states,
1361
+ return_dict=return_dict,
1362
+ )
1363
+
1364
+ logits = self.lm_head(outputs[0]).contiguous()
1365
+
1366
+ loss = None
1367
+ if labels is not None:
1368
+ # Shift so that tokens < n predict n
1369
+ if attention_mask is not None:
1370
+ shift_attention_mask = attention_mask[..., 1:]
1371
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1372
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1373
+ else:
1374
+ shift_logits = logits[..., :-1, :].contiguous()
1375
+ shift_labels = labels[..., 1:].contiguous()
1376
+ # Flatten the tokens
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1379
+
1380
+ if not return_dict:
1381
+ output = (logits,) + outputs[1:]
1382
+ return (loss,) + output if loss is not None else output
1383
+
1384
+ return CausalLMOutputWithPast(
1385
+ loss=loss,
1386
+ logits=logits,
1387
+ past_key_values=outputs.past_key_values,
1388
+ hidden_states=outputs.hidden_states,
1389
+ attentions=outputs.attentions,
1390
+ )
1391
+
1392
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1393
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1394
+ unwanted_kwargs = ["position_ids", "token_type_ids"]
1395
+ for kwarg in unwanted_kwargs:
1396
+ inputs.pop(kwarg, None)
1397
+ return inputs
1398
+
1399
+ @staticmethod
1400
+ def _expand_inputs_for_generation(
1401
+ *args,
1402
+ **model_kwargs,
1403
+ ):
1404
+ return expand_inputs_for_generation(*args, **model_kwargs)
1405
+
1406
+ @staticmethod
1407
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
1408
+ return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
1409
+
1410
+ @staticmethod
1411
+ def _reorder_cache(past, beam_idx):
1412
+ reordered_past = ()
1413
+ for layer_past in past:
1414
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1415
+ return reordered_past
1416
+
1417
+ def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
1418
+ config_vl_model = self.config
1419
+
1420
+ language_embed_size = config_vl_model.hidden_size
1421
+ num_language_layers = config_vl_model.num_hidden_layers
1422
+ ffn_inner_size = config_vl_model.ffn_dim
1423
+
1424
+ vision_config = self.model.decoder.vision_model.config
1425
+ if hasattr(vision_config, "vision_config"):
1426
+ vision_config = vision_config.vision_config
1427
+
1428
+ # Get vision model blocks infos
1429
+ vision_patch_size = vision_config.patch_size
1430
+ vision_hidden_size = vision_config.hidden_size
1431
+ num_vision_layers = vision_config.num_hidden_layers
1432
+ # The +1 is for the CLS token
1433
+ single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
1434
+ vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
1435
+
1436
+ # Get language and cross-att blocks infos
1437
+ num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
1438
+ language_seq_len = data_param.max_seq_len
1439
+ language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1440
+ cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
1441
+ k_v_cross_attn_seq_len = (
1442
+ (self.config.resampler_n_latents * max_num_images)
1443
+ if self.config.use_resampler
1444
+ else (single_image_seq_len * max_num_images)
1445
+ )
1446
+
1447
+ language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1448
+ num_layers=num_language_layers,
1449
+ batch_size=hparams.batch_size_per_gpu,
1450
+ q_seq_len=language_seq_len,
1451
+ k_seq_len=language_seq_len,
1452
+ hidden_size=language_embed_size,
1453
+ kv_in_dim=language_embed_size,
1454
+ ff_exp_factor=language_exp_factor,
1455
+ grad_acc_size=hparams.grad_acc_size,
1456
+ swiglu=False,
1457
+ vocab_size=tokenizer.vocab_size,
1458
+ count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
1459
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1460
+ )
1461
+ cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1462
+ num_layers=num_cross_attn_layers,
1463
+ batch_size=hparams.batch_size_per_gpu,
1464
+ q_seq_len=language_seq_len,
1465
+ k_seq_len=k_v_cross_attn_seq_len,
1466
+ hidden_size=language_embed_size,
1467
+ kv_in_dim=vision_hidden_size,
1468
+ ff_exp_factor=cross_att_exp_factor,
1469
+ grad_acc_size=hparams.grad_acc_size,
1470
+ swiglu=self.config.cross_layer_activation_function == "swiglu",
1471
+ vocab_size=None,
1472
+ count_backward=True,
1473
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1474
+ )
1475
+ vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
1476
+ num_layers=num_vision_layers,
1477
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1478
+ q_seq_len=single_image_seq_len,
1479
+ k_seq_len=single_image_seq_len,
1480
+ hidden_size=vision_hidden_size,
1481
+ kv_in_dim=vision_hidden_size,
1482
+ ff_exp_factor=vision_exp_factor,
1483
+ grad_acc_size=hparams.grad_acc_size,
1484
+ swiglu=False,
1485
+ vocab_size=None,
1486
+ count_backward=not hparams.model_params["freeze_vision_layers"],
1487
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1488
+ )
1489
+ if self.config.use_resampler:
1490
+ perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
1491
+ num_layers=self.config.resampler_depth,
1492
+ batch_size=hparams.batch_size_per_gpu * max_num_images,
1493
+ q_seq_len=self.config.resampler_n_latents,
1494
+ vision_embed_seq_len=single_image_seq_len,
1495
+ q_k_v_input_dim=vision_hidden_size,
1496
+ attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
1497
+ ff_exp_factor=cross_att_exp_factor,
1498
+ count_backward=True,
1499
+ use_grad_checkpointing=hparams.gradient_checkpointing,
1500
+ )
1501
+ flop_count = (
1502
+ language_tflops_per_batch_per_gpu
1503
+ + cross_attention_tflops_per_batch_per_gpu
1504
+ + vision_tflops_per_batch_per_gpu
1505
+ + perceiver_tflops_per_batch_per_gpu
1506
+ )
1507
+ else:
1508
+ flop_count = (
1509
+ language_tflops_per_batch_per_gpu
1510
+ + cross_attention_tflops_per_batch_per_gpu
1511
+ + vision_tflops_per_batch_per_gpu
1512
+ )
1513
+ return flop_count
m4/models/vt5/__init__.py ADDED
File without changes
m4/models/vt5/configuration_vt5.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020, The T5 Authors and HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ VT5 model configuration"""
16
+ import os
17
+ from typing import Tuple, Union
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "t5-small": "https://huggingface.co/t5-small/resolve/main/config.json",
28
+ "t5-base": "https://huggingface.co/t5-base/resolve/main/config.json",
29
+ "t5-large": "https://huggingface.co/t5-large/resolve/main/config.json",
30
+ "t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json",
31
+ "t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json",
32
+ }
33
+
34
+
35
+ class VT5Config(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
38
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
39
+ configuration with the defaults will yield a similar configuration to that of the T5
40
+ [t5-small](https://huggingface.co/t5-small) architecture.
41
+
42
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
43
+ documentation from [`PretrainedConfig`] for more information.
44
+
45
+ TODO: this doc is completely out of sync with the actual args
46
+
47
+ Arguments:
48
+ vocab_size (`int`, *optional*, defaults to 32128):
49
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
50
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
51
+ d_model (`int`, *optional*, defaults to 512):
52
+ Size of the encoder layers and the pooler layer.
53
+ d_kv (`int`, *optional*, defaults to 64):
54
+ Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
55
+ num_heads`.
56
+ d_ff (`int`, *optional*, defaults to 2048):
57
+ Size of the intermediate feed forward layer in each `T5Block`.
58
+ num_layers (`int`, *optional*, defaults to 6):
59
+ Number of hidden layers in the Transformer encoder.
60
+ num_decoder_layers (`int`, *optional*):
61
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
62
+ num_heads (`int`, *optional*, defaults to 8):
63
+ Number of attention heads for each attention layer in the Transformer encoder.
64
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
65
+ The number of buckets to use for each attention layer.
66
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
67
+ The maximum distance of the longer sequences for the bucket separation.
68
+ dropout_rate (`float`, *optional*, defaults to 0.1):
69
+ The ratio for all dropout layers.
70
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
71
+ The epsilon used by the layer normalization layers.
72
+ initializer_factor (`float`, *optional*, defaults to 1):
73
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
74
+ testing).
75
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
76
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
77
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models).
80
+ additional_vocab_size (`int`, *optional`, defaults to 0):
81
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
82
+ are always trainable whereas regular vocab tokens can be frozen or not.
83
+ alpha_initializer (`str`, *optional*, defaults to `"ones"`):
84
+ Initialization type for the alphas.
85
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
86
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
87
+ alpha_type (`str`, *optional*, defaults to `"vector"`):
88
+ Whether the gating alphas should be vectors or single floats.
89
+ """
90
+ model_type = "vt5"
91
+ keys_to_ignore_at_inference = ["past_key_values"]
92
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=32128,
97
+ d_model=512,
98
+ d_kv=64,
99
+ d_ff=2048,
100
+ num_layers=6,
101
+ num_decoder_layers=None,
102
+ num_heads=8,
103
+ relative_attention_num_buckets=32,
104
+ relative_attention_max_distance=128,
105
+ dropout_rate=0.1,
106
+ layer_norm_epsilon=1e-6,
107
+ initializer_factor=1.0,
108
+ feed_forward_proj="relu",
109
+ is_encoder_decoder=True,
110
+ use_cache=True,
111
+ pad_token_id=0,
112
+ eos_token_id=1,
113
+ additional_vocab_size=0,
114
+ alpha_initializer="ones",
115
+ alphas_initializer_range=0.0,
116
+ alpha_type="vector",
117
+ cross_layer_interval=1,
118
+ tie_word_embeddings=False,
119
+ freeze_text_layers=True,
120
+ freeze_lm_head=False,
121
+ freeze_vision_layers=True,
122
+ vision_model_name="google/vit-base-patch16-224",
123
+ vision_model_params="{}",
124
+ vision_embed_dim=768,
125
+ image_token_index=32128,
126
+ **kwargs,
127
+ ):
128
+ self.vocab_size = vocab_size
129
+ self.additional_vocab_size = additional_vocab_size
130
+ self.d_model = d_model
131
+ self.d_kv = d_kv
132
+ self.d_ff = d_ff
133
+ self.num_layers = num_layers
134
+ self.num_decoder_layers = (
135
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
136
+ ) # default = symmetry
137
+ self.num_heads = num_heads
138
+ self.relative_attention_num_buckets = relative_attention_num_buckets
139
+ self.relative_attention_max_distance = relative_attention_max_distance
140
+ self.dropout_rate = dropout_rate
141
+ self.layer_norm_epsilon = layer_norm_epsilon
142
+ self.initializer_factor = initializer_factor
143
+ self.feed_forward_proj = feed_forward_proj
144
+ self.use_cache = use_cache
145
+
146
+ act_info = self.feed_forward_proj.split("-")
147
+ self.dense_act_fn = act_info[-1]
148
+ self.is_gated_act = act_info[0] == "gated"
149
+
150
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
151
+ raise ValueError(
152
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
153
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
154
+ "'gated-gelu' or 'relu'"
155
+ )
156
+
157
+ # for backwards compatibility
158
+ if feed_forward_proj == "gated-gelu":
159
+ self.dense_act_fn = "gelu_new"
160
+
161
+ self.alpha_initializer = alpha_initializer
162
+ self.alphas_initializer_range = alphas_initializer_range
163
+ self.alpha_type = alpha_type
164
+
165
+ self.cross_layer_interval = cross_layer_interval
166
+ self.freeze_vision_layers = freeze_vision_layers
167
+ self.vision_model_name = vision_model_name
168
+ self.vision_model_params = vision_model_params
169
+
170
+ self.tie_word_embeddings = tie_word_embeddings
171
+ self.freeze_text_layers = freeze_text_layers
172
+ self.freeze_lm_head = freeze_lm_head
173
+ self.image_token_index = image_token_index
174
+
175
+ self.vision_embed_dim = vision_embed_dim
176
+
177
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
178
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
179
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
180
+ # of this object many attributes have default values and haven't yet been overridden.
181
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
182
+
183
+ super().__init__(
184
+ pad_token_id=pad_token_id,
185
+ eos_token_id=eos_token_id,
186
+ is_encoder_decoder=is_encoder_decoder,
187
+ tie_word_embeddings=tie_word_embeddings,
188
+ **kwargs,
189
+ )
190
+
191
+ def check_compatibilities(self):
192
+ if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
193
+ raise ValueError(
194
+ "if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
195
+ )
196
+
197
+ vision_model_params = eval(self.vision_model_params)
198
+ config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
199
+ if hasattr(config, "vision_config"):
200
+ vison_config = config.vision_config
201
+ else:
202
+ vison_config = config
203
+ vision_embed_dim = vison_config.hidden_size
204
+ if self.vision_embed_dim != vision_embed_dim:
205
+ raise ValueError(
206
+ f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
207
+ f" ({vision_embed_dim})"
208
+ )
209
+
210
+ @classmethod
211
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
212
+ outputs = super(VT5Config, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
213
+ if isinstance(outputs, Tuple):
214
+ # When called with return_unused_kwargs=True, the first item will be the config
215
+ outputs[0].check_compatibilities()
216
+ else:
217
+ outputs.check_compatibilities()
218
+ return outputs
m4/models/vt5/modeling_vt5.py ADDED
The diff for this file is too large to render. See raw diff
 
m4/models/zero_checkpoint_to_hf.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This script combines the 2 steps of
4
+ # 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint
5
+ # 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints
6
+ #
7
+ # example usage:
8
+ #
9
+ # this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model
10
+ #
11
+ # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40
12
+ #
13
+ # or you can override the destination by passing an explicit target dir, e.g.:
14
+ #
15
+ # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir
16
+
17
+ import argparse
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
23
+
24
+
25
+ # auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH
26
+ repodir = str(Path(__file__).resolve().parents[2])
27
+ sys.path.insert(0, repodir)
28
+
29
+ import m4.models
30
+ from m4.testing_utils import read_json_file
31
+
32
+
33
+ if __name__ == "__main__":
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100"
37
+ )
38
+ parser.add_argument(
39
+ "output_dir",
40
+ type=str,
41
+ nargs="?",
42
+ help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument",
43
+ )
44
+ args = parser.parse_args()
45
+
46
+ checkpoint_dir = Path(args.checkpoint_dir)
47
+ config_dir = checkpoint_dir / "unwrapped_model"
48
+ ds_checkpoint_dir = checkpoint_dir / "accelerator_state"
49
+ config_file_path = config_dir / "config.json"
50
+
51
+ if args.output_dir is None:
52
+ output_dir = checkpoint_dir / "unwrapped_model"
53
+ else:
54
+ output_dir = args.output_dir
55
+
56
+ config = read_json_file(config_file_path)
57
+ config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None)
58
+ if config_class is None:
59
+ raise ValueError(f"{config['model_type']=} isn't supported by m4")
60
+ modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None)
61
+
62
+ print(f"Detected {config_class}")
63
+
64
+ print("Reconsolidating fp32 model from checkpoint shards (can take a long time)")
65
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu
66
+
67
+ # Keeping debug to use if you ever need to debug state dict
68
+ # print("Saved State Dict")
69
+ # for k, v in state_dict.items():
70
+ # print(f"{k} {v.shape}")
71
+
72
+ kwargs = {}
73
+ print(f"Loading config from {config_dir}")
74
+ model_config = config_class.from_pretrained(config_dir)
75
+
76
+ print(f"Instantiating a {modeling_class} model in bf16")
77
+ model = modeling_class.from_pretrained(
78
+ None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16
79
+ )
80
+
81
+ # Keeping debug to use if you ever need to debug state dict
82
+ # print("Model State Dict")
83
+ # for k, v in model.state_dict().items():
84
+ # print(f"{k} {v.shape}")
85
+
86
+ print(f"Saving model to {output_dir}")
87
+ model.save_pretrained(output_dir)
m4/scripts/cleanup-checkpoints.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+ # This tool deletes checkpoints found at given path that are no longer needed
5
+ #
6
+ # we have 2 parts to each checkpoints to cleanup
7
+ #
8
+ # 1. the original deepspeed checkpoint
9
+ # 2. the converted hf checkpoint
10
+ #
11
+ # we will start with a combined requirement for eval to be completed and s3 synced to nuke the checkpoint
12
+ #
13
+ # Example:
14
+ #
15
+ # ./cleanup-checkpoints.py checkpoints-path
16
+ #
17
+ # Use `-h` for more options
18
+
19
+ import argparse
20
+ import shutil # noqa
21
+ import subprocess
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+
26
+
27
+ repo_path = Path(__file__).parents[2]
28
+
29
+ # we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
30
+ # rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
31
+ #
32
+ # If that file is there it means:
33
+ #
34
+ # 1. either the cleanup is still running
35
+ # 2. the cleanup got aborted (e.g. cpu-oom)
36
+ #
37
+ # to detect aborted cleanups we will check if the control file is older than a reasonable time to perform such a cleanup
38
+ control_file_name = "started-cleanup-checkpoint"
39
+ finished_uploading_file_name = "finished-upload-checkpoint"
40
+ # should fine tune - but surely 1h per checkpoint is plenty
41
+ reasonable_cleanup_time_in_secs = 1 * 60 * 60
42
+
43
+
44
+ def run_cmd(cmd, check=True):
45
+ try:
46
+ response = subprocess.run(
47
+ cmd,
48
+ stderr=subprocess.PIPE,
49
+ stdout=subprocess.PIPE,
50
+ check=check,
51
+ encoding="utf-8",
52
+ ).stdout.strip()
53
+ except subprocess.CalledProcessError as exc:
54
+ raise EnvironmentError(exc.stderr)
55
+
56
+ return response
57
+
58
+
59
+ def get_args():
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
62
+ parser.add_argument("--skip-evals-check", action="store_true", help="skip evals done checks")
63
+ return parser.parse_args()
64
+
65
+
66
+ def exit(msg):
67
+ print(msg)
68
+ sys.exit()
69
+
70
+
71
+ def should_process(path, control_file_path, args):
72
+ """Heuristics to decide whether to cleanup this opt_step-XXX checkpoint or not"""
73
+
74
+ s3_completed_path = path / finished_uploading_file_name
75
+ eval_completed_paths = [
76
+ path / "run_evals_0_shots_done",
77
+ path / "run_evals_4_shots_done",
78
+ path / "run_evals_perplexity_validation_done",
79
+ path / "run_evals_0_shots_a_la_flamingo_done",
80
+ ]
81
+
82
+ # check s3 sync is completed
83
+ if not s3_completed_path.exists():
84
+ print(f"[N] {path} hasn't been synced to s3 yet. Skipping")
85
+ return False
86
+
87
+ # check evals are completed
88
+ if not args.skip_evals_check:
89
+ for eval_path in eval_completed_paths:
90
+ if not eval_path.exists():
91
+ print(f"[N] {path} hasn't been evaled yet. Skipping")
92
+ return False
93
+
94
+ # complicated checks - has another job already started processing? or did it crash?
95
+ if control_file_path.exists():
96
+ if control_file_path.stat().st_mtime < time.time() - reasonable_cleanup_time_in_secs:
97
+ print(f"[Y] {path} looks stale - probably aborted cleanup job. Deleting")
98
+ return True
99
+ else:
100
+ print(
101
+ f"[N] {path} either another job is doing the cleanup or less than"
102
+ f" {reasonable_cleanup_time_in_secs} secs has passed since it was launched. Skipping"
103
+ )
104
+ return False
105
+ else:
106
+ print(f"[Y] {path} completed s3 sync + eval. Deleting")
107
+ return True
108
+
109
+
110
+ def main():
111
+ args = get_args()
112
+
113
+ checkpoints_path = Path(args.checkpoints_path)
114
+ if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
115
+ raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
116
+
117
+ checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
118
+ if len(checkpoint_dirs) == 0:
119
+ exit("No checkpoints found, exiting")
120
+
121
+ # Check each checkpoint folder in real time to allow for overlapping jobs starting at different times
122
+ # Additionally do not delete the last 2 checkpoints
123
+ #
124
+ # sort numerically to sort correctly different number of digits: opt_step-10, opt_step-100
125
+ checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
126
+ for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
127
+ print(f"\n*** Checking {checkpoint_dir}")
128
+
129
+ if i + 1 == len(checkpoint_dirs_sorted):
130
+ print(f"[N] {checkpoint_dir} is a last checkpoint. Skipping")
131
+ continue
132
+
133
+ if i + 2 == len(checkpoint_dirs_sorted):
134
+ print(f"[N] {checkpoint_dir} is a second to last checkpoint. Skipping")
135
+ continue
136
+
137
+ control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name
138
+
139
+ if not should_process(checkpoint_dir, control_file_path, args):
140
+ continue
141
+
142
+ print(f"Launching cleanup for {checkpoint_dir}")
143
+ # we could use flock here, to avoid a race condition, but it'd be pointless since each
144
+ # cronjob is likely to run on a different node and flock only works within a single node
145
+ control_file_path.touch()
146
+
147
+ # cleanup
148
+ # XXX: enable the actual delete once tested a lot
149
+ # The delete should be relatively safe since it'll only run if it finds 2 files:
150
+ # save_dir/opt_step-XXX/s3_sync_is_completed save_dir/opt_step-XXX/eval_is_completed
151
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
152
+ print(f"Checkpoint {checkpoint_dir} deleted")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
m4/scripts/convert-checkpoints.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+ # This tool converts any deepspeed checkpoints found at given path to hf format
5
+ #
6
+ # Example:
7
+ #
8
+ # ./convert-checkpoints.py checkpoints-path
9
+ #
10
+
11
+ import argparse
12
+ import subprocess
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+
17
+
18
+ repo_path = Path(__file__).parents[2]
19
+ zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
20
+
21
+ # we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
22
+ # rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
23
+ #
24
+ # If that file is there it means:
25
+ #
26
+ # 1. either the conversion is still running
27
+ # 2. the conversion got aborted (e.g. cpu-oom)
28
+ #
29
+ # to detect aborted conversions we will check if the control file is older than a reasonable time to perform such a conversion
30
+ control_file_name = "started-convert-checkpoint"
31
+ # should fine tune - but surely 2h per checkpoint is plenty
32
+ reasonable_conversion_time_in_secs = 2 * 60 * 60
33
+
34
+
35
+ def run_cmd(cmd, check=True):
36
+ try:
37
+ response = subprocess.run(
38
+ cmd,
39
+ stderr=subprocess.PIPE,
40
+ stdout=subprocess.PIPE,
41
+ check=check,
42
+ encoding="utf-8",
43
+ ).stdout.strip()
44
+ except subprocess.CalledProcessError as exc:
45
+ raise EnvironmentError(exc.stderr)
46
+
47
+ return response
48
+
49
+
50
+ def get_args():
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
53
+ parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
54
+ return parser.parse_args()
55
+
56
+
57
+ def exit(msg):
58
+ print(msg)
59
+ sys.exit()
60
+
61
+
62
+ def should_process(path, force, control_file_path):
63
+ """Heuristics to decide whether to convert this opt_step-XXX checkpoint or not"""
64
+
65
+ target_dir = path / "unwrapped_model"
66
+
67
+ # easy checks - the conversion is clearly completed
68
+ if (target_dir / "pytorch_model.bin").exists() or (target_dir / "pytorch_model.bin.index.json").exists():
69
+ print(f"[N] {path} appears to be already converted. Skipping")
70
+ return False
71
+
72
+ if force:
73
+ print("[Y] Forced to re-convert {checkpoint_dir}")
74
+ return True
75
+
76
+ # complicated checks - has another job already started processing? or did it crash?
77
+ control_file_path = target_dir / control_file_name
78
+ if control_file_path.exists():
79
+ if control_file_path.stat().st_mtime < time.time() - reasonable_conversion_time_in_secs:
80
+ print(f"[Y] {path} looks stale - probably aborted job. Re-converting")
81
+ return True
82
+ else:
83
+ print(
84
+ f"[N] {path} either another job is converting it or less than"
85
+ f" {reasonable_conversion_time_in_secs} secs has passed since it was launched. Skipping"
86
+ )
87
+ return False
88
+ else:
89
+ print(f"[Y] {path} is a new checkpoint. Converting")
90
+ return True
91
+
92
+
93
+ def main():
94
+ args = get_args()
95
+
96
+ checkpoints_path = Path(args.checkpoints_path)
97
+ if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
98
+ raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
99
+
100
+ checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
101
+ if len(checkpoint_dirs) == 0:
102
+ exit("No checkpoints found, exiting")
103
+
104
+ # Check each folder in real time to allow for overlapping jobs starting at different times
105
+ for checkpoint_dir in checkpoint_dirs:
106
+ print(f"\n*** Checking {checkpoint_dir}")
107
+
108
+ control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name
109
+
110
+ if not should_process(checkpoint_dir, args.force, control_file_path):
111
+ continue
112
+
113
+ print(f"Launching conversion for {checkpoint_dir} - it could take a long time")
114
+ cmd = [zero_checkpoint_to_hf_path, checkpoint_dir]
115
+ # we could use flock here, to avoid a race condition, but it'd be pointless since each
116
+ # cronjob is likely to run on a different node and flock only works within a single node
117
+ control_file_path.touch()
118
+ response = run_cmd(cmd)
119
+ control_file_path.unlink()
120
+ print(response)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
m4/scripts/s3-upload-checkpoints.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+ # This tool uploads any new deepspeed checkpoints found at given path to s3 (and also various non-checkpoint files, like logs)
5
+ #
6
+ # Example:
7
+ #
8
+ # ./s3-upload-checkpoints.py checkpoints-path
9
+ #
10
+ # Use `-h` for more options
11
+ #
12
+
13
+
14
+ import argparse
15
+ import subprocess
16
+ import sys
17
+ import time
18
+ from pathlib import Path
19
+
20
+
21
+ repo_path = Path(__file__).resolve().parents[2]
22
+ zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
23
+
24
+ RETRIES = 5
25
+
26
+ # what dir/file glob patterns to include in the upload besides checkpoints
27
+ include_patterns = ["tb_run_*", "logs", "config.yaml"]
28
+
29
+
30
+ # we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
31
+ # rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
32
+ #
33
+ # If that file is there it means:
34
+ #
35
+ # 1. either the upload is still running
36
+ # 2. the upload got aborted (e.g. cpu-oom)
37
+ #
38
+ # to detect aborted uploads we will check if the control file is older than a reasonable time to perform such a upload
39
+ control_file_name = "started-upload-checkpoint"
40
+ finished_uploading_file_name = "finished-upload-checkpoint"
41
+ # should fine tune - but surely 2h per checkpoint is plenty
42
+ reasonable_upload_time_in_secs = 2 * 60 * 60
43
+
44
+
45
+ def run_cmd(cmd, check=True):
46
+ try:
47
+ response = subprocess.run(
48
+ cmd,
49
+ stderr=subprocess.PIPE,
50
+ stdout=subprocess.PIPE,
51
+ check=check,
52
+ encoding="utf-8",
53
+ ).stdout.strip()
54
+ except subprocess.CalledProcessError as exc:
55
+ raise EnvironmentError(exc.stderr)
56
+
57
+ return response
58
+
59
+
60
+ def get_args():
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
63
+ # parser.add_argument("experiment_name", type=str, help="experiment name as a s3 sub-dir")
64
+ parser.add_argument("-f", "--force", action="store_true", help="force uploading of all checkpoints")
65
+ parser.add_argument(
66
+ "--skip-conversion-check", action="store_true", help="skip checkpoint conversion is done check"
67
+ )
68
+ return parser.parse_args()
69
+
70
+
71
+ def exit(msg):
72
+ print(msg)
73
+ sys.exit()
74
+
75
+
76
+ def should_process(path, force, control_file_path, finished_uploading_file_path, args):
77
+ """Heuristics to decide whether to upload this opt_step-XXX checkpoint or not"""
78
+
79
+ # check if checkpoint is fully saved
80
+ finished_saving_path = path / "finished-saving" # defined in from trainer.py
81
+ if not finished_saving_path.exists():
82
+ print(f"[N] {path} isn't finished saving. Skipping")
83
+ return False
84
+
85
+ if force:
86
+ print("[Y] Forced to re-process {checkpoint_dir}")
87
+ return True
88
+
89
+ # check if already uploaded
90
+ if finished_uploading_file_path.exists():
91
+ print(f"[N] {path} has already been uploaded. Skipping")
92
+ return False
93
+
94
+ # check conversion is completed
95
+ if not args.skip_conversion_check:
96
+ converted_model_path_1 = path / "unwrapped_model" / "pytorch_model.bin.index.json"
97
+ converted_model_path_2 = path / "unwrapped_model" / "pytorch_model.bin"
98
+ if not converted_model_path_1.exists() and not converted_model_path_2.exists():
99
+ print(f"[N] {path} doesn't have a converted model. Skipping")
100
+ return False
101
+
102
+ # complicated checks - has another job already started uploading? or did it crash?
103
+ if control_file_path.exists():
104
+ if control_file_path.stat().st_mtime < time.time() - reasonable_upload_time_in_secs:
105
+ print(f"[Y] {path} looks stale - probably aborted job. Re-uploading")
106
+ return True
107
+ else:
108
+ print(
109
+ f"[N] {path} either another job is uploading it or less than"
110
+ f" {reasonable_upload_time_in_secs} secs has passed since it was launched. Skipping"
111
+ )
112
+ return False
113
+ else:
114
+ print(f"[Y] {path} is a new checkpoint. Uploading")
115
+ return True
116
+
117
+
118
+ def main():
119
+ args = get_args()
120
+
121
+ checkpoints_path = Path(args.checkpoints_path)
122
+ if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
123
+ raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
124
+
125
+ checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
126
+ if len(checkpoint_dirs) == 0:
127
+ exit("No checkpoints found, exiting")
128
+
129
+ exp_name = checkpoints_path.name
130
+
131
+ # Check each folder in real time to allow for overlapping jobs starting at different times
132
+ for checkpoint_dir in checkpoint_dirs:
133
+ print(f"\n*** Checking {checkpoint_dir}")
134
+
135
+ control_file_path = checkpoint_dir / control_file_name
136
+ finished_uploading_file_path = checkpoint_dir / finished_uploading_file_name
137
+
138
+ if not should_process(checkpoint_dir, args.force, control_file_path, finished_uploading_file_path, args):
139
+ continue
140
+
141
+ opt_step = checkpoint_dir.name
142
+ bucket_name = "m4-exps"
143
+ bucket_path = f"{exp_name}/{opt_step}"
144
+
145
+ print(f"Launching upload for {checkpoint_dir} - it could take a long time")
146
+ cmd = f"s5cmd sync {checkpoint_dir}/ s3://{bucket_name}/{bucket_path}/".split()
147
+ # we could use flock here, to avoid a race condition, but it'd be pointless since each
148
+ # cronjob is likely to run on a different node and flock only works within a single node
149
+ control_file_path.touch()
150
+ # print(f"mock running {cmd}")
151
+
152
+ # s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
153
+ # ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
154
+ # you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
155
+ # XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
156
+
157
+ tries = 0
158
+ while tries < RETRIES:
159
+ tries += 1
160
+ try:
161
+ response = run_cmd(cmd)
162
+ print(response)
163
+ break
164
+ except EnvironmentError as e:
165
+ if "InvalidDigest" in str(e):
166
+ print(f"MD5 checksum failed, upload retry {tries}")
167
+ continue
168
+ except Exception:
169
+ # some other possible failure?
170
+ raise
171
+
172
+ # for now disable this as large files don't have sha256 checksums
173
+ # result = integrity_check_recursive(checkpoint_dir, bucket_name, bucket_path)
174
+ # print(f"Integrity check was {result}")
175
+
176
+ control_file_path.unlink()
177
+ finished_uploading_file_path.touch()
178
+
179
+ # now upload non-checkpoint files
180
+ print("\n*** Uploading non-checkpoint files")
181
+ upload_dirs = []
182
+ for pat in include_patterns:
183
+ upload_dirs += list(checkpoints_path.glob(pat))
184
+
185
+ for dir in upload_dirs:
186
+ print(f"Launching upload for {dir}")
187
+ cmd = f"s5cmd sync {dir} s3://m4-exps/{exp_name}/".split()
188
+ print(f"running {cmd}")
189
+ response = run_cmd(cmd)
190
+ print(response)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
m4/scripts/s3_checkpoint_download_convert_upload.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+ # This tool converts any deepspeed checkpoints found at given path to hf format
5
+ #
6
+ # Example:
7
+ #
8
+ # ./convert-checkpoints.py checkpoints-path
9
+ #
10
+
11
+ import argparse
12
+ import subprocess
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import boto3
17
+
18
+
19
+ def check_s3_directory(directory_path):
20
+ s3 = boto3.client("s3")
21
+
22
+ # Add a trailing slash to the directory path
23
+ if not directory_path.endswith("/"):
24
+ directory_path += "/"
25
+
26
+ # Check if any objects exist with the given directory prefix
27
+ response = s3.list_objects_v2(Bucket="m4-exps", Prefix=directory_path)
28
+
29
+ # If any objects are found, the directory exists
30
+ if "Contents" in response:
31
+ return True
32
+
33
+ return False
34
+
35
+
36
+ def check_s3_file(file_key):
37
+ s3 = boto3.client("s3")
38
+
39
+ try:
40
+ s3.head_object(Bucket="m4-exps", Key=file_key)
41
+ return True
42
+ except Exception:
43
+ return False
44
+
45
+
46
+ def run_cmd(cmd, check=True):
47
+ try:
48
+ response = subprocess.run(
49
+ cmd,
50
+ stderr=subprocess.PIPE,
51
+ stdout=subprocess.PIPE,
52
+ check=check,
53
+ encoding="utf-8",
54
+ ).stdout.strip()
55
+ except subprocess.CalledProcessError as exc:
56
+ raise EnvironmentError(exc.stderr)
57
+
58
+ return response
59
+
60
+
61
+ def get_args():
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument("run_name", type=str, help="run name")
64
+ parser.add_argument("opt_step_num_list", nargs="+", help="list of opt-steps to download")
65
+ parser.add_argument("repo_path", type=str, help="repo path")
66
+
67
+ parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
68
+ return parser.parse_args()
69
+
70
+
71
+ def exit(msg):
72
+ print(msg)
73
+ sys.exit()
74
+
75
+
76
+ def cmd_retry_loop(cmd, max_retries=5):
77
+ # s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
78
+ # ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
79
+ # you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
80
+ # XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
81
+
82
+ tries = 0
83
+ while tries < max_retries:
84
+ tries += 1
85
+ try:
86
+ response = run_cmd(cmd)
87
+ print(response)
88
+ break
89
+ except EnvironmentError as e:
90
+ if "InvalidDigest" in str(e):
91
+ print(f"MD5 checksum failed, download retry {tries}")
92
+ continue
93
+ except Exception:
94
+ # some other possible failure?
95
+ raise
96
+ return response
97
+
98
+
99
+ def main():
100
+ args = get_args()
101
+
102
+ run_name = args.run_name
103
+ opt_step_num_list = args.opt_step_num_list
104
+ repo_path = Path(args.repo_path)
105
+ zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
106
+ bucket_name = "m4-exps"
107
+ opt_step_s3_file_keys = [f"{run_name}/opt_step-{opt_step_num}" for opt_step_num in opt_step_num_list]
108
+
109
+ check_s3_directory(run_name)
110
+
111
+ # Check each folder in real time to allow for overlapping jobs starting at different times
112
+ for opt_step_s3_file_key in opt_step_s3_file_keys:
113
+ print(f"\n*** Checking {opt_step_s3_file_key}")
114
+ if not check_s3_directory(opt_step_s3_file_key):
115
+ print(f"The checkpoint {opt_step_s3_file_key} does not exist - skipping")
116
+ continue
117
+ unwrapped_model_s3_file_key = f"{opt_step_s3_file_key}/unwrapped_model"
118
+ bin_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin"
119
+ index_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin.index.json"
120
+ is_not_converted = not check_s3_file(bin_s3_file_key) and not check_s3_file(index_s3_file_key)
121
+ if is_not_converted:
122
+ print(
123
+ f"The checkpoint hasn't been converted, launching download for {opt_step_s3_file_key} - it could take"
124
+ " a long time"
125
+ )
126
+
127
+ opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
128
+ cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
129
+ cmd = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/* {cluster_opt_step_dir}".split()
130
+ download_response_opt_step_dir = cmd_retry_loop(cmd, max_retries=5)
131
+ print(f"download_response_opt_step_dir: {download_response_opt_step_dir}")
132
+ else:
133
+ print(
134
+ "The checkpoint has been converted already, downloading only the unwrapped checkpoint and"
135
+ " tokenizer dir"
136
+ )
137
+ opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
138
+ cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
139
+ unwrapped_model_dir = f"{cluster_opt_step_dir}/unwrapped_model"
140
+ tokenizer_dir = f"{cluster_opt_step_dir}/tokenizer"
141
+ cmd_model = (
142
+ f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/* {unwrapped_model_dir}".split()
143
+ )
144
+ cmd_tokenizer = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/tokenizer/* {tokenizer_dir}".split()
145
+ download_response_model = cmd_retry_loop(cmd_model, max_retries=5)
146
+ print(f"download_response_model: {download_response_model}")
147
+ download_response_tokenizer = cmd_retry_loop(cmd_tokenizer, max_retries=5)
148
+ print(f"download_response_tokenizer: {download_response_tokenizer}")
149
+
150
+ print(f"opt_step_dirname: {opt_step_dirname} downloaded to cluster_opt_step_dir: {cluster_opt_step_dir}")
151
+
152
+ if is_not_converted:
153
+ print(f"Converting {cluster_opt_step_dir}")
154
+ convert_cmd = [zero_checkpoint_to_hf_path, cluster_opt_step_dir]
155
+ conversion_response = run_cmd(convert_cmd)
156
+ print(f"conversion_response: {conversion_response}")
157
+ print(f"upload converted checkpoint: {cluster_opt_step_dir}")
158
+ upload_cmd = (
159
+ f"s5cmd sync {cluster_opt_step_dir}/unwrapped_model/"
160
+ f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/ ".split()
161
+ )
162
+ upload_response = cmd_retry_loop(upload_cmd, max_retries=5)
163
+ print(f"upload_response: {upload_response}")
164
+ print(
165
+ f"Uploaded {cluster_opt_step_dir}/unwrapped_model to"
166
+ f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model"
167
+ )
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
m4/scripts/s3_checkpoint_download_convert_upload.slurm ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=tr_test-s3-download-and-convert-checkpoints
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --nodes=1
5
+ #SBATCH --time=3:00:00
6
+ #SBATCH --partition=production-cluster
7
+ #SBATCH --output=/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/logs/%x-%j.out
8
+
9
+
10
+ set -e
11
+
12
+ # ----------------- Auto-Workdir -----------------
13
+ if [ -n $SLURM_JOB_ID ]; then
14
+ # check the original location through scontrol and $SLURM_JOB_ID
15
+ SCRIPT_PATH=$(scontrol show job $SLURM_JOB_ID | awk -F= '/Command=/{print $2}')
16
+ else
17
+ # otherwise: started with bash. Get the real location.
18
+ SCRIPT_PATH=$(realpath $0)
19
+ fi
20
+ SCRIPT_DIR=$(dirname ${SCRIPT_PATH})
21
+ M4_REPO_PATH=$(builtin cd $SCRIPT_DIR/../../; pwd)
22
+
23
+ # --------------------------------------------------
24
+
25
+ ### EDIT ME START ###
26
+
27
+ CONDA_ENV_NAME=shared-m4
28
+
29
+ EXPERIMENT_NAME=tr_194_laion_cm4_mix
30
+
31
+ opt_step_num_list=(
32
+ "1000"
33
+ "2000"
34
+ )
35
+
36
+ ### EDIT ME END ###
37
+
38
+
39
+ echo "START TIME: $(date)"
40
+
41
+ source /fsx/m4/start-m4-user
42
+ conda activate base
43
+ conda activate $CONDA_ENV_NAME
44
+ pushd $M4_REPO_PATH
45
+ export PYTHONPATH=$WORKING_DIR:$PYTHONPATH
46
+
47
+ echo "running checkpoint download, convert, upload for opt-steps: ${opt_step_num_list[@]} of experiment: $EXPERIMENT_NAME"
48
+
49
+ python $M4_REPO_PATH/m4/scripts/s3_checkpoint_download_convert_upload.py $EXPERIMENT_NAME ${opt_step_num_list[@]} $M4_REPO_PATH
50
+
51
+ echo "END TIME: $(date)"
m4/scripts/s3_downloaded_checkpoints_cleanup.slurm ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=tr_test-s3-cleanup-checkpoints
3
+ #SBATCH --ntasks=1
4
+ #SBATCH --nodes=1
5
+ #SBATCH --time=3:00:00
6
+ #SBATCH --partition=production-cluster
7
+ #SBATCH --output=/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/logs/%x-%j.out
8
+
9
+
10
+ set -e
11
+
12
+ # ----------------- Auto-Workdir -----------------
13
+ if [ -n $SLURM_JOB_ID ]; then
14
+ # check the original location through scontrol and $SLURM_JOB_ID
15
+ SCRIPT_PATH=$(scontrol show job $SLURM_JOB_ID | awk -F= '/Command=/{print $2}')
16
+ else
17
+ # otherwise: started with bash. Get the real location.
18
+ SCRIPT_PATH=$(realpath $0)
19
+ fi
20
+ SCRIPT_DIR=$(dirname ${SCRIPT_PATH})
21
+ M4_REPO_PATH=$(builtin cd $SCRIPT_DIR/../../; pwd)
22
+
23
+ # --------------------------------------------------
24
+
25
+ ### EDIT ME START ###
26
+
27
+ CONDA_ENV_NAME=shared-m4
28
+
29
+ EXPERIMENT_NAME=tr_194_laion_cm4_mix
30
+
31
+ opt_step_num_list=(
32
+ "1000"
33
+ "2000"
34
+ )
35
+
36
+ ### EDIT ME END ###
37
+
38
+
39
+ echo "START TIME: $(date)"
40
+
41
+ source /fsx/m4/start-m4-user
42
+ conda activate base
43
+ conda activate $CONDA_ENV_NAME
44
+ pushd $M4_REPO_PATH
45
+ export PYTHONPATH=$WORKING_DIR:$PYTHONPATH
46
+
47
+ for opt_step_num in ${opt_step_num_list[@]}
48
+ do
49
+ OPT_STEP_DIR="/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/${EXPERIMENT_NAME}/opt_step-${opt_step_num}"
50
+ rm -r $OPT_STEP_DIR
51
+ echo "Deleted $OPT_STEP_DIR of experiment: $EXPERIMENT_NAME"
52
+ done
53
+
54
+ echo "END TIME: $(date)"
m4/scripts/schedule-evals.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+ # This tool checks if evaluation is needed
5
+ #
6
+
7
+ import argparse
8
+ import os
9
+ import subprocess
10
+ import sys
11
+ import time
12
+ from pathlib import Path
13
+
14
+
15
+ repo_path = Path(__file__).parents[2]
16
+
17
+ # we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
18
+ # rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
19
+ #
20
+ # If that file is there it means:
21
+ #
22
+ # 1. either the eval is still running
23
+ # 2. the eval got aborted (e.g. gpu-oom)
24
+ #
25
+
26
+ # should fine tune - but surely 9h per checkpoint is plenty
27
+ reasonable_eval_time_in_secs = 9 * 60 * 60
28
+
29
+
30
+ def run_cmd(cmd, check=True):
31
+ try:
32
+ response = subprocess.run(
33
+ cmd,
34
+ stderr=subprocess.PIPE,
35
+ stdout=subprocess.PIPE,
36
+ check=check,
37
+ encoding="utf-8",
38
+ ).stdout.strip()
39
+ except subprocess.CalledProcessError as exc:
40
+ raise EnvironmentError(exc.stderr)
41
+
42
+ return response
43
+
44
+
45
+ def get_args():
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
48
+ return parser.parse_args()
49
+
50
+
51
+ def exit(msg):
52
+ print(msg)
53
+ sys.exit()
54
+
55
+
56
+ def check_eval_crash(path):
57
+ """Heuristics to decide whether to restart this opt_step-XXX checkpoint evaluation or not"""
58
+ eval_0_completed_path = path / "start_run_evals_0_shots"
59
+ eval_4_completed_path = path / "start_run_evals_4_shots"
60
+ eval_perplexity_path = path / "start_run_evals_perplexity_validation"
61
+ # complicated checks - has another job already started processing? or did it crash?
62
+ for eval_start_path in [eval_0_completed_path, eval_4_completed_path, eval_perplexity_path]:
63
+ if eval_start_path.exists():
64
+ if eval_start_path.stat().st_mtime < time.time() - reasonable_eval_time_in_secs:
65
+ print(f"[Y] {path} looks stale - Probably crashed - Restart evals")
66
+ os.remove(eval_start_path)
67
+
68
+
69
+ def main():
70
+ args = get_args()
71
+
72
+ checkpoints_path = Path(args.checkpoints_path)
73
+ if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
74
+ raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
75
+
76
+ checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
77
+ if len(checkpoint_dirs) == 0:
78
+ exit("No checkpoints found, exiting")
79
+
80
+ checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
81
+ for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
82
+ print(f"\n*** Checking {checkpoint_dir} for evals")
83
+ check_eval_crash(checkpoint_dir)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
m4/testing_utils.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # adapted from https://stackoverflow.com/a/59041913/9201239
16
+ import asyncio # noqa
17
+ import contextlib
18
+ import importlib.util
19
+ import inspect
20
+ import json
21
+ import logging
22
+ import os
23
+ import random
24
+ import re
25
+ import shutil
26
+ import sys
27
+ import tempfile
28
+ import unittest
29
+ from distutils.util import strtobool
30
+ from io import StringIO
31
+ from pathlib import Path
32
+ from typing import Iterator, Union
33
+ from unittest import mock
34
+ from unittest.case import SkipTest
35
+
36
+ import numpy as np
37
+ from packaging import version
38
+ from parameterized import parameterized
39
+
40
+
41
+ try:
42
+ import torch
43
+
44
+ _torch_available = True
45
+ except Exception:
46
+ _torch_available = False
47
+
48
+
49
+ def is_torch_available():
50
+ return _torch_available
51
+
52
+
53
+ def parse_flag_from_env(key, default=False):
54
+ try:
55
+ value = os.environ[key]
56
+ except KeyError:
57
+ # KEY isn't set, default to `default`.
58
+ _value = default
59
+ else:
60
+ # KEY is set, convert it to True or False.
61
+ try:
62
+ _value = strtobool(value)
63
+ except ValueError:
64
+ # More values are supported, but let's keep the message simple.
65
+ raise ValueError(f"If set, {key} must be yes or no.")
66
+ return _value
67
+
68
+
69
+ def parse_int_from_env(key, default=None):
70
+ try:
71
+ value = os.environ[key]
72
+ except KeyError:
73
+ _value = default
74
+ else:
75
+ try:
76
+ _value = int(value)
77
+ except ValueError:
78
+ raise ValueError(f"If set, {key} must be a int.")
79
+ return _value
80
+
81
+
82
+ def require_torch(test_case):
83
+ """
84
+ Decorator marking a test that requires PyTorch.
85
+
86
+ These tests are skipped when PyTorch isn't installed.
87
+
88
+ """
89
+ if not is_torch_available():
90
+ return unittest.skip("test requires PyTorch")(test_case)
91
+ else:
92
+ return test_case
93
+
94
+
95
+ def require_torch_no_gpus(test_case):
96
+ """
97
+ Decorator marking a test that requires a setup without GPUs (in PyTorch). These tests are skipped on a machine with GPUs.
98
+
99
+ To run *only* the no gpu tests, assuming all test names contain no_gpu: $ pytest -sv ./tests -k "no_gpu"
100
+ """
101
+ import torch
102
+
103
+ if is_torch_available() and torch.cuda.device_count() > 0:
104
+ return unittest.skip("test requires an environment w/o GPUs")(test_case)
105
+ else:
106
+ return test_case
107
+
108
+
109
+ def require_torch_multi_gpu(test_case):
110
+ """
111
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
112
+ multiple GPUs.
113
+
114
+ To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
115
+ """
116
+ if not is_torch_available():
117
+ return unittest.skip("test requires PyTorch")(test_case)
118
+
119
+ import torch
120
+
121
+ if torch.cuda.device_count() < 2:
122
+ return unittest.skip("test requires multiple GPUs")(test_case)
123
+ else:
124
+ return test_case
125
+
126
+
127
+ def require_torch_non_multi_gpu(test_case):
128
+ """
129
+ Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
130
+ """
131
+ if not is_torch_available():
132
+ return unittest.skip("test requires PyTorch")(test_case)
133
+
134
+ import torch
135
+
136
+ if torch.cuda.device_count() > 1:
137
+ return unittest.skip("test requires 0 or 1 GPU")(test_case)
138
+ else:
139
+ return test_case
140
+
141
+
142
+ def require_torch_up_to_2_gpus(test_case):
143
+ """
144
+ Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
145
+ """
146
+ if not is_torch_available():
147
+ return unittest.skip("test requires PyTorch")(test_case)
148
+
149
+ import torch
150
+
151
+ if torch.cuda.device_count() > 2:
152
+ return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
153
+ else:
154
+ return test_case
155
+
156
+
157
+ if is_torch_available():
158
+ # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
159
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ else:
161
+ torch_device = None
162
+
163
+
164
+ def require_torch_gpu(test_case):
165
+ """Decorator marking a test that requires CUDA and PyTorch."""
166
+ if torch_device != "cuda":
167
+ return unittest.skip("test requires CUDA")(test_case)
168
+ else:
169
+ return test_case
170
+
171
+
172
+ def is_deepspeed_available():
173
+ return importlib.util.find_spec("deepspeed") is not None
174
+
175
+
176
+ def require_deepspeed(test_case):
177
+ """
178
+ Decorator marking a test that requires deepspeed
179
+ """
180
+ if not is_deepspeed_available():
181
+ return unittest.skip("test requires deepspeed")(test_case)
182
+ else:
183
+ return test_case
184
+
185
+
186
+ def is_bnb_available():
187
+ return importlib.util.find_spec("bitsandbytes") is not None
188
+
189
+
190
+ def require_bnb(test_case):
191
+ """
192
+ Decorator marking a test that requires bitsandbytes
193
+ """
194
+ if not is_bnb_available():
195
+ return unittest.skip("test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")(
196
+ test_case
197
+ )
198
+ else:
199
+ return test_case
200
+
201
+
202
+ def require_bnb_non_decorator():
203
+ """
204
+ Non-Decorator function that would skip a test if bitsandbytes is missing
205
+ """
206
+ if not is_bnb_available():
207
+ raise SkipTest("Test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")
208
+
209
+
210
+ def set_seed(seed: int = 42):
211
+ """
212
+ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``
213
+
214
+ Args:
215
+ seed (:obj:`int`): The seed to set.
216
+ """
217
+ random.seed(seed)
218
+ np.random.seed(seed)
219
+ if is_torch_available():
220
+ torch.manual_seed(seed)
221
+ torch.cuda.manual_seed_all(seed)
222
+ # ^^ safe to call this function even if cuda is not available
223
+
224
+
225
+ def get_gpu_count():
226
+ """
227
+ Return the number of available gpus (regardless of whether torch or tf is used)
228
+ """
229
+ if is_torch_available():
230
+ import torch
231
+
232
+ return torch.cuda.device_count()
233
+ else:
234
+ return 0
235
+
236
+
237
+ def torch_assert_equal(actual, expected, **kwargs):
238
+ """
239
+ compare two tensors or non-tensor numbers for their equality
240
+ """
241
+ # assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
242
+ return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
243
+
244
+
245
+ def torch_assert_close(actual, expected, **kwargs):
246
+ """
247
+ compare two tensors or non-tensor numbers for their closeness.
248
+ """
249
+ # assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
250
+ return torch.testing.assert_close(actual, expected, **kwargs)
251
+
252
+
253
+ def is_torch_bf16_available():
254
+ # from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301
255
+ if is_torch_available():
256
+ import torch
257
+
258
+ if not torch.cuda.is_available() or torch.version.cuda is None:
259
+ return False
260
+ if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
261
+ return False
262
+ if int(torch.version.cuda.split(".")[0]) < 11:
263
+ return False
264
+ if not version.parse(torch.__version__) >= version.parse("1.09"):
265
+ return False
266
+ return True
267
+ else:
268
+ return False
269
+
270
+
271
+ def require_torch_bf16(test_case):
272
+ """Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9."""
273
+ if not is_torch_bf16_available():
274
+ return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case)
275
+ else:
276
+ return test_case
277
+
278
+
279
+ def get_tests_dir(append_path=None):
280
+ """
281
+ Args:
282
+ append_path: optional path to append to the tests dir path
283
+
284
+ Return:
285
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
286
+ joined after the `tests` dir the former is provided.
287
+
288
+ """
289
+ # this function caller's __file__
290
+ caller__file__ = inspect.stack()[1][1]
291
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
292
+ if append_path:
293
+ return os.path.join(tests_dir, append_path)
294
+ else:
295
+ return tests_dir
296
+
297
+
298
+ def parameterized_custom_name_func_join_params(func, param_num, param):
299
+ """
300
+ customize the test name generator function as we want all params to appear in the sub-test
301
+ name, as by default it shows only the first param or for multiple params it just uses a unique sequence of ids and no params at all.
302
+
303
+ Usage:
304
+
305
+ @parameterized.expand(
306
+ [
307
+ (0, True),
308
+ (0, False),
309
+ (1, True),
310
+ ],
311
+ name_func=parameterized_custom_name_func_join_params,
312
+ )
313
+ def test_determinism_wrt_rank(self, num_workers, pad_dataset):
314
+
315
+ which gives:
316
+
317
+ test_determinism_wrt_rank_0_true
318
+ test_determinism_wrt_rank_0_false
319
+ test_determinism_wrt_rank_1_true
320
+
321
+ """
322
+ param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
323
+ return f"{func.__name__}_{param_based_name}"
324
+
325
+
326
+ #
327
+ # Helper functions for dealing with testing text outputs
328
+ # The original code came from:
329
+ # https://github.com/fastai/fastai/blob/master/tests/utils/text.py
330
+
331
+
332
+ # When any function contains print() calls that get overwritten, like progress bars,
333
+ # a special care needs to be applied, since under pytest -s captured output (capsys
334
+ # or contextlib.redirect_stdout) contains any temporary printed strings, followed by
335
+ # \r's. This helper function ensures that the buffer will contain the same output
336
+ # with and without -s in pytest, by turning:
337
+ # foo bar\r tar mar\r final message
338
+ # into:
339
+ # final message
340
+ # it can handle a single string or a multiline buffer
341
+ def apply_print_resets(buf):
342
+ return re.sub(r"^.*\r", "", buf, 0, re.M)
343
+
344
+
345
+ def assert_screenout(out, what):
346
+ out_pr = apply_print_resets(out).lower()
347
+ match_str = out_pr.find(what.lower())
348
+ assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
349
+
350
+
351
+ class CaptureStd:
352
+ """
353
+ Context manager to capture:
354
+
355
+ - stdout: replay it, clean it up and make it available via ``obj.out``
356
+ - stderr: replay it and make it available via ``obj.err``
357
+ - combined: combined the chosen streams and make it available via ``obj.combined``
358
+
359
+ init arguments:
360
+
361
+ - out - capture stdout:`` True``/``False``, default ``True``
362
+ - err - capture stdout: ``True``/``False``, default ``True``
363
+ - replay - whether to replay or not: ``True``/``False``, default ``True``. By default each
364
+ captured stream gets replayed back on context's exit, so that one can see what the test was
365
+ doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass
366
+ ``replay=False`` to disable this feature.
367
+
368
+ Examples::
369
+
370
+ # to capture stdout only with auto-replay
371
+ with CaptureStdout() as cs:
372
+ print("Secret message")
373
+ assert "message" in cs.out
374
+
375
+ # to capture stderr only with auto-replay
376
+ import sys
377
+ with CaptureStderr() as cs:
378
+ print("Warning: ", file=sys.stderr)
379
+ assert "Warning" in cs.err
380
+
381
+ # to capture both streams with auto-replay
382
+ with CaptureStd() as cs:
383
+ print("Secret message")
384
+ print("Warning: ", file=sys.stderr)
385
+ assert "message" in cs.out
386
+ assert "Warning" in cs.err
387
+
388
+ # to capture just one of the streams, and not the other, with auto-replay
389
+ with CaptureStd(err=False) as cs:
390
+ print("Secret message")
391
+ assert "message" in cs.out
392
+ # but best use the stream-specific subclasses
393
+
394
+ # to capture without auto-replay
395
+ with CaptureStd(replay=False) as cs:
396
+ print("Secret message")
397
+ assert "message" in cs.out
398
+
399
+ # sometimes it's easier to not try to figure out if it's stdout or stderr, and yet at
400
+ # other times the software may send the same output to stderr or stdout depending on
401
+ # environment, so to make the test robust a combined entry of both streams is available
402
+
403
+ """
404
+
405
+ def __init__(self, out=True, err=True, replay=True):
406
+ self.replay = replay
407
+
408
+ if out:
409
+ self.out_buf = StringIO()
410
+ self.out = "error: CaptureStd context is unfinished yet, called too early"
411
+ else:
412
+ self.out_buf = None
413
+ self.out = "not capturing stdout"
414
+
415
+ if err:
416
+ self.err_buf = StringIO()
417
+ self.err = "error: CaptureStd context is unfinished yet, called too early"
418
+ else:
419
+ self.err_buf = None
420
+ self.err = "not capturing stderr"
421
+
422
+ self.combined = "error: CaptureStd context is unfinished yet, called too early"
423
+
424
+ def __enter__(self):
425
+ if self.out_buf is not None:
426
+ self.out_old = sys.stdout
427
+ sys.stdout = self.out_buf
428
+
429
+ if self.err_buf is not None:
430
+ self.err_old = sys.stderr
431
+ sys.stderr = self.err_buf
432
+
433
+ self.combined = ""
434
+
435
+ return self
436
+
437
+ def __exit__(self, *exc):
438
+ if self.out_buf is not None:
439
+ sys.stdout = self.out_old
440
+ captured = self.out_buf.getvalue()
441
+ if self.replay:
442
+ sys.stdout.write(captured)
443
+ self.out = apply_print_resets(captured)
444
+ self.combined += self.out
445
+
446
+ if self.err_buf is not None:
447
+ sys.stderr = self.err_old
448
+ captured = self.err_buf.getvalue()
449
+ if self.replay:
450
+ sys.stderr.write(captured)
451
+ self.err = captured
452
+ self.combined += self.err
453
+
454
+ def __repr__(self):
455
+ msg = ""
456
+ if self.out_buf:
457
+ msg += f"stdout: {self.out}\n"
458
+ if self.err_buf:
459
+ msg += f"stderr: {self.err}\n"
460
+ return msg
461
+
462
+
463
+ # in tests it's the best to capture only the stream that's wanted, otherwise
464
+ # it's easy to miss things, so unless you need to capture both streams, use the
465
+ # subclasses below (less typing). Or alternatively, configure `CaptureStd` to
466
+ # disable the stream you don't need to test.
467
+
468
+
469
+ class CaptureStdout(CaptureStd):
470
+ """Same as CaptureStd but captures only stdout"""
471
+
472
+ def __init__(self, replay=True):
473
+ super().__init__(err=False, replay=replay)
474
+
475
+
476
+ class CaptureStderr(CaptureStd):
477
+ """Same as CaptureStd but captures only stderr"""
478
+
479
+ def __init__(self, replay=True):
480
+ super().__init__(out=False, replay=replay)
481
+
482
+
483
+ class CaptureLogger:
484
+ """
485
+ Context manager to capture `logging` streams
486
+
487
+ Args:
488
+
489
+ - logger: 'logging` logger object
490
+
491
+ Results:
492
+ The captured output is available via `self.out`
493
+
494
+ Example::
495
+
496
+ >>> from transformers import logging
497
+ >>> from transformers.testing_utils import CaptureLogger
498
+
499
+ >>> msg = "Testing 1, 2, 3"
500
+ >>> logging.set_verbosity_info()
501
+ >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
502
+ >>> with CaptureLogger(logger) as cl:
503
+ ... logger.info(msg)
504
+ >>> assert cl.out, msg+"\n"
505
+ """
506
+
507
+ def __init__(self, logger):
508
+ self.logger = logger
509
+ self.io = StringIO()
510
+ self.sh = logging.StreamHandler(self.io)
511
+ self.out = ""
512
+
513
+ def __enter__(self):
514
+ self.logger.addHandler(self.sh)
515
+ return self
516
+
517
+ def __exit__(self, *exc):
518
+ self.logger.removeHandler(self.sh)
519
+ self.out = self.io.getvalue()
520
+
521
+ def __repr__(self):
522
+ return f"captured: {self.out}\n"
523
+
524
+
525
+ @contextlib.contextmanager
526
+ # adapted from https://stackoverflow.com/a/64789046/9201239
527
+ def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
528
+ """
529
+ Temporary add given path to `sys.path`.
530
+
531
+ Usage ::
532
+
533
+ with ExtendSysPath('/path/to/dir'):
534
+ mymodule = importlib.import_module('mymodule')
535
+
536
+ """
537
+
538
+ path = os.fspath(path)
539
+ try:
540
+ sys.path.insert(0, path)
541
+ yield
542
+ finally:
543
+ sys.path.remove(path)
544
+
545
+
546
+ class TestCasePlus(unittest.TestCase):
547
+ """This class extends `unittest.TestCase` with additional features.
548
+
549
+ Feature 1: A set of fully resolved important file and dir path accessors.
550
+
551
+ In tests often we need to know where things are relative to the current test file, and it's not trivial since the
552
+ test could be invoked from more than one directory or could reside in sub-directories with different depths. This
553
+ class solves this problem by sorting out all the basic paths and provides easy accessors to them:
554
+
555
+ * ``pathlib`` objects (all fully resolved):
556
+
557
+ - ``test_file_path`` - the current test file path (=``__file__``)
558
+ - ``test_file_dir`` - the directory containing the current test file
559
+ - ``tests_dir`` - the directory of the ``tests`` test suite
560
+ - ``data_dir`` - the directory of the ``tests/data`` test suite
561
+ - ``repo_root_dir`` - the directory of the repository
562
+ - ``src_dir`` - the directory where the ``m4`` sub-dir resides (same as repo_root_dir in this case)
563
+
564
+ * stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects:
565
+
566
+ - ``test_file_path_str``
567
+ - ``test_file_dir_str``
568
+ - ``tests_dir_str``
569
+ - ``data_dir_str``
570
+ - ``repo_root_dir_str``
571
+ - ``src_dir_str``
572
+
573
+ Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
574
+
575
+ 1. Create a unique temporary dir:
576
+
577
+ ::
578
+
579
+ def test_whatever(self):
580
+ tmp_dir = self.get_auto_remove_tmp_dir()
581
+
582
+ ``tmp_dir`` will contain the pathlib path to the created temporary dir. It will be automatically
583
+ removed at the end of the test.
584
+
585
+
586
+ 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
587
+ empty it after the test.
588
+
589
+ ::
590
+
591
+ def test_whatever(self):
592
+ tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
593
+
594
+ This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
595
+ didn't leave any data in there.
596
+
597
+ 3. You can override the first two options by directly overriding the ``before`` and ``after`` args, leading to the
598
+ following behavior:
599
+
600
+ ``before=True``: the temporary dir will always be cleared at the beginning of the test.
601
+
602
+ ``before=False``: if the temporary dir already existed, any existing files will remain there.
603
+
604
+ ``after=True``: the temporary dir will always be deleted at the end of the test.
605
+
606
+ ``after=False``: the temporary dir will always be left intact at the end of the test.
607
+
608
+ Use `self.get_auto_remove_tmp_dir_str()` instead if you want the returned value to be a non-pathlib version.
609
+
610
+ Note 1: In order to run the equivalent of ``rm -r`` safely, only subdirs of the project repository checkout are
611
+ allowed if an explicit ``tmp_dir`` is used, so that by mistake no ``/tmp`` or similar important part of the
612
+ filesystem will get nuked. i.e. please always pass paths that start with ``./``
613
+
614
+ Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
615
+ otherwise.
616
+
617
+ Feature 3: Get a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` specific to the current test suite.
618
+ This is useful for invoking external programs from the test suite - e.g. distributed training.
619
+
620
+
621
+ ::
622
+ def test_whatever(self):
623
+ env = self.get_env()
624
+
625
+ """
626
+
627
+ def setUp(self):
628
+ # get_auto_remove_tmp_dir feature:
629
+ self.teardown_tmp_dirs = []
630
+
631
+ # figure out the resolved paths for repo_root, tests, etc.
632
+ self._test_file_path = inspect.getfile(self.__class__)
633
+ path = Path(self._test_file_path).resolve()
634
+ self._test_file_dir = path.parents[0]
635
+ for up in [1, 2, 3]:
636
+ tmp_dir = path.parents[up]
637
+ if (tmp_dir / "m4").is_dir() and (tmp_dir / "tests").is_dir():
638
+ break
639
+ if tmp_dir:
640
+ self._repo_root_dir = tmp_dir
641
+ else:
642
+ raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
643
+ self._tests_dir = self._repo_root_dir / "tests"
644
+ self._data_dir = self._repo_root_dir / "tests" / "test_data"
645
+ self._src_dir = self._repo_root_dir # m4 doesn't use "src/" prefix in the repo
646
+
647
+ @property
648
+ def test_file_path(self):
649
+ return self._test_file_path
650
+
651
+ @property
652
+ def test_file_path_str(self):
653
+ return str(self._test_file_path)
654
+
655
+ @property
656
+ def test_file_dir(self):
657
+ return self._test_file_dir
658
+
659
+ @property
660
+ def test_file_dir_str(self):
661
+ return str(self._test_file_dir)
662
+
663
+ @property
664
+ def tests_dir(self):
665
+ return self._tests_dir
666
+
667
+ @property
668
+ def tests_dir_str(self):
669
+ return str(self._tests_dir)
670
+
671
+ @property
672
+ def data_dir(self):
673
+ return self._data_dir
674
+
675
+ @property
676
+ def data_dir_str(self):
677
+ return str(self._data_dir)
678
+
679
+ @property
680
+ def repo_root_dir(self):
681
+ return self._repo_root_dir
682
+
683
+ @property
684
+ def repo_root_dir_str(self):
685
+ return str(self._repo_root_dir)
686
+
687
+ @property
688
+ def src_dir(self):
689
+ return self._src_dir
690
+
691
+ @property
692
+ def src_dir_str(self):
693
+ return str(self._src_dir)
694
+
695
+ def get_env(self):
696
+ """
697
+ Return a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` correctly. This is useful
698
+ for invoking external programs from the test suite - e.g. distributed training.
699
+
700
+ It always inserts ``.`` first, then ``./tests`` depending on the test suite type and
701
+ finally the preset ``PYTHONPATH`` if any (all full resolved paths).
702
+
703
+ """
704
+ env = os.environ.copy()
705
+ paths = [self.src_dir_str]
706
+ paths.append(self.tests_dir_str)
707
+ paths.append(env.get("PYTHONPATH", ""))
708
+
709
+ env["PYTHONPATH"] = ":".join(paths)
710
+ return env
711
+
712
+ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
713
+ """
714
+ Args:
715
+ tmp_dir (:obj:`string`, `optional`):
716
+ if :obj:`None`:
717
+
718
+ - a unique temporary path will be created
719
+ - sets ``before=True`` if ``before`` is :obj:`None`
720
+ - sets ``after=True`` if ``after`` is :obj:`None`
721
+ else:
722
+
723
+ - :obj:`tmp_dir` will be created
724
+ - sets ``before=True`` if ``before`` is :obj:`None`
725
+ - sets ``after=False`` if ``after`` is :obj:`None`
726
+ before (:obj:`bool`, `optional`):
727
+ If :obj:`True` and the :obj:`tmp_dir` already exists, make sure to empty it right away if :obj:`False`
728
+ and the :obj:`tmp_dir` already exists, any existing files will remain there.
729
+ after (:obj:`bool`, `optional`):
730
+ If :obj:`True`, delete the :obj:`tmp_dir` at the end of the test if :obj:`False`, leave the
731
+ :obj:`tmp_dir` and its contents intact at the end of the test.
732
+
733
+ Returns:
734
+ tmp_dir(:obj:`string`): either the same value as passed via `tmp_dir` or the path to the auto-selected tmp
735
+ dir
736
+ """
737
+ if tmp_dir is not None:
738
+ # defining the most likely desired behavior for when a custom path is provided.
739
+ # this most likely indicates the debug mode where we want an easily locatable dir that:
740
+ # 1. gets cleared out before the test (if it already exists)
741
+ # 2. is left intact after the test
742
+ if before is None:
743
+ before = True
744
+ if after is None:
745
+ after = False
746
+
747
+ # to avoid nuking parts of the filesystem, only relative paths are allowed
748
+ if not tmp_dir.startswith("./"):
749
+ raise ValueError(
750
+ f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
751
+ )
752
+
753
+ # using provided path
754
+ tmp_dir = Path(tmp_dir).resolve()
755
+
756
+ # ensure the dir is empty to start with
757
+ if before is True and tmp_dir.exists():
758
+ shutil.rmtree(tmp_dir, ignore_errors=True)
759
+
760
+ tmp_dir.mkdir(parents=True, exist_ok=True)
761
+
762
+ else:
763
+ # defining the most likely desired behavior for when a unique tmp path is auto generated
764
+ # (not a debug mode), here we require a unique tmp dir that:
765
+ # 1. is empty before the test (it will be empty in this situation anyway)
766
+ # 2. gets fully removed after the test
767
+ if before is None:
768
+ before = True
769
+ if after is None:
770
+ after = True
771
+
772
+ # using unique tmp dir (always empty, regardless of `before`)
773
+ tmp_dir = Path(tempfile.mkdtemp())
774
+
775
+ if after is True:
776
+ # register for deletion
777
+ self.teardown_tmp_dirs.append(tmp_dir)
778
+
779
+ return tmp_dir
780
+
781
+ def get_auto_remove_tmp_dir_str(self, *args, **kwargs):
782
+ return str(self.get_auto_remove_tmp_dir(*args, **kwargs))
783
+
784
+ def tearDown(self):
785
+ # get_auto_remove_tmp_dir feature: remove registered temp dirs
786
+ for path in self.teardown_tmp_dirs:
787
+ shutil.rmtree(path, ignore_errors=True)
788
+ self.teardown_tmp_dirs = []
789
+
790
+
791
+ def mockenv(**kwargs):
792
+ """
793
+ this is a convenience wrapper, that allows this ::
794
+
795
+ @mockenv(RUN_SLOW=True, USE_TF=False)
796
+ def test_something():
797
+ run_slow = os.getenv("RUN_SLOW", False)
798
+ use_tf = os.getenv("USE_TF", False)
799
+
800
+ """
801
+ return mock.patch.dict(os.environ, kwargs)
802
+
803
+
804
+ # from https://stackoverflow.com/a/34333710/9201239
805
+ @contextlib.contextmanager
806
+ def mockenv_context(*remove, **update):
807
+ """
808
+ Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
809
+
810
+ The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
811
+
812
+ Args:
813
+ remove: Environment variables to remove.
814
+ update: Dictionary of environment variables and values to add/update.
815
+ """
816
+ env = os.environ
817
+ update = update or {}
818
+ remove = remove or []
819
+
820
+ # List of environment variables being updated or removed.
821
+ stomped = (set(update.keys()) | set(remove)) & set(env.keys())
822
+ # Environment variables and values to restore on exit.
823
+ update_after = {k: env[k] for k in stomped}
824
+ # Environment variables and values to remove on exit.
825
+ remove_after = frozenset(k for k in update if k not in env)
826
+
827
+ try:
828
+ env.update(update)
829
+ [env.pop(k, None) for k in remove]
830
+ yield
831
+ finally:
832
+ env.update(update_after)
833
+ [env.pop(k) for k in remove_after]
834
+
835
+
836
+ # --- test network helper functions --- #
837
+
838
+
839
+ def get_xdist_worker_id():
840
+ """
841
+ when run under pytest-xdist returns the worker id (int), otherwise returns 0
842
+ """
843
+ worker_id_string = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
844
+ return int(worker_id_string[2:]) # strip "gw"
845
+
846
+
847
+ DEFAULT_MASTER_PORT = 10999
848
+
849
+
850
+ def get_unique_port_number():
851
+ """
852
+ When the test suite runs under pytest-xdist we need to make sure that concurrent tests won't use
853
+ the same port number. We can accomplish that by using the same base and always adding the xdist
854
+ worker id to it, or 0 if not running under pytest-xdist
855
+ """
856
+ return DEFAULT_MASTER_PORT + get_xdist_worker_id()
857
+
858
+
859
+ # --- test IO helper functions --- #
860
+
861
+
862
+ def write_file(file, content):
863
+ with open(file, "w") as f:
864
+ f.write(content)
865
+
866
+
867
+ def read_json_file(file):
868
+ with open(file, "r") as fh:
869
+ return json.load(fh)
870
+
871
+
872
+ def replace_str_in_file(file, text_to_search, replacement_text):
873
+ file = Path(file)
874
+ text = file.read_text()
875
+ text = text.replace(text_to_search, replacement_text)
876
+ file.write_text(text)
877
+
878
+
879
+ # --- pytest conf functions --- #
880
+
881
+ # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
882
+ pytest_opt_registered = {}
883
+
884
+
885
+ def pytest_addoption_shared(parser):
886
+ """
887
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
888
+
889
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
890
+ option.
891
+
892
+ """
893
+ option = "--make-reports"
894
+ if option not in pytest_opt_registered:
895
+ parser.addoption(
896
+ option,
897
+ action="store",
898
+ default=False,
899
+ help="generate report files. The value of this option is used as a prefix to report names",
900
+ )
901
+ pytest_opt_registered[option] = 1
902
+
903
+
904
+ def pytest_terminal_summary_main(tr, id):
905
+ """
906
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
907
+ directory. The report files are prefixed with the test suite name.
908
+
909
+ This function emulates --duration and -rA pytest arguments.
910
+
911
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
912
+ there.
913
+
914
+ Args:
915
+ - tr: `terminalreporter` passed from `conftest.py`
916
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
917
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
918
+
919
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
920
+ changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
921
+ plugins and interfere.
922
+
923
+ """
924
+ from _pytest.config import create_terminal_writer
925
+
926
+ if not len(id):
927
+ id = "tests"
928
+
929
+ config = tr.config
930
+ orig_writer = config.get_terminal_writer()
931
+ orig_tbstyle = config.option.tbstyle
932
+ orig_reportchars = tr.reportchars
933
+
934
+ dir = f"reports/{id}"
935
+ Path(dir).mkdir(parents=True, exist_ok=True)
936
+ report_files = {
937
+ k: f"{dir}/{k}.txt"
938
+ for k in [
939
+ "durations",
940
+ "errors",
941
+ "failures_long",
942
+ "failures_short",
943
+ "failures_line",
944
+ "passes",
945
+ "stats",
946
+ "summary_short",
947
+ "warnings",
948
+ ]
949
+ }
950
+
951
+ # custom durations report
952
+ # note: there is no need to call pytest --durations=XX to get this separate report
953
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
954
+ dlist = []
955
+ for replist in tr.stats.values():
956
+ for rep in replist:
957
+ if hasattr(rep, "duration"):
958
+ dlist.append(rep)
959
+ if dlist:
960
+ dlist.sort(key=lambda x: x.duration, reverse=True)
961
+ with open(report_files["durations"], "w") as f:
962
+ durations_min = 0.05 # sec
963
+ f.write("slowest durations\n")
964
+ for i, rep in enumerate(dlist):
965
+ if rep.duration < durations_min:
966
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
967
+ break
968
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
969
+
970
+ def summary_failures_short(tr):
971
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
972
+ reports = tr.getreports("failed")
973
+ if not reports:
974
+ return
975
+ tr.write_sep("=", "FAILURES SHORT STACK")
976
+ for rep in reports:
977
+ msg = tr._getfailureheadline(rep)
978
+ tr.write_sep("_", msg, red=True, bold=True)
979
+ # chop off the optional leading extra frames, leaving only the last one
980
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
981
+ tr._tw.line(longrepr)
982
+ # note: not printing out any rep.sections to keep the report short
983
+
984
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
985
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
986
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
987
+ # pytest-instafail does that)
988
+
989
+ # report failures with line/short/long styles
990
+ config.option.tbstyle = "auto" # full tb
991
+ with open(report_files["failures_long"], "w") as f:
992
+ tr._tw = create_terminal_writer(config, f)
993
+ tr.summary_failures()
994
+
995
+ # config.option.tbstyle = "short" # short tb
996
+ with open(report_files["failures_short"], "w") as f:
997
+ tr._tw = create_terminal_writer(config, f)
998
+ summary_failures_short(tr)
999
+
1000
+ config.option.tbstyle = "line" # one line per error
1001
+ with open(report_files["failures_line"], "w") as f:
1002
+ tr._tw = create_terminal_writer(config, f)
1003
+ tr.summary_failures()
1004
+
1005
+ with open(report_files["errors"], "w") as f:
1006
+ tr._tw = create_terminal_writer(config, f)
1007
+ tr.summary_errors()
1008
+
1009
+ with open(report_files["warnings"], "w") as f:
1010
+ tr._tw = create_terminal_writer(config, f)
1011
+ tr.summary_warnings() # normal warnings
1012
+ tr.summary_warnings() # final warnings
1013
+
1014
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
1015
+
1016
+ # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
1017
+ # takes > 10 minutes (as this part doesn't generate any output on the terminal).
1018
+ # (also, it seems there is no useful information in this report, and we rarely need to read it)
1019
+ # with open(report_files["passes"], "w") as f:
1020
+ # tr._tw = create_terminal_writer(config, f)
1021
+ # tr.summary_passes()
1022
+
1023
+ with open(report_files["summary_short"], "w") as f:
1024
+ tr._tw = create_terminal_writer(config, f)
1025
+ tr.short_test_summary()
1026
+
1027
+ with open(report_files["stats"], "w") as f:
1028
+ tr._tw = create_terminal_writer(config, f)
1029
+ tr.summary_stats()
1030
+
1031
+ # restore:
1032
+ tr._tw = orig_writer
1033
+ tr.reportchars = orig_reportchars
1034
+ config.option.tbstyle = orig_tbstyle
1035
+
1036
+
1037
+ # --- distributed testing functions --- #
1038
+
1039
+
1040
+ class _RunOutput:
1041
+ def __init__(self, returncode, stdout, stderr):
1042
+ self.returncode = returncode
1043
+ self.stdout = stdout
1044
+ self.stderr = stderr
1045
+
1046
+
1047
+ async def _read_stream(stream, callback):
1048
+ while True:
1049
+ line = await stream.readline()
1050
+ if line:
1051
+ callback(line)
1052
+ else:
1053
+ break
1054
+
1055
+
1056
+ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
1057
+ if echo:
1058
+ print("\nRunning: ", " ".join(cmd))
1059
+
1060
+ p = await asyncio.create_subprocess_exec(
1061
+ cmd[0],
1062
+ *cmd[1:],
1063
+ stdin=stdin,
1064
+ stdout=asyncio.subprocess.PIPE,
1065
+ stderr=asyncio.subprocess.PIPE,
1066
+ env=env,
1067
+ )
1068
+
1069
+ # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
1070
+ # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
1071
+ #
1072
+ # If it starts hanging, will need to switch to the following code. The problem is that no data
1073
+ # will be seen until it's done and if it hangs for example there will be no debug info.
1074
+ # out, err = await p.communicate()
1075
+ # return _RunOutput(p.returncode, out, err)
1076
+
1077
+ out = []
1078
+ err = []
1079
+
1080
+ def tee(line, sink, pipe, label=""):
1081
+ line = line.decode("utf-8").rstrip()
1082
+ sink.append(line)
1083
+ if not quiet:
1084
+ print(label, line, file=pipe)
1085
+
1086
+ # XXX: the timeout doesn't seem to make any difference here
1087
+ await asyncio.wait(
1088
+ [
1089
+ _read_stream(p.stdout, lambda line: tee(line, out, sys.stdout, label="stdout:")),
1090
+ _read_stream(p.stderr, lambda line: tee(line, err, sys.stderr, label="stderr:")),
1091
+ ],
1092
+ timeout=timeout,
1093
+ )
1094
+ return _RunOutput(await p.wait(), out, err)
1095
+
1096
+
1097
+ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
1098
+ loop = asyncio.get_event_loop()
1099
+ result = loop.run_until_complete(
1100
+ _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
1101
+ )
1102
+
1103
+ cmd_str = " ".join(cmd)
1104
+ if result.returncode > 0:
1105
+ stderr = "\n".join(result.stderr)
1106
+ raise RuntimeError(
1107
+ f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
1108
+ f"The combined stderr from workers follows:\n{stderr}"
1109
+ )
1110
+
1111
+ # check that the subprocess actually did run and produced some output, should the test rely on
1112
+ # the remote side to do the testing
1113
+ if not result.stdout and not result.stderr:
1114
+ raise RuntimeError(f"'{cmd_str}' produced no output.")
1115
+
1116
+ return result
m4/training/__init__.py ADDED
File without changes
m4/training/config.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from dataclasses import InitVar, asdict, dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import git
9
+ import yaml
10
+ from simple_parsing import ArgumentParser, Serializable
11
+ from simple_parsing.helpers import dict_field, list_field
12
+
13
+ from m4.training.types import DatasetNames, DatasetTypes
14
+ from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, LoggingTypes
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class CfgFileConfig:
22
+ """Config file args"""
23
+
24
+ # path to config file
25
+ config: Optional[Path] = None
26
+ # set to false if you don't want to save config automatically
27
+ save_config: bool = True
28
+
29
+
30
+ @dataclass
31
+ class GlobalBatchSizeRampUp:
32
+ """These are init variables that are used to set up the GBS ramp up protocol"""
33
+
34
+ # global batch size ramp up protocol:
35
+ #
36
+ # 1. start with global batch size `start`
37
+ # 2. every time the number of `samples` is consumed increment global batch size by `increment`
38
+ # 3. repeat step 2 until global batch size reaches `finish`
39
+ start: Optional[int] = None
40
+ finish: Optional[int] = None
41
+ increment: Optional[int] = None
42
+ samples: Optional[int] = None
43
+
44
+
45
+ @dataclass
46
+ class GlobalBatchSizeRampUpRunningParams:
47
+ """The are running variables that are used to tell when to increment GBS and when to stop doing
48
+ that, they are never set directly in the config file, but are calculated when the training starts.
49
+ """
50
+
51
+ global_seen_samples: int = 0
52
+ global_batch_size_current: int = 0
53
+ next_goal_samples: int = 0
54
+ grad_acc_size_current: int = 1
55
+
56
+
57
+ @dataclass
58
+ class Hparams:
59
+ """General Hyperparameters"""
60
+
61
+ # --------------------
62
+ # General parameters
63
+ # --------------------
64
+
65
+ seed: int = 13
66
+ # If set to True, the sole purpose of the job is to pre-process the dataset (i.e. the map
67
+ # operations). The job will exit as soon as the dataset is pre-processed.
68
+ just_preprocess: bool = False
69
+ jz_job_time_sec: Optional[float] = None
70
+ jz_start_time: float = time.time()
71
+ job_id: Optional[int] = None
72
+ timeout: int = 1800 # 30 min
73
+ # set to False to ignore the optimizer states when loading from a checkpoint
74
+ load_optimizer_states: Optional[bool] = True
75
+ # set to False to disable this gpu memory saving method
76
+ gradient_checkpointing: Optional[bool] = True
77
+
78
+ # --------------------
79
+ # Model-related hparams
80
+ # --------------------
81
+ tokenizer_name: str = "gpt2"
82
+ # The value of the string will evaluated (i.e. interpreted) and must be a dict
83
+ tokenizer_params: str = '{"use_fast":True}'
84
+ tokenizer_add_tokens: str = (
85
+ f'[AddedToken("{FAKE_TOKEN_AROUND_IMAGE_V2}", rstrip=False, lstrip=False), AddedToken("{IMAGE_TOKEN}",'
86
+ " rstrip=False, lstrip=False)]"
87
+ )
88
+ # The value of the string will evaluated (i.e. interpreted). Unnecessary if tokenizer has a `pad_token`.
89
+ tokenizer_add_special_tokens: str = '{"pad_token": tokenizer.eos_token}'
90
+ model_name: str = "gpt2"
91
+ revision: str = "main"
92
+ model_params: Dict[str, Any] = dict_field(
93
+ dict(
94
+ vision_embed_dim=768,
95
+ vision_image_size=224,
96
+ vision_model_name="google/vit-base-patch16-224",
97
+ # The value of the string will evaluated (i.e. interpreted) and must be a dict
98
+ vision_model_params="{}",
99
+ # Ties the word embedding with LM head's weights
100
+ # Since word embedding is frozen, use in conjuncation with freeze_lm_head=True
101
+ tie_word_embeddings=False,
102
+ # Freeze different parts of the model
103
+ freeze_lm_head=False,
104
+ freeze_text_layers=True,
105
+ freeze_text_module_exceptions=[],
106
+ freeze_vision_layers=True,
107
+ freeze_vision_module_exceptions=[],
108
+ # Perceiver Resampler Parameters
109
+ use_resampler=False,
110
+ resampler_n_latents=64,
111
+ resampler_depth=6,
112
+ resampler_n_heads=16,
113
+ resampler_head_dim=96,
114
+ )
115
+ )
116
+
117
+ # --------------------
118
+ # Training parameters
119
+ # --------------------
120
+ resume_run: Optional[bool] = None
121
+ do_validation: bool = True
122
+
123
+ # deprecated in favor of batch_size_per_gpu
124
+ batch_size: Optional[int] = None
125
+ batch_size_per_gpu: int = 1
126
+ global_batch_size: Optional[int] = None
127
+
128
+ global_batch_size_ramp_up: GlobalBatchSizeRampUp = GlobalBatchSizeRampUp()
129
+ grad_acc_size: Optional[int] = 1
130
+
131
+ grad_clip: float = 1.0
132
+
133
+ # weights by which to multiply the loss of each dataset when accumulating gradients over datasets
134
+ loss_weights_per_dataset: Optional[List[float]] = None
135
+ # int(max_num_tokens / (batch_size * max_seq_len * grad_acc_size * num_processes))
136
+ max_num_opt_steps: Optional[int] = 500_000
137
+ max_num_opt_steps_this_run: Optional[int] = None
138
+ max_num_epochs: Optional[int] = None
139
+
140
+ # If the path appears the program will stop after finishing the current training step
141
+ kill_switch_path: Optional[Path] = None
142
+
143
+ # If the path appears the program will save a checkpoint and immediately delete this flag
144
+ save_switch_path: Optional[Path] = None
145
+
146
+ # --------------------
147
+ # Logging parameters
148
+ # --------------------
149
+ train_logging_opt_steps: int = 50
150
+ train_logging_per_dataset_suffix: str = ""
151
+
152
+ # If a specific logging type is specified, per dataset information will be inserted inside
153
+ # those logs.
154
+ train_logging_per_dataset_info: List[LoggingTypes] = list_field(LoggingTypes.JSONL, LoggingTypes.WANDB)
155
+
156
+ # If `train_logging_activations` is not empty, hooks will be inserted to the model to track
157
+ # the min/max/std/norm of the activations and weights. This will slow down training.
158
+ # See https://huggingface.co/docs/transformers/main/en/debugging#underflow-and-overflow-detection
159
+ train_logging_activations: List[LoggingTypes] = list_field()
160
+ train_logging_activations_opt_steps: Optional[int] = 25
161
+ train_logging_grad_param_deepspeed: List[LoggingTypes] = list_field()
162
+ train_logging_grad_param_deepspeed_opt_steps: int = 50
163
+ val_logging_opt_steps: int = train_logging_opt_steps * 5
164
+ val_inline_logging_opt_steps: int = train_logging_opt_steps
165
+ train_saving_opt_steps: int = train_logging_opt_steps * 5
166
+ save_dir: Optional[Path] = None
167
+ upload_to_s3: bool = False
168
+ train_log_mem_usage: bool = False
169
+ timing_break_down: bool = False
170
+
171
+ save_batch_max_idx: Optional[int] = None
172
+ save_batch_min_idx: Optional[int] = None
173
+
174
+ # ----------------------
175
+ # Wandb Parameters
176
+ # ----------------------
177
+ wandb_enable: bool = False
178
+ # name of the project
179
+ wandb_project: str = "VLOOM"
180
+ wandb_entity: str = "huggingfacem4"
181
+ # name of the wandb entity
182
+ wandb_log_freq: int = 50
183
+ wandb_run_id: str = ""
184
+ wandb_tags: Optional[List[str]] = None
185
+
186
+ repo_commit_id: Optional[str] = None
187
+
188
+ # ----------------------
189
+ # Debug Parameters
190
+ # ----------------------
191
+ use_torch_profiler: bool = False
192
+
193
+
194
+ @dataclass
195
+ class ResumeParams:
196
+ # ----------------------
197
+ # Resume run Parameters
198
+ # ----------------------
199
+ # Need to make sure that resume_run is True to give an input here
200
+ opt_step_dir: Optional[Path] = None
201
+ accelerator_state_dir: Optional[Path] = None
202
+ model_file: Optional[Path] = None
203
+ model_config_file: Optional[Path] = None
204
+ # Automatically resumes last run of the save_dir. Set to False to choose a specific run
205
+ resume_last: bool = True
206
+ train_logs: Dict = dict_field()
207
+ resume_opt_step: int = 0
208
+ resume_epoch: int = 0
209
+ resume_dataset_state: List = list_field()
210
+
211
+ gbs_running: GlobalBatchSizeRampUpRunningParams = GlobalBatchSizeRampUpRunningParams()
212
+
213
+
214
+ @dataclass
215
+ class DatasetParams:
216
+ # This always need to be specified as it is needed by dataset utils down the line
217
+ dataset_name: DatasetNames
218
+ # max number of images per sample
219
+ max_num_images: int = 5
220
+ # maximum sequence length
221
+ max_seq_len: int = 256
222
+ training_datasets_paths: List[Path] = list_field()
223
+ validation_datasets_paths: List[Path] = list_field()
224
+ # if True, instead of split and pack, each instance in sample will be
225
+ # either truncated or padded to the same length.
226
+ pad_dataset: bool = False
227
+ map_batch_size: int = 64
228
+ # Preprocessing number of processes in map (not useful for processing on the fly)
229
+ map_num_proc: Optional[int] = None
230
+ # Decides how many number of samples/subsequence should be extracted from the
231
+ # CM4 corpus when the dataset is to be padded irrelavent otherwise as full packing
232
+ # is used
233
+ max_num_samples_per_document: int = 10
234
+
235
+ # Strategy for detecting blur, laplacian or fft
236
+ blur_strategy: str = "fft"
237
+ # Threshold for blur detection, 0.0 means disabled. Set 32 for "laplacian" and
238
+ # 10 for "fft" for starters
239
+ blur_threshold: float = 0.0
240
+
241
+ add_begin_of_doc_token: bool = False
242
+ add_end_of_doc_token: bool = True
243
+
244
+ shuffle_after_packing: bool = False
245
+
246
+ # Parameters for T5 MLM
247
+ t5_mlm_noise_density: float = 0.15
248
+ t5_mlm_mean_noise_span_length: int = 3
249
+
250
+ dataset_type: Optional[DatasetTypes] = None
251
+
252
+ # Parameters for webdataset pipeline
253
+ shuffle_initial_urls_list: bool = False
254
+ shuffle_before_split_by_node_buffer_size: Optional[int] = None
255
+ shuffle_before_split_by_worker_buffer_size: Optional[int] = None
256
+ shuffle_after_tarfile_to_samples_buffer_size: Optional[int] = None
257
+ shuffle_after_batching_buffer_size: Optional[int] = None
258
+
259
+
260
+ @dataclass
261
+ class ImageCaptionPairedDatasetParams(DatasetParams):
262
+ # PMD only: This value decides the probability of the image token being at the start
263
+ # of the text or at the end of the text. Set to 0.5 for equal probability.
264
+ # Set to 0 for the image always at start.
265
+ prob_image_at_end: float = 0.5
266
+ # PMD only: Specifies the tolerance for the amount of padding in a sequence. If set
267
+ # to -1, then all padding will be tolerated. If set to 0, then no padding will be tolerated.
268
+ # Continuously increase this value to allow more padding in the sequence.
269
+ padding_tolerance: int = -1
270
+ dataset_type: DatasetTypes = DatasetTypes.IMAGE_CAPTION_PAIRS
271
+
272
+
273
+ @dataclass
274
+ class WebDocumentsDatasetParams(DatasetParams):
275
+ # Decide how often should the image attention mask is such that the
276
+ # the text attends to next image. Set to 0 for just perceding images
277
+ # NOTE: For PMD, this option doesn't apply anymore. Use `prob_image_at_end`
278
+ # to control the position of the image and corresponding image.
279
+ p_next: float = 0.5
280
+ dataset_type: DatasetTypes = DatasetTypes.WEB_DOCUMENTS
281
+
282
+
283
+ @dataclass
284
+ class DataParams(Serializable):
285
+ """Data Parameters"""
286
+
287
+ # what software to use for the dataset
288
+ use_webdataset: bool = False
289
+
290
+ # number of workers for dataloaders int
291
+ num_workers: int = 1
292
+ # allow async faster data transfer to GPUs (only make sense when CUDA GPUs are available)
293
+ # known to cause memory issues
294
+ pin_memory: bool = False
295
+ # Whether to use persistent workers for the dataloaders
296
+ persistent_workers: bool = True
297
+ realtime_processing: bool = False
298
+
299
+ train_seed: int = 1
300
+ val_seed: int = 2
301
+
302
+ # can use one config for both train + validation or specific ones if need to be different
303
+ select_n_examples: Optional[int] = None
304
+ select_n_examples_train: Optional[int] = None
305
+ select_n_examples_validation: Optional[int] = None
306
+
307
+ # TODO: Move to per dataset params as it makes more sense there
308
+ proba_interleaving_dataset: Optional[List[float]] = None
309
+
310
+ pmd: ImageCaptionPairedDatasetParams = ImageCaptionPairedDatasetParams(dataset_name=DatasetNames.PMD)
311
+ laion: ImageCaptionPairedDatasetParams = ImageCaptionPairedDatasetParams(dataset_name=DatasetNames.LAION)
312
+ cm4: WebDocumentsDatasetParams = WebDocumentsDatasetParams(dataset_name=DatasetNames.CM4)
313
+ wiki: WebDocumentsDatasetParams = WebDocumentsDatasetParams(dataset_name=DatasetNames.WIKI)
314
+
315
+
316
+ @dataclass
317
+ class OptimizerParams:
318
+ """Optimization parameters"""
319
+
320
+ # --------------------
321
+ # vl optim parameters
322
+ # --------------------
323
+ vl_optim: str = "AdamW"
324
+ vl_optim_params: Dict[str, Any] = dict_field(
325
+ dict(
326
+ # learning rate
327
+ lr=1e-4,
328
+ # betas for adam
329
+ betas=(0.9, 0.999),
330
+ weight_decay=0.1,
331
+ no_decay=["bias", "alpha", "layernorm", "ln", "layer_norm", "perceiver_resampler"],
332
+ )
333
+ )
334
+
335
+ vl_lr_scheduler: str = "get_constant_schedule_with_warmup"
336
+ # number of warmup steps for the learning rate
337
+ vl_lr_scheduler_params: Dict[str, Any] = dict_field(dict(num_warmup_steps=5_000, last_epoch=-1))
338
+ z_loss: float = 0.0
339
+
340
+
341
+ @dataclass
342
+ class Parameters(Serializable):
343
+ """base options."""
344
+
345
+ hparams: Hparams = Hparams()
346
+ optim_param: OptimizerParams = OptimizerParams()
347
+ data_param: DataParams = DataParams()
348
+ resume_param: ResumeParams = ResumeParams()
349
+ should_verify: InitVar[bool] = True
350
+
351
+ def verify(self, should_verify: bool):
352
+ if not should_verify:
353
+ return
354
+
355
+ dict_rep = vars(self)
356
+ expected = vars(self.__class__(should_verify=False))
357
+ for key, value in dict_rep.items():
358
+ if isinstance(value, dict):
359
+ diff = set(value.keys()) - set(asdict(expected[key]).keys())
360
+ raise TypeError(
361
+ f"{key} in {self.__class__.__name__} has extra keys: {diff}. Please fix your config if you are"
362
+ " using one."
363
+ )
364
+ if key not in expected:
365
+ raise ValueError(f"{key} is not a valid parameter for {self.__class__.__name__}")
366
+
367
+ def __post_init__(self, should_verify: bool = True):
368
+ """Post-initialization code"""
369
+ self.verify(should_verify=should_verify)
370
+
371
+ # copy select_n_examples to the more specific ones if the latter haven't been preset
372
+ if self.data_param.select_n_examples is not None:
373
+ if self.data_param.select_n_examples_train is None:
374
+ self.data_param.select_n_examples_train = self.data_param.select_n_examples
375
+ if self.data_param.select_n_examples_validation is None:
376
+ self.data_param.select_n_examples_validation = self.data_param.select_n_examples
377
+
378
+ # Get commit id
379
+ if self.hparams.repo_commit_id is None:
380
+ self.hparams.repo_commit_id = git.Repo(search_parent_directories=True).head.object.hexsha
381
+
382
+ # If processing on the fly, with the current implementation, we can't have `num_workers=0`
383
+ if self.data_param.realtime_processing and self.data_param.num_workers == 0:
384
+ raise ValueError(
385
+ "If doing processing on the fly (and thus using the `IterableDataset`), you can't have `num_workers`"
386
+ " equal to 0."
387
+ )
388
+
389
+ # batch_size deprecation
390
+ if self.hparams.batch_size is not None:
391
+ if self.hparams.batch_size_per_gpu > 1:
392
+ raise ValueError(
393
+ "as hparams.batch_size is deprecated - don't know how to proceed with both hparams.batch_size>1"
394
+ " and hparams.batch_size_per_gpu > 1"
395
+ )
396
+ else:
397
+ logger.warning(
398
+ "will use the deprecated hparams.batch_size, but transition to hparams.batch_size_per_gpu instead"
399
+ )
400
+ self.hparams.batch_size_per_gpu = self.hparams.batch_size
401
+ self.hparams.batch_size = None
402
+
403
+ # Assign batch size to data_param as well for dataloaders
404
+ self.data_param.batch_size = self.hparams.batch_size_per_gpu
405
+
406
+ # note: all global batch_size-related configs including hparams.grad_acc_size will be
407
+ # checked/set in trainer's setup_batch_size_related_configs since we need to know the value
408
+ # of num_processes
409
+
410
+ # Assign loggingtypes given values
411
+ self.hparams.train_logging_activations = [LoggingTypes(val) for val in self.hparams.train_logging_activations]
412
+
413
+ # Check that proba_interleaving_dataset is mutually exclusive to loss_weights_per_dataset
414
+ if self.data_param.proba_interleaving_dataset and self.hparams.loss_weights_per_dataset:
415
+ raise ValueError(
416
+ "Can't have hparams.loss_weights_per_dataset and proba_interleaving_dataset. If we have"
417
+ " loss_weights_per_dataset, it means the gradients are accumulated over datasets. Therefore a batch of"
418
+ " each given at each update and there is no use of proba_interleaving_dataset"
419
+ )
420
+
421
+ if (
422
+ self.data_param.proba_interleaving_dataset is not None
423
+ and sum(self.data_param.proba_interleaving_dataset) != 1
424
+ ):
425
+ raise ValueError("proba_interleaving_dataset must sum to 1")
426
+
427
+ self.hparams.train_logging_grad_param_deepspeed = [
428
+ LoggingTypes(val) for val in self.hparams.train_logging_grad_param_deepspeed
429
+ ]
430
+
431
+ # Resume run if there is already an existing folder for this run
432
+ if self.hparams.save_dir is not None and self.hparams.save_dir.exists():
433
+ save_dir_has_checkpoints = (
434
+ len([dir for dir in self.hparams.save_dir.iterdir() if (dir.is_dir() and "opt_step" in str(dir))]) > 0
435
+ )
436
+ if self.hparams.resume_run is not None and not self.hparams.resume_run and save_dir_has_checkpoints:
437
+ logger.warning(
438
+ "`resume_run` was explicitely set to False (i.e. starting from scratch), but the experiment"
439
+ " folder already has been populated with previous runs.\nAlready saved checkpoints will be"
440
+ " overwritten (at best, when `train_saving_opt_steps` is the same) or will be mixed with the new"
441
+ " checkpoints of a potentially brand new experiment. Would it make sense to create a new"
442
+ " `save_dir`?"
443
+ )
444
+ self.hparams.resume_run = save_dir_has_checkpoints
445
+
446
+ # Setup all args needed to resume a run
447
+ if self.hparams.resume_run:
448
+ # Get last step directory
449
+ if self.resume_param.opt_step_dir is None and not self.resume_param.resume_last:
450
+ raise ValueError(
451
+ "`opt_step_dir` cannot be None while `resume_last` is False. Choose which dir you want to resume"
452
+ " from..."
453
+ )
454
+ if self.resume_param.resume_last:
455
+ if self.resume_param.opt_step_dir is not None:
456
+ raise ValueError(
457
+ "`resume_last` cannot be True while `opt_step_dir` is not None. Choose which dir you want to"
458
+ " resume from..."
459
+ )
460
+ latest_path = self.hparams.save_dir / "latest_opt_step_dir"
461
+ with open(latest_path, "r") as fd:
462
+ self.resume_param.opt_step_dir = Path(fd.read().strip())
463
+ if not (self.resume_param.opt_step_dir.exists() and self.resume_param.opt_step_dir.is_dir()):
464
+ raise ValueError(
465
+ f"It appears that the path in the `latest_opt_step_dir` file {latest_path} is invalid. It's"
466
+ " either does not exist or is not a directory. Please fix that."
467
+ )
468
+
469
+ with open(self.resume_param.opt_step_dir / "resume_run_infos.json", "r") as f:
470
+ resume_infos = json.load(f)
471
+ logger.info(f"Resuming from {self.resume_param.opt_step_dir}")
472
+ self.resume_param.accelerator_state_dir = self.resume_param.opt_step_dir / "accelerator_state"
473
+ self.resume_param.model_file = self.resume_param.opt_step_dir / "unwrapped_model"
474
+ self.resume_param.model_config_file = self.resume_param.opt_step_dir / "unwrapped_model/config.json"
475
+ self.resume_param.tokenizer = self.resume_param.opt_step_dir / "tokenizer"
476
+
477
+ self.resume_param.train_logs = resume_infos["train_logs"]
478
+ self.resume_param.resume_opt_step = resume_infos["resume_opt_step"]
479
+ self.resume_param.resume_epoch = resume_infos["resume_epoch"]
480
+ self.resume_param.resume_dataset_state = resume_infos.get("resume_dataset_state", list())
481
+
482
+ gbs_running = resume_infos["gbs_running"]
483
+ self.resume_param.gbs_running.global_batch_size_current = gbs_running["global_batch_size_current"]
484
+ self.resume_param.gbs_running.global_seen_samples = gbs_running["global_seen_samples"]
485
+ self.resume_param.gbs_running.next_goal_samples = gbs_running["next_goal_samples"]
486
+ self.resume_param.gbs_running.grad_acc_size_current = gbs_running["grad_acc_size_current"]
487
+
488
+ self.hparams.wandb_run_id = resume_infos["wandb_run_id"]
489
+ self.hparams.seed = resume_infos["seed"]
490
+
491
+ # Should not happen, but this is in case there is a run mixing
492
+ # wandb_enable = True and wandb_enable = False between jobs
493
+ if not self.hparams.wandb_enable:
494
+ self.hparams.wandb_run_id = ""
495
+
496
+ @classmethod
497
+ def parse(cls):
498
+ cfgfile_parser = ArgumentParser(add_help=False)
499
+ cfgfile_parser.add_arguments(CfgFileConfig, dest="cfgfile")
500
+ cfgfile_args, rest = cfgfile_parser.parse_known_args()
501
+
502
+ cfgfile: CfgFileConfig = cfgfile_args.cfgfile
503
+
504
+ file_config: Optional[Parameters] = None
505
+ if cfgfile.config is not None:
506
+ file_config = Parameters.load(cfgfile.config, load_fn=yaml.safe_load)
507
+
508
+ parser = ArgumentParser()
509
+
510
+ # add cfgfile args so they appear in the help message
511
+ parser.add_arguments(CfgFileConfig, dest="cfgfile")
512
+ parser.add_arguments(Parameters, dest="parameters", default=file_config)
513
+
514
+ # XXX: currently when called from tests we don't want to parse pytest arguments, so either
515
+ # this whole logic needs to be rewritten to not always call `parser.parse_args` but only
516
+ # when needed, for now as a workaround using `parse_known_args` and ignoring the args which
517
+ # don't belong to this program
518
+ args, unknown = parser.parse_known_args()
519
+
520
+ parameters: Parameters = args.parameters
521
+
522
+ parameters.save_config = cfgfile.save_config
523
+
524
+ return parameters
525
+
526
+ def save_config_state(self):
527
+ if self.save_config:
528
+ self.hparams.save_dir.mkdir(parents=True, exist_ok=True)
529
+ if self.hparams.job_id is not None:
530
+ config_file_name = f"{self.hparams.job_id}_config.yaml"
531
+ else:
532
+ config_file_name = "config.yaml"
533
+ self.save(self.hparams.save_dir / config_file_name, indent=4)
534
+
535
+
536
+ def get_config(print_config: bool = True):
537
+ parameters: Parameters = Parameters.parse()
538
+ if print_config:
539
+ print(parameters)
540
+ return parameters
541
+
542
+
543
+ if __name__ == "__main__":
544
+ config = get_config()
545
+ config.save_config_state()
m4/training/dataset_utils.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import webdataset as wds
5
+ from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener
6
+
7
+ from m4.training.types import DatasetTypes
8
+
9
+
10
+ meta_prefix = "__"
11
+ meta_suffix = "__"
12
+
13
+ logger = logging.getLogger(__name__)
14
+ trace = False
15
+
16
+
17
+ def webdoc_valid_sample(sample):
18
+ """Check whether a sample is valid.
19
+
20
+ :param sample: sample to be checked
21
+ """
22
+ return (
23
+ sample is not None
24
+ and isinstance(sample, dict)
25
+ and len(list(sample.keys())) > 0
26
+ and not sample.get("__bad__", False)
27
+ and sample_has_all_files(sample)
28
+ )
29
+
30
+
31
+ def sample_has_all_files(current_sample):
32
+ meta = current_sample.get("metadata.value", None)
33
+ if meta is None:
34
+ return False
35
+ meta = meta.decode("utf-8")
36
+ if len(meta) == 0:
37
+ return False
38
+ target_file_list = meta.split("\n")
39
+ fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")]
40
+ fnames = [current_sample[key] for key in fname_keys]
41
+ check = all([fname in fnames for fname in target_file_list])
42
+ if not check:
43
+ return False
44
+ return True
45
+
46
+
47
+ class ImageDecoder:
48
+ def __call__(self, bytes_):
49
+ import io
50
+
51
+ import PIL.Image
52
+
53
+ img = PIL.Image.open(io.BytesIO(bytes_))
54
+ img.load()
55
+ return img
56
+
57
+
58
+ # Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183
59
+ def log_and_continue(exn):
60
+ """Call in an exception handler to ignore any exception, issue a warning, and continue."""
61
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
62
+ return True
63
+
64
+
65
+ # Adapt group_by_keys to our webdocument format in which each samples contains several text and image files
66
+ # https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250
67
+ def group_by_keys_interleaved(data, handler=log_and_continue):
68
+ """Return function over iterator that groups key, value pairs into samples."""
69
+ current_sample = None
70
+ for filesample in data:
71
+ try:
72
+ assert isinstance(filesample, dict)
73
+ fname, value = filesample["fname"], filesample["data"]
74
+ fname = fname.strip("./")
75
+ if fname.endswith(".metadata.txt"):
76
+ prefix, data_type, extension = fname.split(".")
77
+ suffix = data_type
78
+ else:
79
+ prefix, idx, data_type, extension = fname.split(".")
80
+ if data_type not in ["text", "image"]:
81
+ raise ValueError(f"{fname}: unknown data type {data_type}")
82
+ suffix = idx
83
+ if trace:
84
+ print(
85
+ f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:"
86
+ f" {current_sample.keys() if isinstance(current_sample, dict) else None}"
87
+ )
88
+ if prefix is None:
89
+ continue
90
+ if current_sample is None or prefix != current_sample["__key__"]:
91
+ valid = webdoc_valid_sample(current_sample)
92
+ if valid:
93
+ yield current_sample
94
+ elif current_sample is not None:
95
+ logging.warning(f"{fname}: invalid sample {current_sample} ignored")
96
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
97
+ if suffix in current_sample:
98
+ raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
99
+ current_sample[f"{suffix}.value"] = value
100
+ current_sample[f"{suffix}.type"] = data_type
101
+ current_sample[f"{suffix}.fname"] = fname
102
+ except Exception as exn:
103
+ exn.args = exn.args + (filesample.get("stream"), filesample.get("url"))
104
+ if handler(exn):
105
+ continue
106
+ else:
107
+ break
108
+
109
+ if webdoc_valid_sample(current_sample):
110
+ yield current_sample
111
+
112
+
113
+ def _tarfile_to_webdocument_samples(src, handler=log_and_continue):
114
+ streams = url_opener(src, handler=handler)
115
+ files = tar_file_expander(streams, handler=handler)
116
+ samples = group_by_keys_interleaved(files, handler=handler)
117
+ return samples
118
+
119
+
120
+ tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples)
121
+
122
+
123
+ def _collate_texts_and_images_webdocument(data, handler=log_and_continue):
124
+ for sample in data:
125
+ try:
126
+ max_example_indices = max(
127
+ [int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"]
128
+ )
129
+ texts = [None for _ in range(max_example_indices + 1)]
130
+ images = [None for _ in range(max_example_indices + 1)]
131
+ for idx in range(max_example_indices + 1):
132
+ if f"{idx}.value" not in sample:
133
+ continue
134
+ if "text" in sample[f"{idx}.type"]:
135
+ texts[idx] = sample[f"{idx}.value"]
136
+ elif "image" in sample[f"{idx}.type"]:
137
+ images[idx] = sample[f"{idx}.value"]
138
+ else:
139
+ raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}")
140
+ example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images}
141
+ yield example
142
+ except Exception as exn:
143
+ exn.args = exn.args + (sample.get("stream"), sample.get("url"))
144
+ if handler(exn):
145
+ continue
146
+ else:
147
+ break
148
+
149
+
150
+ collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument)
151
+
152
+
153
+ def _decode_image_and_text_webdocument(data, handler=log_and_continue):
154
+ image_decoder = ImageDecoder()
155
+ for sample in data:
156
+ try:
157
+ sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]]
158
+ sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]]
159
+ yield sample
160
+ except Exception as exn:
161
+ exn.args = exn.args + (sample.get("stream"), sample.get("url"))
162
+ if handler(exn):
163
+ continue
164
+ else:
165
+ break
166
+
167
+
168
+ decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument)
169
+
170
+
171
+ def collate_dicts(samples):
172
+ keys = samples[0].keys()
173
+ batched_samples = {key: [sample[key] for sample in samples] for key in keys}
174
+ return batched_samples
175
+
176
+
177
+ def get_webdocuments_webdataset(
178
+ urls,
179
+ batch_size,
180
+ shuffle_initial_urls_list=False,
181
+ shuffle_before_split_by_node_buffer_size=100,
182
+ shuffle_before_split_by_worker_buffer_size=100,
183
+ shuffle_after_tarfile_to_samples_buffer_size=100,
184
+ shuffle_after_batching_buffer_size=1000,
185
+ ):
186
+ if shuffle_initial_urls_list:
187
+ random.shuffle(urls)
188
+
189
+ pipeline_list = [wds.SimpleShardList(urls)]
190
+
191
+ if shuffle_before_split_by_node_buffer_size is not None:
192
+ pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
193
+
194
+ pipeline_list.append(wds.split_by_node)
195
+
196
+ if shuffle_before_split_by_worker_buffer_size is not None:
197
+ pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
198
+
199
+ pipeline_list.extend(
200
+ [
201
+ wds.split_by_worker,
202
+ tarfile_to_webdocument_samples(),
203
+ ]
204
+ )
205
+
206
+ if shuffle_after_tarfile_to_samples_buffer_size is not None:
207
+ pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
208
+
209
+ pipeline_list.extend(
210
+ [
211
+ collate_texts_and_images_webdocument(),
212
+ decode_image_and_text_webdocument(),
213
+ wds.batched(batch_size, collation_fn=collate_dicts, partial=True),
214
+ ]
215
+ )
216
+
217
+ if shuffle_after_batching_buffer_size is not None:
218
+ pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
219
+
220
+ dataset = wds.DataPipeline(pipeline_list)
221
+ return dataset
222
+
223
+
224
+ def split_keep_2(x):
225
+ x = x.strip("./")
226
+ x_splitter = x.split(".")
227
+ return x_splitter[0], x_splitter[1]
228
+
229
+
230
+ def _tarfile_to_pair_samples(src, handler=log_and_continue):
231
+ streams = url_opener(src, handler=handler)
232
+ files = tar_file_expander(streams, handler=handler)
233
+ samples = group_by_keys(files, keys=split_keep_2, handler=handler)
234
+ return samples
235
+
236
+
237
+ tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples)
238
+
239
+
240
+ def _decode_image_and_text_pairs(data, handler=log_and_continue):
241
+ image_decoder = ImageDecoder()
242
+ for sample in data:
243
+ try:
244
+ sample["image"] = image_decoder(sample["image"])
245
+ sample["text"] = sample["text"].decode("utf-8")
246
+ yield sample
247
+ except Exception as exn:
248
+ exn.args = exn.args + (sample.get("stream"), sample.get("url"))
249
+ if handler(exn):
250
+ continue
251
+ else:
252
+ break
253
+
254
+
255
+ decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs)
256
+
257
+
258
+ def get_image_caption_pairs_webdataset(
259
+ urls,
260
+ batch_size,
261
+ shuffle_initial_urls_list=False,
262
+ shuffle_before_split_by_node_buffer_size=100,
263
+ shuffle_before_split_by_worker_buffer_size=100,
264
+ shuffle_after_tarfile_to_samples_buffer_size=100,
265
+ shuffle_after_batching_buffer_size=1000,
266
+ ):
267
+ if shuffle_initial_urls_list:
268
+ random.shuffle(urls)
269
+
270
+ pipeline_list = [wds.SimpleShardList(urls)]
271
+
272
+ if shuffle_before_split_by_node_buffer_size is not None:
273
+ pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
274
+
275
+ pipeline_list.append(wds.split_by_node)
276
+
277
+ if shuffle_before_split_by_worker_buffer_size is not None:
278
+ pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
279
+
280
+ pipeline_list.extend(
281
+ [
282
+ wds.split_by_worker,
283
+ tarfile_to_pair_samples(handler=log_and_continue),
284
+ ]
285
+ )
286
+
287
+ if shuffle_after_tarfile_to_samples_buffer_size is not None:
288
+ pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
289
+
290
+ pipeline_list.extend(
291
+ [
292
+ decode_image_and_text_pairs(),
293
+ wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed
294
+ ]
295
+ )
296
+
297
+ if shuffle_after_batching_buffer_size is not None:
298
+ pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
299
+
300
+ dataset = wds.DataPipeline(pipeline_list)
301
+ return dataset
302
+
303
+
304
+ def get_webdataset(
305
+ urls,
306
+ ds_type: DatasetTypes,
307
+ batch_size: int,
308
+ shuffle_initial_urls_list,
309
+ shuffle_before_split_by_node_buffer_size,
310
+ shuffle_before_split_by_worker_buffer_size,
311
+ shuffle_after_tarfile_to_samples_buffer_size,
312
+ shuffle_after_batching_buffer_size,
313
+ ):
314
+ if ds_type == DatasetTypes.WEB_DOCUMENTS:
315
+ return get_webdocuments_webdataset(
316
+ urls,
317
+ batch_size,
318
+ shuffle_initial_urls_list,
319
+ shuffle_before_split_by_node_buffer_size,
320
+ shuffle_before_split_by_worker_buffer_size,
321
+ shuffle_after_tarfile_to_samples_buffer_size,
322
+ shuffle_after_batching_buffer_size,
323
+ )
324
+ elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS:
325
+ return get_image_caption_pairs_webdataset(
326
+ urls,
327
+ batch_size,
328
+ shuffle_initial_urls_list,
329
+ shuffle_before_split_by_node_buffer_size,
330
+ shuffle_before_split_by_worker_buffer_size,
331
+ shuffle_after_tarfile_to_samples_buffer_size,
332
+ shuffle_after_batching_buffer_size,
333
+ )
334
+ else:
335
+ raise ValueError(f"Unknown dataset type: {ds_type}")
336
+
337
+
338
+ def check_webdataset_command(command):
339
+ if "s3:/" not in command:
340
+ return True
341
+
342
+ command = command.strip()
343
+ if not command.startswith("pipe:bash"):
344
+ return False
345
+
346
+ if not command.endswith(".tar"):
347
+ return False
348
+
349
+ if "get_file.sh" not in command:
350
+ return False
351
+
352
+ return True
m4/training/debug_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Trainer debug utils """
2
+
3
+
4
+ def dump_optim_states(self):
5
+ """dumps basic information about the state of the optimizer"""
6
+
7
+ print("*** Optim States Dump:")
8
+ param_groups_cnt = len(self.vl_optim.param_groups)
9
+ # state dict has more than param_groups info, so extract only the param groups
10
+ param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
11
+ for i, state in enumerate(param_group_states):
12
+ print(f"param group: {i}")
13
+ print(f" step={state['step']}")
14
+ print(f" exp_avg all_zero={all(state['exp_avg'] == 0)}")
15
+ print(f" exp_avg_sq all_zero={all(state['exp_avg_sq'] == 0)}")
16
+
17
+ # can also dump LR state if need be
18
+ # print(f"LR={self.vl_scheduler.get_last_lr()}")
19
+
20
+
21
+ def validate_optim_states_are_reset(self):
22
+ """
23
+ for a new or fully reset optimizer we expect all zeros `exp_avg` and `exp_avg_sq` state tensors and step=1
24
+ """
25
+
26
+ param_groups_cnt = len(self.vl_optim.param_groups)
27
+ param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
28
+ for i, state in enumerate(param_group_states):
29
+ if state["step"] != 1:
30
+ raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['step']}")
31
+ if not all(state["exp_avg"] == 0):
32
+ raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg']}")
33
+ if not all(state["exp_avg_sq"] == 0):
34
+ raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg_sq']}")
m4/training/packing.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from bisect import bisect_left
3
+ from collections import OrderedDict
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ # Hyper-parameters
16
+ _IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token
17
+ _MIN_LENGTH_DOCUMENTS_TO_PACK = (
18
+ 5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens)
19
+ )
20
+
21
+
22
+ def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
23
+ # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
24
+
25
+ # If any of images index are more than num_classes, set them to -1.
26
+ # Words after the max number of images allowed have been seen don't attend on anything
27
+ if num_classes != -1:
28
+ incremental_mask[incremental_mask >= num_classes] = -1
29
+
30
+ negatives = incremental_mask == -1
31
+ incremental_mask[negatives] = 0
32
+ attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
33
+ attn_mask[negatives, :] = 0
34
+ return attn_mask
35
+
36
+
37
+ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
38
+ image_attention_mask = torch.full_like(input_ids, fill_value=-1)
39
+ next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
40
+ image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
41
+ eod_token_id = tokenizer.eos_token_id
42
+ for batch_idx in range(input_ids.size(0)):
43
+ count = -1
44
+ seen_eod = False
45
+ for idx, token_id in enumerate(input_ids[batch_idx]):
46
+ if token_id == image_token_id:
47
+ count += 1
48
+ image_attention_mask[batch_idx][idx] = count
49
+ seen_eod = False
50
+ else:
51
+ image_attention_mask[batch_idx][idx] = count
52
+
53
+ if seen_eod:
54
+ image_attention_mask[batch_idx][idx] = -1
55
+
56
+ if token_id == eod_token_id:
57
+ seen_eod = True
58
+
59
+ for batch_idx in range(input_ids.size(0)):
60
+ count = -1
61
+ seen_eod = False
62
+ for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
63
+ token_id = input_ids[batch_idx][idx]
64
+ if token_id == image_token_id:
65
+ count += 1
66
+ next_image_attention_mask[batch_idx][idx] = count
67
+ seen_eod = False
68
+ else:
69
+ next_image_attention_mask[batch_idx][idx] = count
70
+
71
+ if token_id == eod_token_id:
72
+ seen_eod = True
73
+
74
+ if seen_eod:
75
+ next_image_attention_mask[batch_idx][idx] = -1
76
+
77
+ non_negative_indices = next_image_attention_mask[batch_idx] != -1
78
+ next_image_attention_mask[batch_idx][non_negative_indices] -= count
79
+ next_image_attention_mask[batch_idx][non_negative_indices] *= -1
80
+
81
+ return image_attention_mask, next_image_attention_mask
82
+
83
+
84
+ def laplacian_blur_detection(image, threshold=0.0):
85
+ # compute the Laplacian of the image and then return the focus
86
+ # measure, which is simply the variance of the Laplacian
87
+ if threshold == 0.0:
88
+ return False
89
+
90
+ image = np.array(image)
91
+
92
+ if len(image.shape) == 3 and image.shape[2] == 3:
93
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
94
+ return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold
95
+ else:
96
+ # Don't remove grayscale images
97
+ return False
98
+
99
+
100
+ def fft_blur_detection(image, size=50, threshold=0.0):
101
+ if threshold == 0.0:
102
+ return False
103
+ (h, w) = image.shape
104
+ (cX, cY) = (int(w / 2.0), int(h / 2.0))
105
+ fft = np.fft.fft2(image)
106
+ fftShift = np.fft.fftshift(fft)
107
+ fftShift[cY - size : cY + size, cX - size : cX + size] = 0
108
+ fftShift = np.fft.ifftshift(fftShift)
109
+ recon = np.fft.ifft2(fftShift)
110
+ magnitude = 20 * np.log(np.abs(recon))
111
+ mean = np.mean(magnitude)
112
+ return mean < threshold
113
+
114
+
115
+ def split_pack_and_pad(
116
+ sample,
117
+ tokenizer,
118
+ max_seq_len,
119
+ image_transform,
120
+ max_num_images,
121
+ max_num_samples_per_document=10,
122
+ prefix_seed=(0, 0),
123
+ is_blurred_fn=None,
124
+ blur_threshold=0.0,
125
+ add_begin_of_doc_token=False,
126
+ add_end_of_doc_token=True,
127
+ max_num_images_per_document=None,
128
+ ):
129
+ """
130
+ Return a batch of samples in the format expected by the model which
131
+ includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`,
132
+ and `next_image_attention_mask`. The `input_ids` are sampled from the document to
133
+ ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together.
134
+ For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document`
135
+ (where the latter is proportional to the length of the document and inversely proportional to the length of subsequences)
136
+ `input_ids` with sequence length `max_seq_len` from the document. This means that
137
+ each sample sampled can have different start index. Based on the start index of sample that
138
+ has been sampled, we also sample a maximum of `max_num_images` images from the document.
139
+ If there are less than `max_num_images` images in the document, we pad the images with zeros.
140
+ The start indexes are skewed towards subsequences that contain images.
141
+
142
+ Args:
143
+ sample (Dict): A sample object containing the document with images and text.
144
+ tokenizer (PretrainedTokenizer): Text tokenizer to be used.
145
+ max_seq_len (int): Maximum sequence length of the returned text tokens.
146
+ image_transform (Callable): Transform to be applied on the images
147
+ max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros.
148
+ max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10.
149
+ prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice`
150
+
151
+ Returns:
152
+ _type_: _description_
153
+ """
154
+ text_batch = sample["texts"]
155
+
156
+ image_batch = sample.get("image_embeddings", None)
157
+ is_raw_images = False
158
+ if image_batch is None:
159
+ image_batch = sample.get("images", None)
160
+ is_raw_images = True
161
+ if image_batch is None:
162
+ raise ValueError("Either image_embeddings or images must be present in the sample")
163
+
164
+ image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
165
+ last_was_image = False
166
+
167
+ if is_blurred_fn is None:
168
+ is_blurred_fn = fft_blur_detection
169
+
170
+ all_images = []
171
+ all_texts = []
172
+ for raw_images, raw_texts in zip(image_batch, text_batch):
173
+ # Filter ones that don't have either one image and one text word
174
+ if not any(raw_images) or not any(raw_texts):
175
+ continue
176
+
177
+ if max_num_images_per_document:
178
+ num_images = sum([1 if image is not None else 0 for image in raw_images])
179
+ if num_images > max_num_images_per_document:
180
+ continue
181
+
182
+ any_blurred = False
183
+
184
+ if is_raw_images and blur_threshold > 0.0:
185
+ for image in raw_images:
186
+ if image is not None:
187
+ image = _convert_to_rgb(image)
188
+ any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold)
189
+ if any_blurred:
190
+ break
191
+
192
+ if any_blurred:
193
+ continue
194
+
195
+ inds_of_texts_to_split = [
196
+ i
197
+ for i, text in enumerate(raw_texts)
198
+ if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
199
+ ]
200
+ if inds_of_texts_to_split:
201
+ splitted_raw_images, splitted_raw_texts = [], []
202
+ previous_i = 0
203
+ for i in inds_of_texts_to_split:
204
+ splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
205
+ part1, part2 = splitting[0], splitting[-1]
206
+
207
+ sub_doc_images = raw_images[previous_i:i] + [None]
208
+ sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
209
+ if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None
210
+ continue
211
+
212
+ splitted_raw_images.append(sub_doc_images)
213
+ splitted_raw_texts.append(sub_doc_texts)
214
+
215
+ if part2.strip() == "":
216
+ previous_i = i + 1
217
+ else:
218
+ raw_texts[i] = part2.strip()
219
+ previous_i = i
220
+
221
+ if previous_i < len(raw_images) and any(raw_images[previous_i:]):
222
+ splitted_raw_images.append(raw_images[previous_i:])
223
+ splitted_raw_texts.append(raw_texts[previous_i:])
224
+
225
+ else:
226
+ splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]
227
+
228
+ # Sanity check
229
+ if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
230
+ raise ValueError(
231
+ "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
232
+ " Something core went wrong during the splitting and needs to be fixed."
233
+ )
234
+
235
+ for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
236
+ images, web_text = [], ""
237
+ for image, text in zip(s_r_ims, s_r_txts):
238
+ if text is None and image is None:
239
+ continue
240
+
241
+ if image is not None:
242
+ web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
243
+ if is_raw_images:
244
+ images.append(image_transform(image))
245
+ else:
246
+ images.append(torch.tensor(image))
247
+ last_was_image = True
248
+ elif text is not None:
249
+ if last_was_image:
250
+ web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
251
+ last_was_image = False
252
+ else:
253
+ web_text += f" {text}" if web_text != "" else text
254
+
255
+ if last_was_image:
256
+ web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"
257
+
258
+ web_text = web_text.strip(" ")
259
+
260
+ # This is mostly a sanity check. Cases like that should not happen at that point.
261
+ if web_text == "" or len(images) == 0:
262
+ continue
263
+
264
+ images = torch.stack(images)
265
+ all_images.append(images)
266
+
267
+ web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
268
+ if add_end_of_doc_token:
269
+ web_text_ids += [tokenizer.eos_token_id]
270
+
271
+ if add_begin_of_doc_token:
272
+ web_text_ids = [tokenizer.bos_token_id] + web_text_ids
273
+ all_texts.append(web_text_ids)
274
+
275
+ output_input_ids = []
276
+ output_images = []
277
+ output_attention_masks = []
278
+ output_num_images = []
279
+ output_num_text_tokens = []
280
+
281
+ input_ids_to_pack = []
282
+ images_to_pack = []
283
+ for images, text in zip(all_images, all_texts):
284
+ # We save all the documents which are shorter than the max_seq_len to pack them together.
285
+ if len(text) <= max_seq_len:
286
+ if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences
287
+ continue
288
+ input_ids_to_pack.extend(text)
289
+ images_to_pack.extend(images)
290
+ else:
291
+ # Computing the bonus scores for tokens near images to skew the sampling towards them
292
+ # The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled.
293
+ # Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images.
294
+ # We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution.
295
+ # Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added.
296
+
297
+ # Now the remaining question is which precedding tokens do we distribue bonuses to.
298
+ # We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx].
299
+ # For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order).
300
+ # If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to.
301
+ # In practice, images will tend to be at the beginning of the sampled sub-sequence.
302
+ # If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to.
303
+ # In practice, images will tend to be at the end of the sampled sub-sequence.
304
+ # To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part.
305
+ # In practise, images will tend to be in the middle of the sampled sequence.
306
+
307
+ # Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence),
308
+ # and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image.
309
+ all_scores = np.array([1] * len(text))
310
+ for img_token_idx in np.where(np.array(text) == image_token_id)[0]:
311
+ all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE
312
+ # all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents)
313
+ all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK]
314
+
315
+ # The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length
316
+ max_num_samples_for_curr_document = len(text) // max_seq_len
317
+ # Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len)
318
+ choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice(
319
+ range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing
320
+ min(
321
+ len(text) - max_seq_len, 2 * max_num_samples_per_document
322
+ ), # Sampling more than necessary and then breaking out of the for loop once we have enough samples
323
+ p=all_scores / np.linalg.norm(all_scores, ord=1),
324
+ replace=False,
325
+ )
326
+
327
+ nb_effective_sequences_out_of_sampling = 0
328
+ for start_index in choices:
329
+ image_start_index = text[:start_index].count(image_token_id)
330
+ text_sub_sequence = text[start_index : start_index + max_seq_len]
331
+ image_count = text_sub_sequence.count(image_token_id)
332
+ if image_count == 0:
333
+ # Skip if there are no images in the sequence
334
+ continue
335
+
336
+ if len(text_sub_sequence) < max_seq_len:
337
+ # If the sub-sequence is shorter than max_seq_len, we reserve it for packing
338
+ # It necessarily mean that the sub-sequence was sampled towards the end of the document,
339
+ # which implies that we only need the `image_start_index` and not the `image_end_index`
340
+ if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]):
341
+ # A safeguard for this
342
+ logger.warning(
343
+ "Skipping this sample because of mismatch in actual number of images and "
344
+ "the '<image>' tokens in the text"
345
+ )
346
+ continue
347
+ input_ids_to_pack.extend(text_sub_sequence)
348
+ images_to_pack.extend(images[image_start_index:])
349
+ continue
350
+
351
+ current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)]
352
+ if len(current_images) != min(max_num_images, image_count):
353
+ # A safeguard for something off about this document, maybe `<image>` tag that
354
+ # by there from before or some issue in parsing the image?
355
+ logger.warning(
356
+ "Skipping this sample because of mismatch in actual number of images and "
357
+ "the '<image>' tokens in the text"
358
+ )
359
+ break
360
+ padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:])
361
+ padded_image_tensor[: min(max_num_images, image_count)] = current_images
362
+ output_images.append(padded_image_tensor)
363
+ output_num_images.append(min(max_num_images, image_count))
364
+
365
+ output_input_ids.append(torch.tensor(text_sub_sequence))
366
+ output_num_text_tokens.append(len(text_sub_sequence))
367
+
368
+ attention_mask = torch.ones((max_seq_len,), dtype=torch.long)
369
+ output_attention_masks.append(attention_mask)
370
+
371
+ nb_effective_sequences_out_of_sampling += 1
372
+ if nb_effective_sequences_out_of_sampling >= min(
373
+ max_num_samples_for_curr_document, max_num_samples_per_document
374
+ ):
375
+ # We got all the samples we need for this document, so breaking out
376
+ break
377
+
378
+ # Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack`
379
+ if input_ids_to_pack:
380
+ image_counter = 0
381
+ for i in range(0, len(input_ids_to_pack), max_seq_len):
382
+ current_input_ids = input_ids_to_pack[i : i + max_seq_len]
383
+ unpadded_seq_len = len(current_input_ids)
384
+ num_images = current_input_ids.count(image_token_id)
385
+ if num_images == 0:
386
+ continue
387
+ current_images = images_to_pack[image_counter : image_counter + num_images]
388
+ image_counter += num_images
389
+ if unpadded_seq_len < max_seq_len:
390
+ padded_input_ids = [tokenizer.pad_token_id] * max_seq_len
391
+ padded_input_ids[:unpadded_seq_len] = current_input_ids
392
+ current_input_ids = padded_input_ids
393
+ elif unpadded_seq_len > max_seq_len:
394
+ # This case has no purpose other than safeguard
395
+ continue
396
+ try:
397
+ current_images = torch.stack(current_images)[:max_num_images]
398
+ except Exception:
399
+ continue
400
+ padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
401
+ padded_image_tensor[: current_images.size(0)] = current_images
402
+ attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
403
+ attention_mask[:unpadded_seq_len] = 1
404
+
405
+ output_images.append(padded_image_tensor)
406
+ output_input_ids.append(torch.tensor(current_input_ids))
407
+ output_num_text_tokens.append(unpadded_seq_len)
408
+ output_num_images.append(min(max_num_images, num_images))
409
+
410
+ output_attention_masks.append(attention_mask)
411
+
412
+ if len(output_images) == 0 or len(output_input_ids) == 0:
413
+ result = {
414
+ "input_ids": torch.tensor([], dtype=torch.long),
415
+ "attention_mask": torch.tensor([], dtype=torch.bool),
416
+ "image_attention_mask": torch.tensor([], dtype=torch.bool),
417
+ "next_image_attention_mask": torch.tensor([], dtype=torch.bool),
418
+ "num_images": torch.tensor([], dtype=torch.long),
419
+ "num_text_tokens": torch.tensor([], dtype=torch.long),
420
+ }
421
+ if is_raw_images:
422
+ result["pixel_values"] = torch.tensor([], dtype=torch.float32)
423
+ else:
424
+ result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
425
+ return result
426
+
427
+ output_input_ids = torch.stack(output_input_ids)
428
+ output_images = torch.stack(output_images)
429
+ output_attention_masks = torch.stack(output_attention_masks)
430
+
431
+ image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(
432
+ output_input_ids, tokenizer
433
+ )
434
+ image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
435
+ next_image_attention_mask = incremental_to_binary_attention_mask(
436
+ next_image_attention_mask, num_classes=max_num_images
437
+ )
438
+
439
+ result = {
440
+ "input_ids": output_input_ids,
441
+ "attention_mask": output_attention_masks,
442
+ "image_attention_mask": image_attention_mask,
443
+ "next_image_attention_mask": next_image_attention_mask,
444
+ "num_images": torch.tensor(output_num_images),
445
+ "num_text_tokens": torch.tensor(output_num_text_tokens),
446
+ }
447
+ if is_raw_images:
448
+ result["pixel_values"] = output_images
449
+ else:
450
+ result["image_embeddings"] = output_images
451
+ return result
452
+
453
+
454
+ def split_and_pad_pmd(
455
+ sample,
456
+ tokenizer,
457
+ max_seq_len,
458
+ image_transform,
459
+ max_num_images,
460
+ prefix_seed=(0, 0),
461
+ is_blurred_fn=None,
462
+ blur_threshold=0.0,
463
+ prob_image_at_end=0.5, # If 1, the <image> token is always added at the end of the text
464
+ # If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated.
465
+ padding_tolerance=-1,
466
+ add_begin_of_doc_token=False,
467
+ add_end_of_doc_token=True,
468
+ ):
469
+ if is_blurred_fn is None:
470
+ is_blurred_fn = fft_blur_detection
471
+
472
+ text_batch = sample["text"]
473
+ image_batch = sample.get("image_embedding", None)
474
+ is_raw_images = False
475
+ if image_batch is None:
476
+ image_batch = sample.get("image", None)
477
+ is_raw_images = True
478
+
479
+ filtered_image_batch = []
480
+ filtered_input_ids = []
481
+
482
+ # Define whether for the current PMD batch whether the images will be at the start or at the end.
483
+ rng = np.random.default_rng(seed=list(prefix_seed))
484
+ is_image_at_end = False
485
+
486
+ # rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will
487
+ # always be less than `prob_image_at_end` and `is_image_at_end` will always be True.
488
+ # This means that images will always be at the end of the text.
489
+ if rng.random() < prob_image_at_end:
490
+ is_image_at_end = True
491
+
492
+ for image, text in zip(image_batch, text_batch):
493
+ if text is None or image is None:
494
+ continue
495
+
496
+ if is_raw_images and is_blurred_fn(image, threshold=blur_threshold):
497
+ continue
498
+
499
+ sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}"
500
+
501
+ # Remove trailing and leading whitespaces, including newlines and tabs
502
+ text = text.strip()
503
+
504
+ if is_image_at_end:
505
+ sample_text = f"{text}{sample_text}"
506
+ else:
507
+ sample_text = f"{sample_text}{text}"
508
+
509
+ sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False)
510
+ if add_end_of_doc_token:
511
+ sample_input_ids += [tokenizer.eos_token_id]
512
+
513
+ if add_begin_of_doc_token:
514
+ sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids
515
+
516
+ filtered_image_batch.append(image)
517
+ filtered_input_ids.append(sample_input_ids)
518
+
519
+ # sort by length of text and save same length elements in a mapping so we
520
+ # can retrieve candidates later.
521
+ filtered_image_batch, filtered_input_ids = zip(
522
+ *sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1]))
523
+ )
524
+ mapping_by_len = OrderedDict()
525
+ for i, sample_input_ids in enumerate(filtered_input_ids):
526
+ if len(sample_input_ids) not in mapping_by_len:
527
+ mapping_by_len[len(sample_input_ids)] = []
528
+ mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids))
529
+
530
+ all_images = []
531
+ all_texts = []
532
+ all_attention_masks = []
533
+ all_num_images = []
534
+ all_num_text_tokens = []
535
+ current_text = []
536
+ current_images = []
537
+
538
+ while True:
539
+ current_lens = list(mapping_by_len.keys())
540
+ if len(current_text) > 0:
541
+ # Now we try to do a binary search to find the biggest sequence that
542
+ # we can fit into the current sequence.
543
+ # This will eventually use up bigger sequences faster which is good
544
+ # and leave smaller sequences to pack with each other later.
545
+ diff = max_seq_len - len(current_text)
546
+ if len(current_lens) == 0:
547
+ possible_index = -1
548
+ else:
549
+ possible_index = bisect_left(current_lens, diff)
550
+ if possible_index == len(current_lens) or current_lens[possible_index] != diff:
551
+ possible_index -= 1
552
+
553
+ if possible_index >= 0:
554
+ best_possible_length = current_lens[possible_index]
555
+ image, sample_input_ids = mapping_by_len[best_possible_length].pop(0)
556
+
557
+ # If we have used up all the samples of a certain length, remove
558
+ # that length from the mapping.
559
+ if len(mapping_by_len[best_possible_length]) == 0:
560
+ del mapping_by_len[best_possible_length]
561
+ current_text.extend(sample_input_ids)
562
+ if is_raw_images:
563
+ current_images.append(image_transform(image))
564
+ else:
565
+ current_images.append(torch.tensor(image))
566
+ elif diff > padding_tolerance and padding_tolerance != -1:
567
+ # If we are here, it means that we still have padding left
568
+ # and we have exhausted our current unique options that will allow us to
569
+ # fill this sequence completely.
570
+ # So, we will try to fill the sequence with whatever we get from the unchanged
571
+ # copy of all sequences.
572
+ while diff > padding_tolerance:
573
+ # Find a random sequence to fit
574
+ # Why we need to add more stuff to prefix seed?
575
+ # prefix_seed will be same in the same batch which means that it might sample
576
+ # same thing again and again if there are multiple cases of padding in the
577
+ # same batch which means we need to make this part as random as possible.
578
+ rng = np.random.default_rng(
579
+ prefix_seed
580
+ + (
581
+ diff,
582
+ len(current_text),
583
+ len(all_texts),
584
+ all_num_images,
585
+ )
586
+ )
587
+ choice = rng.choice(range(len(filtered_input_ids)))
588
+ image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice]
589
+ current_text.extend(sample_input_ids)
590
+ if is_raw_images:
591
+ current_images.append(image_transform(image))
592
+ else:
593
+ current_images.append(torch.tensor(image))
594
+ diff = max_seq_len - len(current_text)
595
+ # In the next top-level while loop iteration, this should go into the else
596
+ # clause which should also handle the sequences longer than max_seq_len
597
+ else:
598
+ current_images = torch.stack(current_images)
599
+ padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
600
+ padded_image_tensor[: current_images.size(0)] = current_images[
601
+ : min(max_num_images, current_images.size(0))
602
+ ]
603
+ all_num_images.append(min(max_num_images, current_images.size(0)))
604
+ all_images.append(padded_image_tensor)
605
+
606
+ padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id)
607
+ current_max_len = min(max_seq_len, len(current_text))
608
+ padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len]
609
+ all_num_text_tokens.append(current_max_len)
610
+ all_texts.append(padded_input_ids)
611
+
612
+ attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
613
+ attention_mask[: len(current_text)] = 1
614
+ all_attention_masks.append(attention_mask)
615
+
616
+ # Make sure to reset the current text and images.
617
+ current_images = []
618
+ current_text = []
619
+ if len(current_lens) == 0:
620
+ break
621
+ else:
622
+ # A case where we might not have any samples left over after the initial filtering step.
623
+ if len(current_lens) == 0:
624
+ break
625
+ image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0)
626
+ if len(mapping_by_len[current_lens[-1]]) == 0:
627
+ del mapping_by_len[current_lens[-1]]
628
+ current_text = sample_input_ids[:max_seq_len]
629
+ if is_raw_images:
630
+ current_images = [image_transform(image)]
631
+ else:
632
+ current_images = [torch.tensor(image)]
633
+
634
+ if len(all_images) == 0 or len(all_texts) == 0:
635
+ result = {
636
+ "input_ids": torch.tensor([], dtype=torch.long),
637
+ "attention_mask": torch.tensor([], dtype=torch.bool),
638
+ "image_attention_mask": torch.tensor([], dtype=torch.bool),
639
+ "num_images": torch.tensor([], dtype=torch.long),
640
+ "num_text_tokens": torch.tensor([], dtype=torch.long),
641
+ }
642
+ if is_raw_images:
643
+ result["pixel_values"] = torch.tensor([], dtype=torch.float32)
644
+ else:
645
+ result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
646
+ return result
647
+
648
+ all_texts = torch.stack(all_texts)
649
+ all_images = torch.stack(all_images)
650
+ all_attention_masks = torch.stack(all_attention_masks)
651
+
652
+ image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer)
653
+ image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
654
+ next_image_attention_mask = incremental_to_binary_attention_mask(
655
+ next_image_attention_mask, num_classes=max_num_images
656
+ )
657
+
658
+ output = {
659
+ "input_ids": all_texts,
660
+ "attention_mask": all_attention_masks,
661
+ "image_attention_mask": image_attention_mask,
662
+ "num_images": torch.tensor(all_num_images),
663
+ "num_text_tokens": torch.tensor(all_num_text_tokens),
664
+ }
665
+ if is_raw_images:
666
+ output["pixel_values"] = all_images
667
+ else:
668
+ output["image_embeddings"] = all_images
669
+
670
+ if is_image_at_end:
671
+ # Set the correct attention mask based on whether the image is at the start
672
+ # or not. When it is at the end, we need next image attention mask.
673
+ output["image_attention_mask"] = next_image_attention_mask
674
+
675
+ return output
676
+
677
+
678
+ # Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py
679
+ def random_spans_helper(
680
+ inputs_length,
681
+ noise_density,
682
+ mean_noise_span_length,
683
+ extra_tokens_per_span_inputs,
684
+ extra_tokens_per_span_targets,
685
+ verbose=False,
686
+ ):
687
+ """Training parameters to avoid padding with random_spans_noise_mask.
688
+
689
+ When training a model with random_spans_noise_mask, we would like to set the
690
+ other training hyperparmeters in a way that avoids padding. This function
691
+ helps us compute these hyperparameters.
692
+
693
+ We assume that each noise span in the input is replaced by
694
+ extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the
695
+ targets is replaced by extra_tokens_per_span_targets sentinel tokens.
696
+
697
+ This function tells us the required number of tokens in the raw example (for
698
+ split_tokens()) as well as the length of the encoded targets.
699
+
700
+ Note that this function assumes the inputs and targets will have EOS appended
701
+ and includes that in the reported length.
702
+
703
+ Args:
704
+ inputs_length: an integer - desired length of the tokenized inputs sequence
705
+ noise_density: a float
706
+ mean_noise_span_length: a float
707
+ extra_tokens_per_span_inputs: an integer
708
+ extra_tokens_per_span_targets: an integer
709
+ verbose: a bool indicating whether to log sequence lengths
710
+ Returns:
711
+ tokens_length: length of original text in tokens
712
+ targets_length: an integer - length in tokens of encoded targets sequence
713
+ """
714
+
715
+ if extra_tokens_per_span_inputs != 1:
716
+ raise NotImplementedError(
717
+ "extra_tokens_per_span_inputs != 1 not supported yet. You need to check"
718
+ " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
719
+ )
720
+ if extra_tokens_per_span_targets != 1:
721
+ raise NotImplementedError(
722
+ "extra_tokens_per_span_targets != 1 not supported yet. You need to check"
723
+ " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
724
+ )
725
+
726
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
727
+ num_noise_tokens = int(round(tokens_length * noise_density))
728
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
729
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
730
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
731
+ # and one EOS token.
732
+ return (
733
+ num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
734
+ num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
735
+ )
736
+
737
+ tokens_length = inputs_length - 1
738
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
739
+ tokens_length += 1
740
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
741
+ # minor hack to get the targets length to be equal to inputs length
742
+ # which is more likely to have been set to a nice round number.
743
+ if noise_density == 0.5 and targets_length > inputs_length:
744
+ tokens_length -= 1
745
+ targets_length -= 1
746
+ if verbose:
747
+ logging.info(
748
+ "tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ",
749
+ tokens_length,
750
+ inputs_length,
751
+ targets_length,
752
+ noise_density,
753
+ mean_noise_span_length,
754
+ )
755
+ return tokens_length, targets_length
m4/training/setup_language_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from m4.models.vbloom.configuration_vbloom import VBloomConfig
4
+ from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM
5
+ from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
6
+ from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel
7
+ from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig
8
+ from m4.models.vgpt_neo.modeling_vgpt_neo import VGPTNeoForCausalLM
9
+ from m4.models.vllama.configuration_vllama import VLlamaConfig
10
+ from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
11
+ from m4.models.vopt.configuration_vopt import VOPTConfig
12
+ from m4.models.vopt.modeling_vopt import VOPTForCausalLM
13
+ from m4.models.vt5.configuration_vt5 import VT5Config
14
+ from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration
15
+
16
+
17
+ model_name2classes = {
18
+ r"bloom|bigscience-small-testing": [VBloomConfig, VBloomForCausalLM],
19
+ r"gpt-neo|gptneo": [VGPTNeoConfig, VGPTNeoForCausalLM],
20
+ r"gpt2": [VGPT2Config, VGPT2LMHeadModel],
21
+ r"opt": [VOPTConfig, VOPTForCausalLM],
22
+ r"t5": [VT5Config, VT5ForConditionalGeneration],
23
+ r"llama": [VLlamaConfig, VLlamaForCausalLM],
24
+ }
25
+
26
+
27
+ def model_name_to_classes(model_name_or_path):
28
+ """returns config_class, model_class for a given model name or path"""
29
+
30
+ model_name_lowcase = model_name_or_path.lower()
31
+ for rx, classes in model_name2classes.items():
32
+ if re.search(rx, model_name_lowcase):
33
+ return classes
34
+ else:
35
+ raise ValueError(
36
+ f"Unknown type of backbone LM. Got {model_name_or_path}, supported regexes:"
37
+ f" {list(model_name2classes.keys())}."
38
+ )
m4/training/setup_vision_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from transformers import AutoModel
4
+
5
+
6
+ # map to check the supported cv archs and also how to extract the model - in some arch, we want to
7
+ # go through a specific prefix to get to the model as in `model.vision_model` for clip
8
+ vision_model_name2model = {
9
+ r"clip": lambda model: model.vision_model,
10
+ r"vit": lambda model: model,
11
+ }
12
+
13
+
14
+ def vision_model_name_to_model(model_name_or_path, model):
15
+ """returns the model if supported, asserts otherwise"""
16
+
17
+ model_name_lowcase = model_name_or_path.lower()
18
+ for rx, lookup in vision_model_name2model.items():
19
+ if re.search(rx, model_name_lowcase):
20
+ return lookup(model)
21
+ else:
22
+ raise ValueError(
23
+ f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:"
24
+ f" {list(vision_model_name2model.keys())}."
25
+ )
26
+
27
+
28
+ def get_vision_model(config):
29
+ vision_model_name = config.vision_model_name
30
+ vision_model_params = eval(config.vision_model_params)
31
+
32
+ model = AutoModel.from_pretrained(vision_model_name, **vision_model_params)
33
+ return vision_model_name_to_model(vision_model_name, model)
m4/training/types.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class DatasetNames(Enum):
5
+ PMD = "pmd"
6
+ LAION = "laion"
7
+ CM4 = "cm4"
8
+ WIKI = "wiki"
9
+
10
+
11
+ class DatasetTypes(Enum):
12
+ WEB_DOCUMENTS = "wd"
13
+ IMAGE_CAPTION_PAIRS = "icp"
m4/training/utils.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import gc
3
+ import json
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from enum import Enum
7
+
8
+ import accelerate
9
+ import psutil
10
+ import pynvml
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.transforms as transforms
14
+ from accelerate.state import AcceleratorState
15
+ from PIL import Image
16
+ from transformers import ( # AddedToken is needed for the eval of the tokenizer params # noqa: F401
17
+ AddedToken,
18
+ AutoTokenizer,
19
+ )
20
+
21
+
22
+ IMAGE_TOKEN = "<image>"
23
+ FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>"
24
+ FAKE_TOKEN_AROUND_IMAGE_V1 = "\n\n"
25
+ # Originally taken from the values used in OpenCLIP
26
+ IMAGE_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
27
+ IMAGE_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class LoggingTypes(Enum):
32
+ """Types of logging to use for the gradient and parameter statistics"""
33
+
34
+ JSONL = "jsonl"
35
+ WANDB = "wandb"
36
+ PRINT = "print"
37
+
38
+
39
+ class JSONEncoderForDataclasses(json.JSONEncoder):
40
+ """
41
+ Use to serialize dataclass object, like so:
42
+ json.dump(data, fp, indent=2, cls=JSONEncoderForDataclasses)
43
+ """
44
+
45
+ def default(self, obj):
46
+ if dataclasses.is_dataclass(obj):
47
+ return dataclasses.asdict(obj)
48
+ return super().default(obj)
49
+
50
+
51
+ def freeze_model(model, module_exceptions=[]):
52
+ mapping = {
53
+ "LayerNorm": nn.LayerNorm,
54
+ "Linear": nn.Linear,
55
+ "Embedding": nn.Embedding,
56
+ }
57
+ module_exceptions_mapped = [mapping[m] for m in module_exceptions]
58
+ for module in model.modules():
59
+ if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
60
+ module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes
61
+ else:
62
+ module.requires_grad_(False)
63
+ return model
64
+
65
+
66
+ def _convert_to_rgb(image):
67
+ # `image.convert("RGB")` would only work for .jpg images, as it creates
68
+ # a wrong background for transparent images. The call to `alpha_composite`
69
+ # handles this case
70
+ if image.mode == "RGB":
71
+ return image
72
+ image_rgba = image.convert("RGBA")
73
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
74
+ alpha_composite = Image.alpha_composite(background, image_rgba)
75
+ alpha_composite = alpha_composite.convert("RGB")
76
+ return alpha_composite
77
+
78
+
79
+ # TODO(aps): Take parameters from config
80
+ def build_image_transform(image_size=224, eval=False):
81
+ return transforms.Compose(
82
+ [
83
+ _convert_to_rgb,
84
+ (
85
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC)
86
+ if eval
87
+ else transforms.RandomResizedCrop(
88
+ (image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
89
+ )
90
+ ),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize(mean=IMAGE_DATASET_MEAN, std=IMAGE_DATASET_STD),
93
+ ]
94
+ )
95
+
96
+
97
+ def get_tokenizer(
98
+ tokenizer_name: str,
99
+ tokenizer_add_tokens,
100
+ tokenizer_add_special_tokens,
101
+ tokenizer_params,
102
+ additional_vocab_size,
103
+ model_vocab_size=None,
104
+ ):
105
+ """
106
+ We artificially separate `tokenizer_add_tokens` and `tokenizer_add_special_tokens` is a dictionary whose keys only takes into account special tokens (eos, pad, cls, etc.).
107
+ On the contrary, `tokenizer_add_tokens` is a list of string of `AddedToken`.
108
+ In practise, we use `tokenizer.add_special_tokens` to add all of these new special tokens or update the existing ones.
109
+
110
+ NB: we constraint to tokenizer to be a fast tokenizer because with the slow tokenizer, we can't set the arguments of the added tokens (cf `.add_tokens`) and by default, the separators are stripped.
111
+ """
112
+ tokenizer_params = eval(tokenizer_params)
113
+ assert isinstance(tokenizer_params, dict)
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_params)
116
+
117
+ if model_vocab_size is not None:
118
+ if model_vocab_size > len(tokenizer):
119
+ logger.warning(
120
+ f"The model vocabulary size ({model_vocab_size}) is larger than the tokenizer vocabulary size "
121
+ f"({len(tokenizer)}). Updating the tokenizer to match."
122
+ )
123
+ if "additional_special_tokens" in tokenizer_params:
124
+ raise ValueError(
125
+ "You can't use `additional_special_tokens` in `tokenizer_params` with a model vocab "
126
+ "size > tokenizer vocab size. We need to adjust tokenizer before adding special "
127
+ "tokens. Please use `tokenizer_add_tokens` instead."
128
+ )
129
+ # We need to pad the tokenizer vocab with fake tokens
130
+ tokenizer.add_tokens(["<fake_token_{}>".format(i) for i in range(model_vocab_size - len(tokenizer))])
131
+
132
+ assert str(eval(tokenizer_add_tokens)[-1]) == IMAGE_TOKEN
133
+ assert str(eval(tokenizer_add_tokens)[-2]) == FAKE_TOKEN_AROUND_IMAGE_V2
134
+ # This check ensures that the image token and the fake token around it will be in the `DecoupledEmbedding.additional_weight`.
135
+ existing_special_tokens = (
136
+ [*tokenizer.special_tokens_map_extended["additional_special_tokens"]]
137
+ if "additional_special_tokens" in tokenizer.special_tokens_map_extended
138
+ else []
139
+ )
140
+ add_special_tokens_dict = {"additional_special_tokens": existing_special_tokens + eval(tokenizer_add_tokens)}
141
+ if tokenizer_add_special_tokens is not None:
142
+ add_special_tokens_dict.update(eval(tokenizer_add_special_tokens))
143
+
144
+ tokenizer.add_special_tokens(add_special_tokens_dict)
145
+
146
+ assert IMAGE_TOKEN in tokenizer.convert_ids_to_tokens(
147
+ [idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))]
148
+ )
149
+ assert FAKE_TOKEN_AROUND_IMAGE_V2 in tokenizer.convert_ids_to_tokens(
150
+ [idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))]
151
+ )
152
+ # This verifies that `<image>` was correctly added to the tokenizer vocabulary
153
+ # XXX: opt-1.3b fails here
154
+ # assert tokenizer.is_fast == tokenizer_params.get("use_fast", True)
155
+
156
+ return tokenizer
157
+
158
+
159
+ def pynmvl_handle(accelerator):
160
+ if not torch.cuda.is_available():
161
+ return None
162
+
163
+ pynvml.nvmlInit()
164
+ return pynvml.nvmlDeviceGetHandleByIndex(accelerator.local_process_index)
165
+
166
+
167
+ def pynvml_get_total_energy_in_joules(handle):
168
+ if not torch.cuda.is_available():
169
+ return 0
170
+ return pynvml.nvmlDeviceGetTotalEnergyConsumption(handle) / 1000
171
+
172
+
173
+ def compute_tflops_per_batch_per_gpu(
174
+ num_layers,
175
+ batch_size,
176
+ q_seq_len,
177
+ k_seq_len,
178
+ hidden_size,
179
+ kv_in_dim,
180
+ ff_exp_factor=None,
181
+ grad_acc_size=1,
182
+ swiglu=False,
183
+ vocab_size=None,
184
+ count_backward=False,
185
+ use_grad_checkpointing=False,
186
+ ):
187
+ multiply_add_factor = torch.tensor(2)
188
+ query_transformation = multiply_add_factor * batch_size * q_seq_len * hidden_size**2
189
+ # k_seq_len == v_seq_len
190
+ key_value_transformation = multiply_add_factor * batch_size * k_seq_len * (2 * hidden_size * kv_in_dim)
191
+ attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size
192
+ attention_softmax = multiply_add_factor * q_seq_len * k_seq_len
193
+ att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size
194
+ post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * hidden_size**2
195
+
196
+ # There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size
197
+ # When using a classic decoder, some blocks don't have those feed-forward layers
198
+ # Swiglu duplicates the first linear layer, so we have to account for 3 of them instead of 2
199
+ if ff_exp_factor and swiglu:
200
+ expansion_linear_layers = 3 * (
201
+ multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size
202
+ )
203
+ elif ff_exp_factor:
204
+ expansion_linear_layers = 2 * (
205
+ multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size
206
+ )
207
+ else:
208
+ expansion_linear_layers = torch.tensor(0)
209
+
210
+ transformer_block_flops = (
211
+ query_transformation
212
+ + key_value_transformation
213
+ + attention_matrix_computation
214
+ + attention_softmax
215
+ + att_over_values_computation
216
+ + post_attention_linear_proj
217
+ + expansion_linear_layers
218
+ )
219
+
220
+ # This computation should only be added if the model has a language head
221
+ if vocab_size:
222
+ language_head_computation = multiply_add_factor * batch_size * q_seq_len * hidden_size * vocab_size
223
+ else:
224
+ language_head_computation = torch.tensor(0)
225
+
226
+ forward_fact = 1
227
+ backward_factor = 2 if count_backward else 0
228
+ grad_checkpointing_factor = 1 if use_grad_checkpointing else 0
229
+ model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (
230
+ num_layers * transformer_block_flops + language_head_computation
231
+ )
232
+ model_tflops = model_flops / (10**12)
233
+
234
+ return model_tflops
235
+
236
+
237
+ def compute_perceiver_tflops_per_batch_per_gpu(
238
+ num_layers,
239
+ batch_size,
240
+ q_seq_len,
241
+ vision_embed_seq_len,
242
+ q_k_v_input_dim,
243
+ attention_hidden_size,
244
+ ff_exp_factor=None,
245
+ count_backward=False,
246
+ use_grad_checkpointing=False,
247
+ ):
248
+ multiply_add_factor = torch.tensor(2)
249
+ query_transformation = multiply_add_factor * batch_size * q_seq_len * q_k_v_input_dim * attention_hidden_size
250
+ # k_seq_len == v_seq_len
251
+ key_value_transformation = (
252
+ multiply_add_factor * batch_size * vision_embed_seq_len * (2 * attention_hidden_size * q_k_v_input_dim)
253
+ )
254
+
255
+ k_seq_len = vision_embed_seq_len + q_seq_len
256
+ attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size
257
+ attention_softmax = multiply_add_factor * q_seq_len * k_seq_len
258
+ att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size
259
+ post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * attention_hidden_size * q_k_v_input_dim
260
+
261
+ # There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size
262
+ # When using a classic decoder, some blocks don't have those feed-forward layers
263
+ if ff_exp_factor:
264
+ expansion_linear_layers = 2 * (
265
+ multiply_add_factor * batch_size * q_seq_len * (q_k_v_input_dim * ff_exp_factor) * q_k_v_input_dim
266
+ )
267
+ else:
268
+ expansion_linear_layers = torch.tensor(0)
269
+
270
+ transformer_block_flops = (
271
+ query_transformation
272
+ + key_value_transformation
273
+ + attention_matrix_computation
274
+ + attention_softmax
275
+ + att_over_values_computation
276
+ + post_attention_linear_proj
277
+ + expansion_linear_layers
278
+ )
279
+
280
+ forward_fact = 1
281
+ backward_factor = 2 if count_backward else 0
282
+ grad_checkpointing_factor = 1 if use_grad_checkpointing else 0
283
+ model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (num_layers * transformer_block_flops)
284
+ model_tflops = model_flops / (10**12)
285
+
286
+ return model_tflops
287
+
288
+
289
+ def mem_usage_formatted(logging_type=LoggingTypes.PRINT):
290
+ # adapted from deepspeed's see_memory_usage
291
+
292
+ torch.cuda.empty_cache()
293
+
294
+ # python doesn't do real-time garbage collection so do it explicitly to get the correct usage reports
295
+ gc.collect()
296
+ vm_stats = psutil.virtual_memory()
297
+
298
+ mem = {
299
+ "gpu mem alloc": f"{torch.cuda.memory_allocated()/2**30:0.2f}GB",
300
+ "max alloc": f"{torch.cuda.max_memory_allocated()/2**30:0.2f}GB",
301
+ "reserv": f"{torch.cuda.memory_reserved()/2**30:0.2f}GB",
302
+ "max reserv": f"{torch.cuda.max_memory_reserved()/2**30:0.2f}GB",
303
+ "cpu vm used": f"{(vm_stats.total-vm_stats.available)/2**30:0.2f}GB {vm_stats.percent}%",
304
+ }
305
+
306
+ if logging_type == LoggingTypes.PRINT:
307
+ mem = " | ".join([f"{k}: {v}" for k, v in mem.items()]) + " | "
308
+
309
+ # get the peak memory to report correct data, so reset the max_memory_allocated counter for the next call
310
+ torch.cuda.reset_peak_memory_stats()
311
+
312
+ return mem
313
+
314
+
315
+ def is_deepspeed_used():
316
+ deepspeed_plugin = get_deepspeed_plugin()
317
+ return deepspeed_plugin is not None
318
+
319
+
320
+ def get_deepspeed_stage():
321
+ deepspeed_plugin = get_deepspeed_plugin()
322
+ if deepspeed_plugin is None:
323
+ return 0
324
+ ds_config = deepspeed_plugin.deepspeed_config
325
+ stage = ds_config.get("zero_optimization", {}).get("stage", 0)
326
+ # from accelerate>=0.17.1 can do instead:
327
+ # stage = deepspeed_plugin.zero_stage
328
+ return stage
329
+
330
+
331
+ def is_deepspeed_zero3_used():
332
+ return get_deepspeed_stage() == 3
333
+
334
+
335
+ def accelerate_torch_dtype():
336
+ """
337
+ derive and return `torch_dtype` to be used in `from_pretrained` from either Deepspeed config or if
338
+ Deepspeed isn't used than accelerator state
339
+ """
340
+ if not is_accelerate_initialized():
341
+ return None
342
+
343
+ accelerator_state = AcceleratorState()
344
+
345
+ if is_deepspeed_used():
346
+ deepspeed_plugin = accelerator_state.deepspeed_plugin
347
+ ds_config = deepspeed_plugin.deepspeed_config
348
+ if ds_config.get("fp16", {}).get("enabled", False):
349
+ torch_dtype = torch.float16
350
+ elif ds_config.get("bf16", {}).get("enabled", False):
351
+ torch_dtype = torch.bfloat16
352
+ else:
353
+ torch_dtype = None
354
+ else: # no Deepspeed
355
+ if accelerator_state.mixed_precision == "fp16":
356
+ torch_dtype = torch.float16
357
+ elif accelerator_state.mixed_precision == "bf16":
358
+ torch_dtype = torch.bfloat16
359
+ else:
360
+ torch_dtype = None
361
+
362
+ return torch_dtype
363
+
364
+
365
+ def is_accelerate_initialized():
366
+ return accelerate.state.is_initialized()
367
+
368
+
369
+ def get_deepspeed_plugin():
370
+ if is_accelerate_initialized():
371
+ return AcceleratorState().deepspeed_plugin
372
+ else:
373
+ return None
374
+
375
+
376
+ def get_deepspeed_engine(accelerator):
377
+ return accelerator.deepspeed_engine_wrapped.engine
378
+
379
+
380
+ def is_deepspeed_zero_init_enabled():
381
+ deepspeed_plugin = get_deepspeed_plugin()
382
+ if deepspeed_plugin is not None:
383
+ return deepspeed_plugin.is_zero3_init_enabled()
384
+ else:
385
+ return False
386
+
387
+
388
+ @contextmanager
389
+ def hf_trainer_disable_zero3_init_context_manager():
390
+ # monkey patch hack to emulate a context that has zero_init disabled as it's used in
391
+ # modeling_utils.py in transformers for from_config and from_pretrained.
392
+ import transformers.modeling_utils # noqa
393
+
394
+ orig = transformers.modeling_utils.is_deepspeed_zero3_enabled
395
+ transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: False
396
+ yield
397
+ transformers.modeling_utils.is_deepspeed_zero3_enabled = orig
398
+
399
+
400
+ def deepspeed_zero_init_disabled_context_manager():
401
+ """
402
+ returns either a context list that includes one that will disable zero.Init or an empty context list
403
+ """
404
+ deepspeed_plugin = get_deepspeed_plugin()
405
+ if deepspeed_plugin is not None:
406
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
407
+ else:
408
+ return [hf_trainer_disable_zero3_init_context_manager()]
409
+
410
+
411
+ def deepspeed_gathered_parameters_context_manager(params, modify=True):
412
+ """
413
+ Under zero.Init returns a context manager that will gather the sharded param, otherwise returns an empty list
414
+
415
+ If `modify` is `True`, gather the shards and once the context exits update the shards with the
416
+ modified data - one wants that when modifying the gathered param. If one wants to just gather
417
+ the shards in order to read the param and no modifications are done to it, use `modify=False` as
418
+ it's more efficient.
419
+
420
+ `params` - can be a single parameter, a list, or a tuple of parameters to collect.
421
+
422
+ Example:
423
+
424
+ from transformers.utils import ContextManagers
425
+ from m4.training.utils import deepspeed_gathered_parameters_context_manager
426
+ with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
427
+ module.weight.data.normal_(mean=0.0, std=std)
428
+ if module.padding_idx is not None:
429
+ module.weight.data[module.padding_idx].zero_()
430
+
431
+
432
+ """
433
+ if is_deepspeed_zero_init_enabled():
434
+ import deepspeed
435
+
436
+ # 0 is for updating `params` shards after modifying it, `None` is for read-only (only gather)
437
+ modifier_rank = 0 if modify else None
438
+ return [deepspeed.zero.GatheredParameters(params, modifier_rank=modifier_rank)]
439
+ else:
440
+ return []
441
+
442
+
443
+ # adapted from https://github.com/huggingface/transformers/blob/a081f292ca8479eaf66d7396186021268f128829/src/transformers/modeling_utils.py#L438-L496
444
+ # as it appears to be a private function
445
+ def load_state_dict_into_model(model_to_load, state_dict, start_prefix):
446
+ # Convert old format to new format if needed from a PyTorch state_dict
447
+ old_keys = []
448
+ new_keys = []
449
+ for key in state_dict.keys():
450
+ new_key = None
451
+ if "gamma" in key:
452
+ new_key = key.replace("gamma", "weight")
453
+ if "beta" in key:
454
+ new_key = key.replace("beta", "bias")
455
+ if new_key:
456
+ old_keys.append(key)
457
+ new_keys.append(new_key)
458
+ for old_key, new_key in zip(old_keys, new_keys):
459
+ state_dict[new_key] = state_dict.pop(old_key)
460
+
461
+ # copy state_dict so _load_from_state_dict can modify it
462
+ metadata = getattr(state_dict, "_metadata", None)
463
+ state_dict = state_dict.copy()
464
+ if metadata is not None:
465
+ state_dict._metadata = metadata
466
+
467
+ error_msgs = []
468
+
469
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
470
+ # so we need to apply the function recursively.
471
+ def load(module: torch.nn.Module, state_dict, prefix=""):
472
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
473
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
474
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
475
+ # state_dict
476
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
477
+ if is_deepspeed_zero_init_enabled():
478
+ import deepspeed
479
+
480
+ # In sharded models, each shard has only part of the full state_dict, so only gather
481
+ # parameters that are in the current state_dict.
482
+ named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
483
+ params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
484
+ if len(params_to_gather) > 0:
485
+ # because zero3 puts placeholders in model params, this context
486
+ # manager gathers (unpartitions) the params of the current layer, then loads from
487
+ # the state dict and then re-partitions them again
488
+ with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
489
+ if torch.distributed.get_rank() == 0:
490
+ module._load_from_state_dict(*args)
491
+ else:
492
+ module._load_from_state_dict(*args)
493
+
494
+ for name, child in module._modules.items():
495
+ if child is not None:
496
+ load(child, state_dict, prefix + name + ".")
497
+
498
+ load(model_to_load, state_dict, prefix=start_prefix)
499
+ # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
500
+ # it's safe to delete it.
501
+ del state_dict
502
+
503
+ return error_msgs
504
+
505
+
506
+ def get_stats(var, ctx):
507
+ if var is None:
508
+ return {}
509
+ var = var.float()
510
+ abs_var = var.abs()
511
+ return {
512
+ f"{ctx}_var_min": var.min().item(),
513
+ f"{ctx}_var_max": var.max().item(),
514
+ f"{ctx}_var_mean": var.mean().item(),
515
+ f"{ctx}_var_std": var.std().item(),
516
+ f"{ctx}_abs_var_min": abs_var.min().item(),
517
+ f"{ctx}_abs_var_max": abs_var.max().item(),
518
+ f"{ctx}_abs_var_mean": abs_var.mean().item(),
519
+ f"{ctx}_abs_var_std": abs_var.std().item(),
520
+ f"{ctx}_var_norm_2": (var.norm(p=2) / var.numel()).item(),
521
+ f"{ctx}_var_norm_1": (var.norm(p=1) / var.numel()).item(),
522
+ f"{ctx}_nonzero": (var != 0).sum().item(),
523
+ }
524
+
525
+
526
+ def get_stats_format(ctx):
527
+ return {
528
+ f"{ctx}_var_min": "e",
529
+ f"{ctx}_var_max": "e",
530
+ f"{ctx}_var_mean": "e",
531
+ f"{ctx}_var_std": "e",
532
+ f"{ctx}_abs_var_min": "e",
533
+ f"{ctx}_abs_var_max": "e",
534
+ f"{ctx}_abs_var_mean": "e",
535
+ f"{ctx}_abs_var_std": "e",
536
+ f"{ctx}_var_norm_2": "e",
537
+ f"{ctx}_var_norm_1": "e",
538
+ f"{ctx}_nonzero": "",
539
+ }
m4/utils/__init__.py ADDED
File without changes