Spaces:
Runtime error
Runtime error
VictorSanh
commited on
Commit
•
217780a
1
Parent(s):
15d7f9b
Update visualization
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +19 -5
- app.py +793 -0
- app_bis.py +857 -0
- app_dialogue.py +395 -0
- elon_musk.md +23 -0
- images/bear.jpg +0 -0
- m4/__init__.py +1 -0
- m4/models/__init__.py +28 -0
- m4/models/common.py +104 -0
- m4/models/custom_modules.py +337 -0
- m4/models/perceiver/perceiver.py +141 -0
- m4/models/vbloom/__init__.py +0 -0
- m4/models/vbloom/configuration_vbloom.py +235 -0
- m4/models/vbloom/modeling_vbloom.py +1396 -0
- m4/models/vgpt2/__init__.py +0 -0
- m4/models/vgpt2/configuration_vgpt2.py +288 -0
- m4/models/vgpt2/modeling_vgpt2.py +1384 -0
- m4/models/vgpt_neo/__init__.py +0 -0
- m4/models/vgpt_neo/configuration_vgpt_neo.py +250 -0
- m4/models/vgpt_neo/modeling_vgpt_neo.py +1182 -0
- m4/models/vllama/configuration_vllama.py +204 -0
- m4/models/vllama/make_tiny_llama.py +51 -0
- m4/models/vllama/make_tiny_model.py +114 -0
- m4/models/vllama/modeling_vllama.py +1260 -0
- m4/models/vopt/__init__.py +0 -0
- m4/models/vopt/configuration_vopt.py +250 -0
- m4/models/vopt/make_tiny_model.py +114 -0
- m4/models/vopt/modeling_vopt.py +1513 -0
- m4/models/vt5/__init__.py +0 -0
- m4/models/vt5/configuration_vt5.py +218 -0
- m4/models/vt5/modeling_vt5.py +0 -0
- m4/models/zero_checkpoint_to_hf.py +87 -0
- m4/scripts/cleanup-checkpoints.py +156 -0
- m4/scripts/convert-checkpoints.py +124 -0
- m4/scripts/s3-upload-checkpoints.py +194 -0
- m4/scripts/s3_checkpoint_download_convert_upload.py +171 -0
- m4/scripts/s3_checkpoint_download_convert_upload.slurm +51 -0
- m4/scripts/s3_downloaded_checkpoints_cleanup.slurm +54 -0
- m4/scripts/schedule-evals.py +87 -0
- m4/testing_utils.py +1116 -0
- m4/training/__init__.py +0 -0
- m4/training/config.py +545 -0
- m4/training/dataset_utils.py +352 -0
- m4/training/debug_utils.py +34 -0
- m4/training/packing.py +755 -0
- m4/training/setup_language_model.py +38 -0
- m4/training/setup_vision_model.py +33 -0
- m4/training/types.py +13 -0
- m4/training/utils.py +539 -0
- m4/utils/__init__.py +0 -0
README.md
CHANGED
@@ -1,12 +1,26 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🐨
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: m4-dialogue
|
3 |
emoji: 🐨
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.12.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# M4 Visualization
|
13 |
+
|
14 |
+
For visualizations, we have a main [app](https://huggingface.co/spaces/HuggingFaceM4/m4-demo) which calls multiple child apps to retrieve generations via [Gradio API](https://gradio.app/using-blocks-like-functions/). This allows us to parallelize calls to multiple models at the same time instead of running them sequentially.
|
15 |
+
|
16 |
+
|
17 |
+
## How to?
|
18 |
+
|
19 |
+
The process of adding a model to the main space:
|
20 |
+
|
21 |
+
- Use `huggingface-cli login` to login with an auth token that has a read/write access to the `HuggingFaceM4` org on the hub.
|
22 |
+
- Use `./upload_checkpoint_to_hub_gcs.sh` script to upload a checkpoint from GCP store to the hub. An example command to upload checkpoint for step 3000 from `tr_121ter` to the hub: `./m4/visualization/upload_checkpoint_to_hub_gcs.sh gs://hf-science-m4-cold/local_experiment_dir/tr_121ter/opt_step-3000`. This will create model repo under the `HuggingFaceM4` repo on the hub. If you are on the cluster, use `./upload_checkpoint_to_hub_s3.sh` instead. I recommend being on a compute node to avoid disk space issues (uploading to the hub consists in downloading locally the checkpoint, creating a repo on the hub, copying it locally, filling it with the weights and commiting the weights to the hub repo).
|
23 |
+
- [MANUAL] Go to the hub, create a repo of type `space` with the same name as the model. In the space's settings, add a secret `HF_AUTH_TOKEN` with a token which has read access to the `HuggingFaceM4` repo. This step can be potentially automated in the future.
|
24 |
+
- [MANUAL] Edit `m4/visualization/app.py`'s three dictionary to include your model in the existing formats of those dictionaries.
|
25 |
+
- Run `m4/visualization/sync-repo.sh <name_of_the_space_on_the_hub>` to sync the repo with the local setting. This will automatically update the space to have the latest code as in the `m4/visualization/app.py`.
|
26 |
+
- Run `m4/visualization/sync-repo.sh main` to update the main repo as well with the new model.
|
app.py
ADDED
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
+
import transformers
|
11 |
+
from accelerate.utils import get_max_memory
|
12 |
+
from joblib import Parallel, delayed
|
13 |
+
from PIL import Image
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
from m4.models.vbloom import configuration_vbloom, modeling_vbloom
|
17 |
+
from m4.models.vgpt2 import configuration_vgpt2, modeling_vgpt2
|
18 |
+
from m4.models.vgpt_neo import configuration_vgpt_neo, modeling_vgpt_neo
|
19 |
+
from m4.models.vllama import configuration_vllama, modeling_vllama
|
20 |
+
from m4.models.vopt import configuration_vopt, modeling_vopt
|
21 |
+
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
|
22 |
+
from m4.training.utils import build_image_transform
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig(level=logging.INFO)
|
26 |
+
logger = logging.getLogger()
|
27 |
+
|
28 |
+
CURRENT_MODEL = "<replace_model>"
|
29 |
+
# CURRENT_MODEL = "tr_177_4datasets_alpha_baseline_opt_step-2000"
|
30 |
+
|
31 |
+
MAX_TRIES = 3
|
32 |
+
TOKENIZER_FAST = True
|
33 |
+
MAX_SEQ_LEN = 1024
|
34 |
+
model, tokenizer = None, None
|
35 |
+
|
36 |
+
|
37 |
+
MODEL_TO_DISPLAY_NAME = {
|
38 |
+
"tr_199_w_xattn_opt_step-65000": "VLlama - tr_199_w_xattn_opt_step-65000",
|
39 |
+
# "tr_201_sft_on_lrv_opt_step-15000": "VLlama - tr_201_sft_on_lrv_opt_step-15000",
|
40 |
+
# "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlama - tr_202bis_ift_llava_all_unfrozen_opt_step-14128",
|
41 |
+
# "tr_203_ift_m3it_opt_step-50000": "VLlama - tr_203_ift_m3it_opt_step-50000",
|
42 |
+
# "tr_205_sft_ultrachat_opt_step-20000": "VLlama - tr_205_sft_ultrachat_opt_step-20000",
|
43 |
+
# "tr_207_ift_svit_opt_step-14627": "VLlama - tr_207_ift_svit_opt_step-14627",
|
44 |
+
"tr_209_ift_mixture_opt_step-14000": "VLlama - tr_209_ift_mixture_opt_step-14000",
|
45 |
+
}
|
46 |
+
MODEL_TO_MODEL_CLASS = {
|
47 |
+
"tr_199_w_xattn_opt_step-65000": "VLlamaForCausalLM",
|
48 |
+
# "tr_201_sft_on_lrv_opt_step-15000": "VLlamaForCausalLM",
|
49 |
+
# "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaForCausalLM",
|
50 |
+
# "tr_203_ift_m3it_opt_step-50000": "VLlamaForCausalLM",
|
51 |
+
# "tr_205_sft_ultrachat_opt_step-20000": "VLlamaForCausalLM",
|
52 |
+
# "tr_207_ift_svit_opt_step-14627": "VLlamaForCausalLM",
|
53 |
+
"tr_209_ift_mixture_opt_step-14000": "VLlamaForCausalLM",
|
54 |
+
}
|
55 |
+
|
56 |
+
MODEL_TO_CONFIG_CLASS = {
|
57 |
+
"tr_199_w_xattn_opt_step-65000": "VLlamaConfig",
|
58 |
+
# "tr_201_sft_on_lrv_opt_step-15000": "VLlamaConfig",
|
59 |
+
# "tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaConfig",
|
60 |
+
# "tr_203_ift_m3it_opt_step-50000": "VLlamaConfig",
|
61 |
+
# "tr_205_sft_ultrachat_opt_step-20000": "VLlamaConfig",
|
62 |
+
# "tr_207_ift_svit_opt_step-14627": "VLlamaConfig",
|
63 |
+
"tr_209_ift_mixture_opt_step-14000": "VLlamaConfig",
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def load_tokenizer_model(model_name, model_class):
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
69 |
+
model_name,
|
70 |
+
use_fast=TOKENIZER_FAST,
|
71 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
|
72 |
+
)
|
73 |
+
tokenizer.padding_side = "left"
|
74 |
+
config_class = MODEL_TO_CONFIG_CLASS[model_name.split("/")[-1]]
|
75 |
+
|
76 |
+
# assert tokenizer.is_fast
|
77 |
+
|
78 |
+
supported_custom_modules = {
|
79 |
+
"vgpt2": modeling_vgpt2,
|
80 |
+
"vbloom": modeling_vbloom,
|
81 |
+
"vgptneo": modeling_vgpt_neo,
|
82 |
+
"vopt": modeling_vopt,
|
83 |
+
"vllama": modeling_vllama,
|
84 |
+
}
|
85 |
+
supported_custom_configs = {
|
86 |
+
"vgpt2": configuration_vgpt2,
|
87 |
+
"vbloom": configuration_vbloom,
|
88 |
+
"vgptneo": configuration_vgpt_neo,
|
89 |
+
"vopt": configuration_vopt,
|
90 |
+
"vllama": configuration_vllama,
|
91 |
+
}
|
92 |
+
parent_config_class = (
|
93 |
+
[v for k, v in supported_custom_configs.items() if k in model_class.lower()] + [transformers]
|
94 |
+
)[0]
|
95 |
+
parent_model_class = (
|
96 |
+
[v for k, v in supported_custom_modules.items() if k in model_class.lower()] + [transformers]
|
97 |
+
)[0]
|
98 |
+
config_class = getattr(parent_config_class, config_class)
|
99 |
+
model_class = getattr(parent_model_class, model_class)
|
100 |
+
config = config_class.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
|
101 |
+
max_memory_map = get_max_memory()
|
102 |
+
for key in max_memory_map.keys():
|
103 |
+
if key != "cpu":
|
104 |
+
# Get this in GB
|
105 |
+
max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
|
106 |
+
# Decrease 2 for Pytorch overhead and 2 for the forward to be safe
|
107 |
+
max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
|
108 |
+
model = model_class.from_pretrained(
|
109 |
+
model_name,
|
110 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
|
111 |
+
device_map="auto",
|
112 |
+
offload_folder="./offload",
|
113 |
+
torch_dtype=config.torch_dtype,
|
114 |
+
max_memory=max_memory_map,
|
115 |
+
)
|
116 |
+
model.eval()
|
117 |
+
print("Current device map:", model.hf_device_map)
|
118 |
+
print("Model default generation config:", model.generation_config)
|
119 |
+
# TODO: the device_map looks very inefficien right now. that could be improved
|
120 |
+
# it typically looks like that
|
121 |
+
# {
|
122 |
+
# 'model.embed_tokens': 0,
|
123 |
+
# 'model.vision_model': 0,
|
124 |
+
# 'model.layers.0': 0,
|
125 |
+
# 'model.layers.1': 0,
|
126 |
+
# 'model.layers.2': 0,
|
127 |
+
# 'model.layers.3': 0,
|
128 |
+
# 'model.layers.4': 0,
|
129 |
+
# 'model.layers.5': 0,
|
130 |
+
# 'model.layers.6': 1,
|
131 |
+
# 'model.layers.7': 1,
|
132 |
+
# 'model.layers.8': 1,
|
133 |
+
# 'model.layers.9': 1,
|
134 |
+
# 'model.layers.10': 1,
|
135 |
+
# 'model.layers.11': 1,
|
136 |
+
# 'model.layers.12': 1,
|
137 |
+
# 'model.layers.13': 1,
|
138 |
+
# 'model.layers.14': 1,
|
139 |
+
# 'model.layers.15': 1,
|
140 |
+
# 'model.layers.16': 1,
|
141 |
+
# 'model.layers.17': 2,
|
142 |
+
# 'model.layers.18': 2,
|
143 |
+
# 'model.layers.19': 2,
|
144 |
+
# 'model.layers.20': 2,
|
145 |
+
# 'model.layers.21': 2,
|
146 |
+
# 'model.layers.22': 2,
|
147 |
+
# 'model.layers.23': 2,
|
148 |
+
# 'model.layers.24': 2,
|
149 |
+
# 'model.layers.25': 2,
|
150 |
+
# 'model.layers.26': 2,
|
151 |
+
# 'model.layers.27': 2,
|
152 |
+
# 'model.layers.28': 3,
|
153 |
+
# 'model.layers.29': 3,
|
154 |
+
# 'model.layers.30': 3,
|
155 |
+
# 'model.layers.31': 3,
|
156 |
+
# 'model.gated_cross_attn_layers.0': 3,
|
157 |
+
# 'model.gated_cross_attn_layers.1': 3,
|
158 |
+
# 'model.gated_cross_attn_layers.2': 3,
|
159 |
+
# 'model.gated_cross_attn_layers.3': 3,
|
160 |
+
# 'model.gated_cross_attn_layers.4': 3,
|
161 |
+
# 'model.gated_cross_attn_layers.5': 3,
|
162 |
+
# 'model.gated_cross_attn_layers.6': 3,
|
163 |
+
# 'model.gated_cross_attn_layers.7': 3,
|
164 |
+
# 'model.gated_cross_attn_layers.8': 4,
|
165 |
+
# 'model.gated_cross_attn_layers.9': 4,
|
166 |
+
# 'model.gated_cross_attn_layers.10': 4,
|
167 |
+
# 'model.gated_cross_attn_layers.11': 4,
|
168 |
+
# 'model.gated_cross_attn_layers.12': 4,
|
169 |
+
# 'model.gated_cross_attn_layers.13': 4,
|
170 |
+
# 'model.gated_cross_attn_layers.14': 4,
|
171 |
+
# 'model.gated_cross_attn_layers.15': 4,
|
172 |
+
# 'model.norm': 4,
|
173 |
+
# 'lm_head': 4
|
174 |
+
# } which means there is a lot of things going around between the gated cross attention layers and the LM layers...
|
175 |
+
return tokenizer, model
|
176 |
+
|
177 |
+
|
178 |
+
MODEL_TO_SPACE_MAPPING = {}
|
179 |
+
IS_MAIN_SPACE = CURRENT_MODEL not in MODEL_TO_MODEL_CLASS
|
180 |
+
if IS_MAIN_SPACE:
|
181 |
+
for model in MODEL_TO_MODEL_CLASS:
|
182 |
+
MODEL_TO_SPACE_MAPPING[model] = gr.Blocks.load(
|
183 |
+
name=f"spaces/HuggingFaceM4/{model}", api_key=os.getenv("HF_AUTH_TOKEN", True)
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
model_path = f"HuggingFaceM4/{CURRENT_MODEL}"
|
187 |
+
tokenizer, model = load_tokenizer_model(model_path, MODEL_TO_MODEL_CLASS[CURRENT_MODEL])
|
188 |
+
|
189 |
+
|
190 |
+
def fetch_images(url_images):
|
191 |
+
images = []
|
192 |
+
for url in url_images:
|
193 |
+
if isinstance(url, str):
|
194 |
+
images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
|
195 |
+
else:
|
196 |
+
images.append(url)
|
197 |
+
return images
|
198 |
+
|
199 |
+
|
200 |
+
def model_generation(
|
201 |
+
prompt,
|
202 |
+
images,
|
203 |
+
tokenizer,
|
204 |
+
model,
|
205 |
+
temperature,
|
206 |
+
no_repeat_ngram_size,
|
207 |
+
max_new_tokens,
|
208 |
+
min_length,
|
209 |
+
ban_tokens,
|
210 |
+
eos_tokens,
|
211 |
+
force_words,
|
212 |
+
repetition_penalty,
|
213 |
+
hide_special_tokens,
|
214 |
+
decoding_strategy,
|
215 |
+
num_beams,
|
216 |
+
length_penalty,
|
217 |
+
top_k,
|
218 |
+
top_p,
|
219 |
+
penalty_alpha,
|
220 |
+
):
|
221 |
+
# Preparing inputs
|
222 |
+
tokens = tokenizer(
|
223 |
+
[prompt],
|
224 |
+
truncation=True,
|
225 |
+
max_length=MAX_SEQ_LEN,
|
226 |
+
padding=True,
|
227 |
+
add_special_tokens=False,
|
228 |
+
)
|
229 |
+
|
230 |
+
input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
|
231 |
+
attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
|
232 |
+
|
233 |
+
image_attention_mask = [
|
234 |
+
incremental_to_binary_attention_mask(
|
235 |
+
image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
|
236 |
+
)
|
237 |
+
]
|
238 |
+
|
239 |
+
image_transform = build_image_transform(eval=True)
|
240 |
+
pixel_values = [torch.stack([image_transform(img) for img in images])]
|
241 |
+
|
242 |
+
input_ids = input_ids.to(0)
|
243 |
+
attention_mask = attention_mask.to(0)
|
244 |
+
pixel_values = torch.stack(pixel_values).to(0)
|
245 |
+
image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
|
246 |
+
|
247 |
+
# Excluding some words from the generation
|
248 |
+
bad_words_ids = None
|
249 |
+
ban_tokens = ban_tokens.replace("\\n", "\n")
|
250 |
+
bad_words = ban_tokens.split(";")
|
251 |
+
if len(bad_words) > 0:
|
252 |
+
bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
|
253 |
+
|
254 |
+
# Forcing some words in the generation
|
255 |
+
force_words_ids = None
|
256 |
+
if force_words != "":
|
257 |
+
force_words = force_words.replace("\\n", "\n")
|
258 |
+
force_words = force_words.split(";")
|
259 |
+
if len(force_words) > 0:
|
260 |
+
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
261 |
+
|
262 |
+
eos_token_ids = None
|
263 |
+
if eos_tokens != "":
|
264 |
+
eos_tokens = eos_tokens.replace("\\n", "\n")
|
265 |
+
eos_tokens = eos_tokens.split(";")
|
266 |
+
if len(eos_tokens) > 0:
|
267 |
+
eos_token_ids = []
|
268 |
+
for eos_token in eos_tokens:
|
269 |
+
tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
|
270 |
+
if len(tokenized_eos_token) > 1:
|
271 |
+
raise ValueError(
|
272 |
+
f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
|
273 |
+
f" {tokenized_eos_token}"
|
274 |
+
)
|
275 |
+
eos_token_ids += tokenized_eos_token
|
276 |
+
|
277 |
+
# Inputs
|
278 |
+
input_args = {
|
279 |
+
"input_ids": input_ids,
|
280 |
+
"attention_mask": attention_mask,
|
281 |
+
"pixel_values": pixel_values,
|
282 |
+
"image_attention_mask": image_attention_mask,
|
283 |
+
}
|
284 |
+
# Common parameters to all decoding strategies
|
285 |
+
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
286 |
+
generation_args = {
|
287 |
+
"temperature": temperature,
|
288 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
289 |
+
"max_new_tokens": max_new_tokens,
|
290 |
+
"min_length": min_length,
|
291 |
+
"bad_words_ids": bad_words_ids,
|
292 |
+
"force_words_ids": force_words_ids,
|
293 |
+
"repetition_penalty": repetition_penalty,
|
294 |
+
"eos_token_id": eos_token_ids,
|
295 |
+
}
|
296 |
+
|
297 |
+
assert decoding_strategy in [
|
298 |
+
"greedy",
|
299 |
+
"beam_search",
|
300 |
+
"beam_sampling",
|
301 |
+
"sampling_top_k",
|
302 |
+
"sampling_top_p",
|
303 |
+
"contrastive_sampling",
|
304 |
+
]
|
305 |
+
if decoding_strategy == "greedy":
|
306 |
+
pass
|
307 |
+
elif decoding_strategy == "beam_search":
|
308 |
+
generation_args["num_beams"] = num_beams
|
309 |
+
generation_args["length_penalty"] = length_penalty
|
310 |
+
assert generation_args["num_beams"] > 1
|
311 |
+
elif decoding_strategy == "beam_sampling":
|
312 |
+
generation_args["num_beams"] = num_beams
|
313 |
+
generation_args["length_penalty"] = length_penalty
|
314 |
+
generation_args["do_sample"] = True
|
315 |
+
assert generation_args["num_beams"] > 1
|
316 |
+
elif decoding_strategy == "sampling_top_k":
|
317 |
+
generation_args["do_sample"] = True
|
318 |
+
generation_args["top_k"] = top_k
|
319 |
+
elif decoding_strategy == "sampling_top_p":
|
320 |
+
generation_args["do_sample"] = True
|
321 |
+
generation_args["top_p"] = top_p
|
322 |
+
elif decoding_strategy == "contrastive_sampling":
|
323 |
+
generation_args["do_sample"] = True
|
324 |
+
generation_args["penalty_alpha"] = penalty_alpha
|
325 |
+
generation_args["top_k"] = top_k
|
326 |
+
|
327 |
+
generated_tokens = model.generate(
|
328 |
+
**input_args,
|
329 |
+
**generation_args,
|
330 |
+
)
|
331 |
+
tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
|
332 |
+
decoded_skip_special_tokens = repr(
|
333 |
+
tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
|
334 |
+
)
|
335 |
+
decoded = repr(tokenizer.batch_decode(generated_tokens)[0])
|
336 |
+
logger.info(
|
337 |
+
"Result: \n"
|
338 |
+
f"Prompt: `{prompt}`\n"
|
339 |
+
f"Tokens ids from prompt + generation: `{generated_tokens[0].tolist()}`\n"
|
340 |
+
f"Tokens (converted) from prompt + generation: `{tokens}`\n"
|
341 |
+
f"String decoded with skipped special tokens: `{decoded_skip_special_tokens}`\n"
|
342 |
+
f"String decoded: `{decoded}`\n"
|
343 |
+
f"Generation mode: `{decoding_strategy}`\n"
|
344 |
+
f"Generation parameters: `{generation_args}`\n"
|
345 |
+
)
|
346 |
+
|
347 |
+
original_prompt = generated_tokens[:, : input_ids.shape[-1]]
|
348 |
+
actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
|
349 |
+
|
350 |
+
first_end_token = len(actual_generated_tokens[0])
|
351 |
+
actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
|
352 |
+
displayed_tokens = torch.cat([original_prompt, actual_generated_tokens], dim=-1)
|
353 |
+
generated_text = tokenizer.batch_decode(displayed_tokens, skip_special_tokens=hide_special_tokens)[0]
|
354 |
+
return generated_text
|
355 |
+
|
356 |
+
|
357 |
+
def model_inference(
|
358 |
+
files,
|
359 |
+
prompt,
|
360 |
+
temperature,
|
361 |
+
no_repeat_ngram_size,
|
362 |
+
max_new_tokens,
|
363 |
+
min_length,
|
364 |
+
ban_tokens,
|
365 |
+
eos_tokens,
|
366 |
+
force_words,
|
367 |
+
repetition_penalty,
|
368 |
+
hide_special_tokens,
|
369 |
+
decoding_strategy,
|
370 |
+
num_beams,
|
371 |
+
length_penalty,
|
372 |
+
top_k,
|
373 |
+
top_p,
|
374 |
+
penalty_alpha,
|
375 |
+
):
|
376 |
+
if isinstance(files, str) and len(files) == 0:
|
377 |
+
files = None
|
378 |
+
|
379 |
+
prompt = prompt.strip()
|
380 |
+
prompt = prompt.replace("\\n", "\n")
|
381 |
+
file_idx = 0
|
382 |
+
url_images = re.findall(r"<image(.*?)>", prompt)
|
383 |
+
for idx, url_image in enumerate(url_images):
|
384 |
+
if len(url_image) == 0:
|
385 |
+
url_images[idx] = Image.open(files[file_idx].name if hasattr(files[file_idx], "name") else files[file_idx])
|
386 |
+
file_idx += 1
|
387 |
+
else:
|
388 |
+
prompt = prompt.replace(url_image, "")
|
389 |
+
url_images[idx] = url_images[idx][1:]
|
390 |
+
images = fetch_images(url_images)
|
391 |
+
|
392 |
+
global model, tokenizer
|
393 |
+
|
394 |
+
generated_text = model_generation(
|
395 |
+
prompt=prompt,
|
396 |
+
images=images,
|
397 |
+
tokenizer=tokenizer,
|
398 |
+
model=model,
|
399 |
+
temperature=temperature,
|
400 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
401 |
+
max_new_tokens=max_new_tokens,
|
402 |
+
min_length=min_length,
|
403 |
+
ban_tokens=ban_tokens,
|
404 |
+
eos_tokens=eos_tokens,
|
405 |
+
force_words=force_words,
|
406 |
+
repetition_penalty=repetition_penalty,
|
407 |
+
hide_special_tokens=hide_special_tokens,
|
408 |
+
decoding_strategy=decoding_strategy,
|
409 |
+
num_beams=num_beams,
|
410 |
+
length_penalty=length_penalty,
|
411 |
+
top_k=top_k,
|
412 |
+
top_p=top_p,
|
413 |
+
penalty_alpha=penalty_alpha,
|
414 |
+
)
|
415 |
+
return generated_text.strip()
|
416 |
+
|
417 |
+
|
418 |
+
def try_model_inference(
|
419 |
+
model,
|
420 |
+
files,
|
421 |
+
prompt,
|
422 |
+
temperature,
|
423 |
+
no_repeat_ngram_size,
|
424 |
+
max_new_tokens,
|
425 |
+
min_length,
|
426 |
+
ban_tokens,
|
427 |
+
eos_tokens,
|
428 |
+
force_words,
|
429 |
+
repetition_penalty,
|
430 |
+
hide_special_tokens,
|
431 |
+
decoding_strategy,
|
432 |
+
num_beams,
|
433 |
+
length_penalty,
|
434 |
+
top_k,
|
435 |
+
top_p,
|
436 |
+
penalty_alpha,
|
437 |
+
):
|
438 |
+
count = 0
|
439 |
+
while count < MAX_TRIES:
|
440 |
+
try:
|
441 |
+
return MODEL_TO_SPACE_MAPPING[model](
|
442 |
+
files,
|
443 |
+
prompt,
|
444 |
+
temperature,
|
445 |
+
no_repeat_ngram_size,
|
446 |
+
max_new_tokens,
|
447 |
+
min_length,
|
448 |
+
ban_tokens,
|
449 |
+
eos_tokens,
|
450 |
+
force_words,
|
451 |
+
repetition_penalty,
|
452 |
+
hide_special_tokens,
|
453 |
+
decoding_strategy,
|
454 |
+
num_beams,
|
455 |
+
length_penalty,
|
456 |
+
top_k,
|
457 |
+
top_p,
|
458 |
+
penalty_alpha,
|
459 |
+
api_name="model_inference",
|
460 |
+
)
|
461 |
+
except KeyError:
|
462 |
+
# Gradio return {'error': None} some times.
|
463 |
+
time.sleep(3)
|
464 |
+
count += 1
|
465 |
+
pass
|
466 |
+
|
467 |
+
|
468 |
+
def all_model_inference(
|
469 |
+
prompt,
|
470 |
+
temperature,
|
471 |
+
no_repeat_ngram_size,
|
472 |
+
max_new_tokens,
|
473 |
+
min_length,
|
474 |
+
ban_tokens,
|
475 |
+
eos_tokens,
|
476 |
+
force_words,
|
477 |
+
repetition_penalty,
|
478 |
+
hide_special_tokens,
|
479 |
+
decoding_strategy,
|
480 |
+
num_beams,
|
481 |
+
length_penalty,
|
482 |
+
top_k,
|
483 |
+
top_p,
|
484 |
+
penalty_alpha,
|
485 |
+
):
|
486 |
+
outputs = []
|
487 |
+
print(
|
488 |
+
prompt,
|
489 |
+
temperature,
|
490 |
+
no_repeat_ngram_size,
|
491 |
+
max_new_tokens,
|
492 |
+
min_length,
|
493 |
+
ban_tokens,
|
494 |
+
eos_tokens,
|
495 |
+
force_words,
|
496 |
+
repetition_penalty,
|
497 |
+
hide_special_tokens,
|
498 |
+
decoding_strategy,
|
499 |
+
num_beams,
|
500 |
+
length_penalty,
|
501 |
+
top_k,
|
502 |
+
top_p,
|
503 |
+
penalty_alpha,
|
504 |
+
)
|
505 |
+
outputs = Parallel(n_jobs=len(MODEL_TO_SPACE_MAPPING), backend="threading")(
|
506 |
+
delayed(try_model_inference)(
|
507 |
+
model,
|
508 |
+
os.path.join(os.path.dirname(__file__), "images", "bear.jpg"),
|
509 |
+
prompt,
|
510 |
+
temperature,
|
511 |
+
no_repeat_ngram_size,
|
512 |
+
max_new_tokens,
|
513 |
+
min_length,
|
514 |
+
ban_tokens,
|
515 |
+
eos_tokens,
|
516 |
+
force_words,
|
517 |
+
repetition_penalty,
|
518 |
+
hide_special_tokens,
|
519 |
+
decoding_strategy,
|
520 |
+
num_beams,
|
521 |
+
length_penalty,
|
522 |
+
top_k,
|
523 |
+
top_p,
|
524 |
+
penalty_alpha,
|
525 |
+
)
|
526 |
+
for model in MODEL_TO_SPACE_MAPPING
|
527 |
+
)
|
528 |
+
if len(outputs) == 1:
|
529 |
+
outputs = outputs[0]
|
530 |
+
return outputs
|
531 |
+
|
532 |
+
|
533 |
+
examples = [
|
534 |
+
[
|
535 |
+
None,
|
536 |
+
"""The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Assistant, and a human user, called User. In the following interactions, User and Assistant will converse in natural language, and Assistant will do its best to answer User’s questions. Assistant was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. When prompted with an image, it does not make up facts.
|
537 |
+
|
538 |
+
The conversation begins:
|
539 |
+
User:<fake_token_around_image><image:https://i1.pickpik.com/photos/515/906/355/kittens-cats-pet-cute-preview.jpg><fake_token_around_image>Describe this image.</s>
|
540 |
+
Assistant: Five kittens are standing together in the center of the photograph. From the left to right, there is one orange kitte, two kittens white and black stripes, and two brown kittens with an orange head. They are in bright green grass and it looks like they are moving forward.</s>
|
541 |
+
User:How about this image?<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>Can you describe it too?</s>
|
542 |
+
Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you. The dog has a white body and brown patches on its face and ears. Its eyes are dark. Its nose is black, and it has long, floppy ears, white paws, long fur, big eyes.</s>
|
543 |
+
User: What kind of breed is it?</s>
|
544 |
+
Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.</s>
|
545 |
+
User: What can you tell me about this breed of dogs?</s>
|
546 |
+
Assistant: One specific characteristic of the Cavalier King Charles Spaniel is its friendly and affectionate nature. This breed is known for being extremely sociable and forming strong bonds with their owners. They are often described as "velcro dogs" because they love to be close to their human companions, whether it's sitting on their laps, cuddling on the couch, or simply following them around the house.</s>
|
547 |
+
User: ghjkhjabnufs</s>Assistant: That doesn’t seem to be a word. Could you ask me another way?</s>
|
548 |
+
User: Do you like Cavalier King Charles Spaniel?</s>
|
549 |
+
Assistant: I do not have personal opinions as I’m just a computer program. However, cavaliers are known for being excellent family pets due to their gentle and patient demeanor, making them great with children and other pets. Their affectionate nature also makes them ideal therapy dogs, as they can provide comfort and emotional support to those in need.</s>
|
550 |
+
User: How many dogs do you see in this image?<fake_token_around_image><image:https://i.dailymail.co.uk/i/pix/2011/07/01/article-2010308-0CD22A8300000578-496_634x414.jpg><fake_token_around_image></s>
|
551 |
+
Assistant: There is no dogs in this image. The picture shows a tennis player jumping to volley the ball.</s>
|
552 |
+
User:<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.</s>
|
553 |
+
Assistant:""",
|
554 |
+
1.0,
|
555 |
+
0,
|
556 |
+
256,
|
557 |
+
10,
|
558 |
+
"<image>;<fake_token_around_image>",
|
559 |
+
"</s>;User",
|
560 |
+
"",
|
561 |
+
1.0,
|
562 |
+
False,
|
563 |
+
"greedy",
|
564 |
+
1,
|
565 |
+
1.0,
|
566 |
+
50,
|
567 |
+
0.5,
|
568 |
+
0.95,
|
569 |
+
],
|
570 |
+
# [
|
571 |
+
# None,
|
572 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
573 |
+
# User: <fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>
|
574 |
+
# Describe this image.
|
575 |
+
# Bot:""",
|
576 |
+
# 1,
|
577 |
+
# 2,
|
578 |
+
# 64,
|
579 |
+
# 10,
|
580 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
581 |
+
# False,
|
582 |
+
# False,
|
583 |
+
# True,
|
584 |
+
# ],
|
585 |
+
# [
|
586 |
+
# None,
|
587 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
588 |
+
# User: <fake_token_around_image><image:https://i.redd.it/hsktcp4nv1g01.jpg><fake_token_around_image>
|
589 |
+
# Why do people find this image funny?
|
590 |
+
# Bot:""",
|
591 |
+
# 1,
|
592 |
+
# 2,
|
593 |
+
# 64,
|
594 |
+
# 10,
|
595 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
596 |
+
# False,
|
597 |
+
# False,
|
598 |
+
# True,
|
599 |
+
# ],
|
600 |
+
# [
|
601 |
+
# None,
|
602 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
603 |
+
# User: <fake_token_around_image><image:https://pbs.twimg.com/media/FooD7oyakAIU5_Q?format=jpg&name=large><fake_token_around_image>
|
604 |
+
# Describe what's in this image.
|
605 |
+
# Bot:""",
|
606 |
+
# 1,
|
607 |
+
# 2,
|
608 |
+
# 64,
|
609 |
+
# 10,
|
610 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
611 |
+
# False,
|
612 |
+
# False,
|
613 |
+
# True,
|
614 |
+
# ],
|
615 |
+
# [
|
616 |
+
# None,
|
617 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
618 |
+
# User: <fake_token_around_image><image:https://www.tutorialride.com/images/non-verbal-analogy-questions/non-verbal-analogy-logical-reasoning-1.jpg><fake_token_around_image>
|
619 |
+
# What's the correct answer? A, B, C or D?
|
620 |
+
# Bot:""",
|
621 |
+
# 1,
|
622 |
+
# 2,
|
623 |
+
# 64,
|
624 |
+
# 10,
|
625 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
626 |
+
# False,
|
627 |
+
# False,
|
628 |
+
# True,
|
629 |
+
# ],
|
630 |
+
]
|
631 |
+
|
632 |
+
|
633 |
+
MSG_MAIN = """
|
634 |
+
# Text generation with Vllama models
|
635 |
+
|
636 |
+
### Help to write prompts:
|
637 |
+
|
638 |
+
Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.
|
639 |
+
"""
|
640 |
+
# MSG_DETAILS = """
|
641 |
+
# ### Additional details
|
642 |
+
# - if the model was trained with the template 1 (`\\n\\n<image>\\n\\n`), then `<fake_token_around_image>` will be replaced with `\\n\\n`. This is particularly useful if you are comparing the performance of different models trained with different templates.
|
643 |
+
# - special tokens are not automatically added to the prompt, so add them manually.
|
644 |
+
# - with the first template `\\n\\n<image>\\n\\n` , the sequence isn't necessary tokenized as `["\\n\\n", "<image>", "\\n\\n"]` to enforce this behavior, you can use the "Integrate image sequence as ids" parameter.
|
645 |
+
# """
|
646 |
+
# if ~IS_MAIN_SPACE:
|
647 |
+
# MSG_DETAILS += (
|
648 |
+
# "- alternatively, you can upload images and then directly specify them via \<image\> tag in the prompt."
|
649 |
+
# )
|
650 |
+
|
651 |
+
with gr.Blocks() as demo:
|
652 |
+
gr.Markdown(MSG_MAIN)
|
653 |
+
with gr.Row():
|
654 |
+
with gr.Column():
|
655 |
+
gr.Markdown("## Input")
|
656 |
+
if not IS_MAIN_SPACE:
|
657 |
+
images = gr.File(label="Images", file_count="multiple")
|
658 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt here")
|
659 |
+
|
660 |
+
gr.Markdown("## Common parameters to all decoding strategy")
|
661 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Softmax temperature")
|
662 |
+
no_repeat_ngram_size = gr.Slider(
|
663 |
+
minimum=0,
|
664 |
+
maximum=10,
|
665 |
+
step=1,
|
666 |
+
value=0,
|
667 |
+
label="The size of an n-gram that cannot occur more than once (0=infinity)",
|
668 |
+
)
|
669 |
+
max_new_tokens = gr.Slider(
|
670 |
+
minimum=0, maximum=512, step=1, value=256, label="Maximum number of new tokens to generate"
|
671 |
+
)
|
672 |
+
min_length = gr.Slider(
|
673 |
+
minimum=0, maximum=512, step=1, value=16, label="Minimum length of the sequence to be generated"
|
674 |
+
)
|
675 |
+
ban_tokens = gr.Textbox(
|
676 |
+
label='Tokens to prevent from being generated (separated by ";")',
|
677 |
+
value="<image>;<fake_token_around_image>",
|
678 |
+
)
|
679 |
+
eos_tokens = gr.Textbox(label="EOS tokens", value="</s>")
|
680 |
+
force_words = gr.Textbox(label='Force words to be generated (separated by ";")', value="")
|
681 |
+
repetition_penalty = gr.Slider(
|
682 |
+
minimum=0, maximum=10, step=0.01, value=1, label="repetition_penalty. CTRL paper suggests 1.2."
|
683 |
+
)
|
684 |
+
hide_special_tokens = gr.Checkbox(label="Hide special tokens in the text", value=False)
|
685 |
+
|
686 |
+
gr.Markdown("## Decoding strategy and its specific parameters")
|
687 |
+
decoding_strategy = gr.Dropdown(
|
688 |
+
["greedy", "beam_search", "beam_sampling", "sampling_top_k", "sampling_top_p", "contrastive_sampling"],
|
689 |
+
label="Decoding strategy",
|
690 |
+
value="greedy",
|
691 |
+
)
|
692 |
+
num_beams = gr.Slider(
|
693 |
+
minimum=0,
|
694 |
+
maximum=10,
|
695 |
+
step=1,
|
696 |
+
value=3,
|
697 |
+
label="Beam size",
|
698 |
+
info="Only used if `decoding_strategy` is `beam_search` or `beam_sampling`",
|
699 |
+
)
|
700 |
+
length_penalty = gr.Slider(
|
701 |
+
minimum=-1000,
|
702 |
+
maximum=1000,
|
703 |
+
step=0.1,
|
704 |
+
value=1,
|
705 |
+
label=(
|
706 |
+
"length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter"
|
707 |
+
" sequences. Only used if `decoding_strategy` is `beam_search` or `beam_sampling`"
|
708 |
+
),
|
709 |
+
)
|
710 |
+
top_k = gr.Slider(
|
711 |
+
minimum=0,
|
712 |
+
maximum=500,
|
713 |
+
step=1,
|
714 |
+
value=50,
|
715 |
+
label="Top k",
|
716 |
+
info="Only used if `decoding_strategy` is `sampling_top_k` or `contrastive_sampling`",
|
717 |
+
)
|
718 |
+
top_p = gr.Slider(
|
719 |
+
minimum=0,
|
720 |
+
maximum=1,
|
721 |
+
step=0.01,
|
722 |
+
value=0.95,
|
723 |
+
label="Top p",
|
724 |
+
info="Only used if `decoding_strategy` is `sampling_top_p`",
|
725 |
+
)
|
726 |
+
penalty_alpha = gr.Slider(
|
727 |
+
minimum=0,
|
728 |
+
maximum=1,
|
729 |
+
step=0.01,
|
730 |
+
value=0.95,
|
731 |
+
label="Penalty alpha",
|
732 |
+
info="Only used if `decoding_strategy` is `contrastive_sampling`",
|
733 |
+
)
|
734 |
+
|
735 |
+
submit = gr.Button(label="Generate")
|
736 |
+
|
737 |
+
with gr.Column():
|
738 |
+
if IS_MAIN_SPACE:
|
739 |
+
outputs = [
|
740 |
+
gr.Textbox(label=MODEL_TO_DISPLAY_NAME[model], multiline=True, readonly=True)
|
741 |
+
for model in MODEL_TO_MODEL_CLASS
|
742 |
+
]
|
743 |
+
inference_func = all_model_inference
|
744 |
+
inputs = [
|
745 |
+
prompt,
|
746 |
+
temperature,
|
747 |
+
no_repeat_ngram_size,
|
748 |
+
max_new_tokens,
|
749 |
+
min_length,
|
750 |
+
ban_tokens,
|
751 |
+
eos_tokens,
|
752 |
+
force_words,
|
753 |
+
repetition_penalty,
|
754 |
+
hide_special_tokens,
|
755 |
+
decoding_strategy,
|
756 |
+
num_beams,
|
757 |
+
length_penalty,
|
758 |
+
top_k,
|
759 |
+
top_p,
|
760 |
+
penalty_alpha,
|
761 |
+
]
|
762 |
+
|
763 |
+
examples = [example[1:] for example in examples]
|
764 |
+
else:
|
765 |
+
outputs = gr.Textbox(label="Generated text", interactive=False)
|
766 |
+
inference_func = model_inference
|
767 |
+
inputs = [
|
768 |
+
images,
|
769 |
+
prompt,
|
770 |
+
temperature,
|
771 |
+
no_repeat_ngram_size,
|
772 |
+
max_new_tokens,
|
773 |
+
min_length,
|
774 |
+
ban_tokens,
|
775 |
+
eos_tokens,
|
776 |
+
force_words,
|
777 |
+
repetition_penalty,
|
778 |
+
hide_special_tokens,
|
779 |
+
decoding_strategy,
|
780 |
+
num_beams,
|
781 |
+
length_penalty,
|
782 |
+
top_k,
|
783 |
+
top_p,
|
784 |
+
penalty_alpha,
|
785 |
+
]
|
786 |
+
with gr.Row():
|
787 |
+
gr.Examples(inputs=inputs, examples=examples)
|
788 |
+
# gr.Markdown(MSG_DETAILS)
|
789 |
+
|
790 |
+
submit.click(inference_func, inputs=inputs, outputs=outputs, api_name="model_inference")
|
791 |
+
|
792 |
+
demo.queue()
|
793 |
+
demo.launch()
|
app_bis.py
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import time
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import requests
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
from accelerate.utils import get_max_memory
|
13 |
+
|
14 |
+
from joblib import Parallel, delayed
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import AutoTokenizer
|
17 |
+
|
18 |
+
from m4.models.vbloom import configuration_vbloom, modeling_vbloom
|
19 |
+
from m4.models.vgpt2 import configuration_vgpt2, modeling_vgpt2
|
20 |
+
from m4.models.vgpt_neo import configuration_vgpt_neo, modeling_vgpt_neo
|
21 |
+
from m4.models.vllama import configuration_vllama, modeling_vllama
|
22 |
+
from m4.models.vopt import configuration_vopt, modeling_vopt
|
23 |
+
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
|
24 |
+
from m4.training.utils import build_image_transform
|
25 |
+
|
26 |
+
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logger = logging.getLogger()
|
29 |
+
|
30 |
+
CURRENT_MODEL = "tr_209_ift_mixture_opt_step-2000"
|
31 |
+
|
32 |
+
MAX_TRIES = 3
|
33 |
+
TOKENIZER_FAST = True
|
34 |
+
MAX_SEQ_LEN = 1024
|
35 |
+
model, tokenizer = None, None
|
36 |
+
|
37 |
+
|
38 |
+
MODEL_TO_DISPLAY_NAME = {
|
39 |
+
"tr_199_w_xattn_opt_step-65000": "VLlama - tr_199_w_xattn_opt_step-65000",
|
40 |
+
"tr_201_sft_on_lrv_opt_step-15000": "VLlama - tr_201_sft_on_lrv_opt_step-15000",
|
41 |
+
"tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlama - tr_202bis_ift_llava_all_unfrozen_opt_step-14128",
|
42 |
+
"tr_203_ift_m3it_opt_step-50000": "VLlama - tr_203_ift_m3it_opt_step-50000",
|
43 |
+
"tr_205_sft_ultrachat_opt_step-20000": "VLlama - tr_205_sft_ultrachat_opt_step-20000",
|
44 |
+
"tr_207_ift_svit_opt_step-14627": "VLlama - tr_207_ift_svit_opt_step-14627",
|
45 |
+
"tr_209_ift_mixture_opt_step-2000": "VLlama - tr_209_ift_mixture_opt_step-2000",
|
46 |
+
}
|
47 |
+
MODEL_TO_MODEL_CLASS = {
|
48 |
+
"tr_199_w_xattn_opt_step-65000": "VLlamaForCausalLM",
|
49 |
+
"tr_201_sft_on_lrv_opt_step-15000": "VLlamaForCausalLM",
|
50 |
+
"tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaForCausalLM",
|
51 |
+
"tr_203_ift_m3it_opt_step-50000": "VLlamaForCausalLM",
|
52 |
+
"tr_205_sft_ultrachat_opt_step-20000": "VLlamaForCausalLM",
|
53 |
+
"tr_207_ift_svit_opt_step-14627": "VLlamaForCausalLM",
|
54 |
+
"tr_209_ift_mixture_opt_step-2000": "VLlamaForCausalLM",
|
55 |
+
}
|
56 |
+
|
57 |
+
MODEL_TO_CONFIG_CLASS = {
|
58 |
+
"tr_199_w_xattn_opt_step-65000": "VLlamaConfig",
|
59 |
+
"tr_201_sft_on_lrv_opt_step-15000": "VLlamaConfig",
|
60 |
+
"tr_202bis_ift_llava_all_unfrozen_opt_step-14128": "VLlamaConfig",
|
61 |
+
"tr_203_ift_m3it_opt_step-50000": "VLlamaConfig",
|
62 |
+
"tr_205_sft_ultrachat_opt_step-20000": "VLlamaConfig",
|
63 |
+
"tr_207_ift_svit_opt_step-14627": "VLlamaConfig",
|
64 |
+
"tr_209_ift_mixture_opt_step-2000": "VLlamaConfig",
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def load_tokenizer_model(model_name, model_class):
|
69 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
70 |
+
model_name,
|
71 |
+
use_fast=TOKENIZER_FAST,
|
72 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
|
73 |
+
)
|
74 |
+
tokenizer.padding_side = "left"
|
75 |
+
config_class = MODEL_TO_CONFIG_CLASS[model_name.split("/")[-1]]
|
76 |
+
|
77 |
+
# assert tokenizer.is_fast
|
78 |
+
|
79 |
+
supported_custom_modules = {
|
80 |
+
"vgpt2": modeling_vgpt2,
|
81 |
+
"vbloom": modeling_vbloom,
|
82 |
+
"vgptneo": modeling_vgpt_neo,
|
83 |
+
"vopt": modeling_vopt,
|
84 |
+
"vllama": modeling_vllama,
|
85 |
+
}
|
86 |
+
supported_custom_configs = {
|
87 |
+
"vgpt2": configuration_vgpt2,
|
88 |
+
"vbloom": configuration_vbloom,
|
89 |
+
"vgptneo": configuration_vgpt_neo,
|
90 |
+
"vopt": configuration_vopt,
|
91 |
+
"vllama": configuration_vllama,
|
92 |
+
}
|
93 |
+
parent_config_class = (
|
94 |
+
[v for k, v in supported_custom_configs.items() if k in model_class.lower()] + [transformers]
|
95 |
+
)[0]
|
96 |
+
parent_model_class = (
|
97 |
+
[v for k, v in supported_custom_modules.items() if k in model_class.lower()] + [transformers]
|
98 |
+
)[0]
|
99 |
+
config_class = getattr(parent_config_class, config_class)
|
100 |
+
model_class = getattr(parent_model_class, model_class)
|
101 |
+
config = config_class.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
|
102 |
+
max_memory_map = get_max_memory()
|
103 |
+
for key in max_memory_map.keys():
|
104 |
+
if key != "cpu":
|
105 |
+
# Get this in GB
|
106 |
+
max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
|
107 |
+
# Decrease 2 for Pytorch overhead and 2 for the forward to be safe
|
108 |
+
max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
|
109 |
+
model = model_class.from_pretrained(
|
110 |
+
model_name,
|
111 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
|
112 |
+
device_map="auto",
|
113 |
+
offload_folder="./offload",
|
114 |
+
torch_dtype=config.torch_dtype,
|
115 |
+
max_memory=max_memory_map,
|
116 |
+
)
|
117 |
+
model.eval()
|
118 |
+
print("Current device map:", model.hf_device_map)
|
119 |
+
print("Model default generation config:", model.generation_config)
|
120 |
+
# TODO: the device_map looks very inefficien right now. that could be improved
|
121 |
+
# it typically looks like that
|
122 |
+
# {
|
123 |
+
# 'model.embed_tokens': 0,
|
124 |
+
# 'model.vision_model': 0,
|
125 |
+
# 'model.layers.0': 0,
|
126 |
+
# 'model.layers.1': 0,
|
127 |
+
# 'model.layers.2': 0,
|
128 |
+
# 'model.layers.3': 0,
|
129 |
+
# 'model.layers.4': 0,
|
130 |
+
# 'model.layers.5': 0,
|
131 |
+
# 'model.layers.6': 1,
|
132 |
+
# 'model.layers.7': 1,
|
133 |
+
# 'model.layers.8': 1,
|
134 |
+
# 'model.layers.9': 1,
|
135 |
+
# 'model.layers.10': 1,
|
136 |
+
# 'model.layers.11': 1,
|
137 |
+
# 'model.layers.12': 1,
|
138 |
+
# 'model.layers.13': 1,
|
139 |
+
# 'model.layers.14': 1,
|
140 |
+
# 'model.layers.15': 1,
|
141 |
+
# 'model.layers.16': 1,
|
142 |
+
# 'model.layers.17': 2,
|
143 |
+
# 'model.layers.18': 2,
|
144 |
+
# 'model.layers.19': 2,
|
145 |
+
# 'model.layers.20': 2,
|
146 |
+
# 'model.layers.21': 2,
|
147 |
+
# 'model.layers.22': 2,
|
148 |
+
# 'model.layers.23': 2,
|
149 |
+
# 'model.layers.24': 2,
|
150 |
+
# 'model.layers.25': 2,
|
151 |
+
# 'model.layers.26': 2,
|
152 |
+
# 'model.layers.27': 2,
|
153 |
+
# 'model.layers.28': 3,
|
154 |
+
# 'model.layers.29': 3,
|
155 |
+
# 'model.layers.30': 3,
|
156 |
+
# 'model.layers.31': 3,
|
157 |
+
# 'model.gated_cross_attn_layers.0': 3,
|
158 |
+
# 'model.gated_cross_attn_layers.1': 3,
|
159 |
+
# 'model.gated_cross_attn_layers.2': 3,
|
160 |
+
# 'model.gated_cross_attn_layers.3': 3,
|
161 |
+
# 'model.gated_cross_attn_layers.4': 3,
|
162 |
+
# 'model.gated_cross_attn_layers.5': 3,
|
163 |
+
# 'model.gated_cross_attn_layers.6': 3,
|
164 |
+
# 'model.gated_cross_attn_layers.7': 3,
|
165 |
+
# 'model.gated_cross_attn_layers.8': 4,
|
166 |
+
# 'model.gated_cross_attn_layers.9': 4,
|
167 |
+
# 'model.gated_cross_attn_layers.10': 4,
|
168 |
+
# 'model.gated_cross_attn_layers.11': 4,
|
169 |
+
# 'model.gated_cross_attn_layers.12': 4,
|
170 |
+
# 'model.gated_cross_attn_layers.13': 4,
|
171 |
+
# 'model.gated_cross_attn_layers.14': 4,
|
172 |
+
# 'model.gated_cross_attn_layers.15': 4,
|
173 |
+
# 'model.norm': 4,
|
174 |
+
# 'lm_head': 4
|
175 |
+
# } which means there is a lot of things going around between the gated cross attention layers and the LM layers...
|
176 |
+
return tokenizer, model
|
177 |
+
|
178 |
+
|
179 |
+
MODEL_TO_SPACE_MAPPING = {}
|
180 |
+
IS_MAIN_SPACE = CURRENT_MODEL not in MODEL_TO_MODEL_CLASS
|
181 |
+
if IS_MAIN_SPACE:
|
182 |
+
for model in MODEL_TO_MODEL_CLASS:
|
183 |
+
MODEL_TO_SPACE_MAPPING[model] = gr.Blocks.load(
|
184 |
+
name=f"spaces/HuggingFaceM4/{model}", api_key=os.getenv("HF_AUTH_TOKEN", True)
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
model_path = f"HuggingFaceM4/{CURRENT_MODEL}"
|
188 |
+
tokenizer, model = load_tokenizer_model(model_path, MODEL_TO_MODEL_CLASS[CURRENT_MODEL])
|
189 |
+
|
190 |
+
|
191 |
+
def fetch_images(url_images):
|
192 |
+
images = []
|
193 |
+
for url in url_images:
|
194 |
+
if isinstance(url, str):
|
195 |
+
images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
|
196 |
+
else:
|
197 |
+
images.append(url)
|
198 |
+
return images
|
199 |
+
|
200 |
+
|
201 |
+
def model_generation(
|
202 |
+
prompt,
|
203 |
+
images,
|
204 |
+
tokenizer,
|
205 |
+
model,
|
206 |
+
temperature,
|
207 |
+
no_repeat_ngram_size,
|
208 |
+
max_new_tokens,
|
209 |
+
min_length,
|
210 |
+
ban_tokens,
|
211 |
+
forced_eos_token_id,
|
212 |
+
eos_tokens,
|
213 |
+
force_words,
|
214 |
+
length_penalty,
|
215 |
+
repetition_penalty,
|
216 |
+
hide_special_tokens,
|
217 |
+
stop_generation,
|
218 |
+
decoding_strategy,
|
219 |
+
num_beams,
|
220 |
+
top_k,
|
221 |
+
top_p,
|
222 |
+
penalty_alpha,
|
223 |
+
):
|
224 |
+
# Preparing inputs
|
225 |
+
tokens = tokenizer(
|
226 |
+
[prompt],
|
227 |
+
truncation=True,
|
228 |
+
max_length=MAX_SEQ_LEN,
|
229 |
+
padding=True,
|
230 |
+
add_special_tokens=False,
|
231 |
+
)
|
232 |
+
|
233 |
+
input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
|
234 |
+
attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
|
235 |
+
|
236 |
+
image_attention_mask = [
|
237 |
+
incremental_to_binary_attention_mask(
|
238 |
+
image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
|
239 |
+
)
|
240 |
+
]
|
241 |
+
|
242 |
+
image_transform = build_image_transform(eval=True)
|
243 |
+
pixel_values = [torch.stack([image_transform(img) for img in images])]
|
244 |
+
|
245 |
+
input_ids = input_ids.to(0)
|
246 |
+
attention_mask = attention_mask.to(0)
|
247 |
+
pixel_values = torch.stack(pixel_values).to(0)
|
248 |
+
image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
|
249 |
+
|
250 |
+
# Excluding some words from the generation
|
251 |
+
bad_words_ids = None
|
252 |
+
ban_tokens = ban_tokens.replace("\\n", "\n")
|
253 |
+
bad_words = ban_tokens.split(";")
|
254 |
+
if len(bad_words) > 0:
|
255 |
+
bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
|
256 |
+
|
257 |
+
# Forcing some words in the generation
|
258 |
+
force_words_ids = None
|
259 |
+
if force_words != "":
|
260 |
+
force_words = force_words.replace("\\n", "\n")
|
261 |
+
force_words = force_words.split(";")
|
262 |
+
if len(force_words) > 0:
|
263 |
+
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
264 |
+
|
265 |
+
# eos_token_ids = None
|
266 |
+
# if eos_tokens != "":
|
267 |
+
# eos_tokens = eos_tokens.replace("\\n", "\n")
|
268 |
+
# eos_tokens = eos_tokens.split(";")
|
269 |
+
# if len(eos_tokens) > 0:
|
270 |
+
# eos_token_ids = []
|
271 |
+
# for eos_token in eos_tokens:
|
272 |
+
# tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
|
273 |
+
# if len(tokenized_eos_token) > 1:
|
274 |
+
# raise ValueError(
|
275 |
+
# f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
|
276 |
+
# f" {tokenized_eos_token}"
|
277 |
+
# )
|
278 |
+
# eos_token_ids += tokenized_eos_token
|
279 |
+
|
280 |
+
# if forced_eos_token_id and eos_token_ids is None:
|
281 |
+
# raise ValueError("You can't use forced_eos_token_id without eos_tokens")
|
282 |
+
# elif forced_eos_token_id:
|
283 |
+
# forced_eos_token_id = eos_token_ids
|
284 |
+
# else:
|
285 |
+
# forced_eos_token_id = None
|
286 |
+
|
287 |
+
# Inputs
|
288 |
+
input_args = {
|
289 |
+
"input_ids": input_ids,
|
290 |
+
"attention_mask": attention_mask,
|
291 |
+
"pixel_values": pixel_values,
|
292 |
+
"image_attention_mask": image_attention_mask,
|
293 |
+
}
|
294 |
+
# Common parameters to all decoding strategies
|
295 |
+
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
296 |
+
generation_args = {
|
297 |
+
"temperature": temperature,
|
298 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
299 |
+
"max_new_tokens": max_new_tokens,
|
300 |
+
"min_length": min_length,
|
301 |
+
"bad_words_ids": bad_words_ids,
|
302 |
+
# "forced_eos_token_id": forced_eos_token_id,
|
303 |
+
"force_words_ids": force_words_ids,
|
304 |
+
"length_penalty": length_penalty,
|
305 |
+
"repetition_penalty": repetition_penalty,
|
306 |
+
"eos_token_id": tokenizer.eos_token_id,
|
307 |
+
}
|
308 |
+
|
309 |
+
assert decoding_strategy in [
|
310 |
+
"greedy",
|
311 |
+
"beam_search",
|
312 |
+
"beam_sampling",
|
313 |
+
"sampling_top_k",
|
314 |
+
"sampling_top_p",
|
315 |
+
"contrastive_sampling",
|
316 |
+
]
|
317 |
+
if decoding_strategy == "greedy":
|
318 |
+
pass
|
319 |
+
elif decoding_strategy == "beam_search":
|
320 |
+
generation_args["num_beams"] = num_beams
|
321 |
+
assert generation_args["num_beams"] > 1
|
322 |
+
elif decoding_strategy == "beam_sampling":
|
323 |
+
generation_args["num_beams"] = num_beams
|
324 |
+
generation_args["do_sample"] = True
|
325 |
+
assert generation_args["num_beams"] > 1
|
326 |
+
elif decoding_strategy == "sampling_top_k":
|
327 |
+
generation_args["do_sample"] = True
|
328 |
+
generation_args["top_k"] = top_k
|
329 |
+
elif decoding_strategy == "sampling_top_p":
|
330 |
+
generation_args["do_sample"] = True
|
331 |
+
generation_args["top_p"] = top_p
|
332 |
+
elif decoding_strategy == "contrastive_sampling":
|
333 |
+
generation_args["do_sample"] = True
|
334 |
+
generation_args["penalty_alpha"] = penalty_alpha
|
335 |
+
generation_args["top_k"] = top_k
|
336 |
+
|
337 |
+
generated_tokens = model.generate(
|
338 |
+
**input_args,
|
339 |
+
**generation_args,
|
340 |
+
)
|
341 |
+
tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
|
342 |
+
decoded_skip_special_tokens = repr(
|
343 |
+
tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
|
344 |
+
)
|
345 |
+
decoded = repr(tokenizer.batch_decode(generated_tokens)[0])
|
346 |
+
logger.info(
|
347 |
+
"Result: \n"
|
348 |
+
f"Prompt: `{prompt}`\n"
|
349 |
+
f"Tokens ids from prompt + generation: `{generated_tokens[0].tolist()}`\n"
|
350 |
+
f"Tokens (converted) from prompt + generation: `{tokens}`\n"
|
351 |
+
f"String decoded with skipped special tokens: `{decoded_skip_special_tokens}`\n"
|
352 |
+
f"String decoded: `{decoded}`\n"
|
353 |
+
f"Generation mode: `{decoding_strategy}`\n"
|
354 |
+
f"Generation parameters: `{generation_args}`\n"
|
355 |
+
)
|
356 |
+
|
357 |
+
original_prompt = generated_tokens[:, : input_ids.shape[-1]]
|
358 |
+
actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
|
359 |
+
|
360 |
+
if stop_generation:
|
361 |
+
# Additional stopping criteria: generating <image> token, <end_of_text> token or <begin_of_text> token
|
362 |
+
assert tokenizer.additional_special_tokens[-1] == "<image>"
|
363 |
+
image_token_id = tokenizer.additional_special_tokens_ids[-1]
|
364 |
+
end_of_text_token_id = tokenizer.eos_token_id
|
365 |
+
begin_of_text_token_id = tokenizer.bos_token_id
|
366 |
+
|
367 |
+
image_token_ids = (actual_generated_tokens == image_token_id).nonzero(as_tuple=True)[1]
|
368 |
+
end_of_text_token_ids = (actual_generated_tokens == end_of_text_token_id).nonzero(as_tuple=True)[1]
|
369 |
+
begin_of_text_token_ids = (actual_generated_tokens == begin_of_text_token_id).nonzero(as_tuple=True)[1]
|
370 |
+
|
371 |
+
first_end_token = min(
|
372 |
+
image_token_ids[0] if len(image_token_ids) else len(actual_generated_tokens[0]),
|
373 |
+
end_of_text_token_ids[0] if len(end_of_text_token_ids) else len(actual_generated_tokens[0]),
|
374 |
+
begin_of_text_token_ids[0] if len(begin_of_text_token_ids) else len(actual_generated_tokens[0]),
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
first_end_token = len(actual_generated_tokens[0])
|
378 |
+
|
379 |
+
actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
|
380 |
+
displayed_tokens = torch.cat([original_prompt, actual_generated_tokens], dim=-1)
|
381 |
+
generated_text = tokenizer.batch_decode(displayed_tokens, skip_special_tokens=hide_special_tokens)[0]
|
382 |
+
return generated_text
|
383 |
+
|
384 |
+
|
385 |
+
def model_inference(
|
386 |
+
files,
|
387 |
+
prompt,
|
388 |
+
temperature,
|
389 |
+
no_repeat_ngram_size,
|
390 |
+
max_new_tokens,
|
391 |
+
min_length,
|
392 |
+
ban_tokens,
|
393 |
+
forced_eos_token_id,
|
394 |
+
eos_tokens,
|
395 |
+
force_words,
|
396 |
+
length_penalty,
|
397 |
+
repetition_penalty,
|
398 |
+
hide_special_tokens,
|
399 |
+
stop_generation,
|
400 |
+
decoding_strategy,
|
401 |
+
num_beams,
|
402 |
+
top_k,
|
403 |
+
top_p,
|
404 |
+
penalty_alpha,
|
405 |
+
):
|
406 |
+
if isinstance(files, str) and len(files) == 0:
|
407 |
+
files = None
|
408 |
+
|
409 |
+
prompt = prompt.strip()
|
410 |
+
prompt = prompt.replace("\\n", "\n")
|
411 |
+
file_idx = 0
|
412 |
+
url_images = re.findall(r"<image(.*?)>", prompt)
|
413 |
+
for idx, url_image in enumerate(url_images):
|
414 |
+
if len(url_image) == 0:
|
415 |
+
url_images[idx] = Image.open(files[file_idx].name if hasattr(files[file_idx], "name") else files[file_idx])
|
416 |
+
file_idx += 1
|
417 |
+
else:
|
418 |
+
prompt = prompt.replace(url_image, "")
|
419 |
+
url_images[idx] = url_images[idx][1:]
|
420 |
+
images = fetch_images(url_images)
|
421 |
+
|
422 |
+
global model, tokenizer
|
423 |
+
|
424 |
+
generated_text = model_generation(
|
425 |
+
prompt=prompt,
|
426 |
+
images=images,
|
427 |
+
tokenizer=tokenizer,
|
428 |
+
model=model,
|
429 |
+
temperature=temperature,
|
430 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
431 |
+
max_new_tokens=max_new_tokens,
|
432 |
+
min_length=min_length,
|
433 |
+
ban_tokens=ban_tokens,
|
434 |
+
forced_eos_token_id=forced_eos_token_id,
|
435 |
+
eos_tokens=eos_tokens,
|
436 |
+
force_words=force_words,
|
437 |
+
length_penalty=length_penalty,
|
438 |
+
repetition_penalty=repetition_penalty,
|
439 |
+
hide_special_tokens=hide_special_tokens,
|
440 |
+
stop_generation=stop_generation,
|
441 |
+
decoding_strategy=decoding_strategy,
|
442 |
+
num_beams=num_beams,
|
443 |
+
top_k=top_k,
|
444 |
+
top_p=top_p,
|
445 |
+
penalty_alpha=penalty_alpha,
|
446 |
+
)
|
447 |
+
return generated_text.strip()
|
448 |
+
|
449 |
+
|
450 |
+
def try_model_inference(
|
451 |
+
model,
|
452 |
+
files,
|
453 |
+
prompt,
|
454 |
+
temperature,
|
455 |
+
no_repeat_ngram_size,
|
456 |
+
max_new_tokens,
|
457 |
+
min_length,
|
458 |
+
ban_tokens,
|
459 |
+
forced_eos_token_id,
|
460 |
+
eos_tokens,
|
461 |
+
force_words,
|
462 |
+
length_penalty,
|
463 |
+
repetition_penalty,
|
464 |
+
hide_special_tokens,
|
465 |
+
stop_generation,
|
466 |
+
decoding_strategy,
|
467 |
+
num_beams,
|
468 |
+
top_k,
|
469 |
+
top_p,
|
470 |
+
penalty_alpha,
|
471 |
+
):
|
472 |
+
count = 0
|
473 |
+
while count < MAX_TRIES:
|
474 |
+
try:
|
475 |
+
return MODEL_TO_SPACE_MAPPING[model](
|
476 |
+
files,
|
477 |
+
prompt,
|
478 |
+
temperature,
|
479 |
+
no_repeat_ngram_size,
|
480 |
+
max_new_tokens,
|
481 |
+
min_length,
|
482 |
+
ban_tokens,
|
483 |
+
forced_eos_token_id,
|
484 |
+
eos_tokens,
|
485 |
+
force_words,
|
486 |
+
length_penalty,
|
487 |
+
repetition_penalty,
|
488 |
+
hide_special_tokens,
|
489 |
+
stop_generation,
|
490 |
+
decoding_strategy,
|
491 |
+
num_beams,
|
492 |
+
top_k,
|
493 |
+
top_p,
|
494 |
+
penalty_alpha,
|
495 |
+
api_name="model_inference",
|
496 |
+
)
|
497 |
+
except KeyError:
|
498 |
+
# Gradio return {'error': None} some times.
|
499 |
+
time.sleep(3)
|
500 |
+
count += 1
|
501 |
+
pass
|
502 |
+
|
503 |
+
|
504 |
+
def all_model_inference(
|
505 |
+
prompt,
|
506 |
+
temperature,
|
507 |
+
no_repeat_ngram_size,
|
508 |
+
max_new_tokens,
|
509 |
+
min_length,
|
510 |
+
ban_tokens,
|
511 |
+
forced_eos_token_id,
|
512 |
+
eos_tokens,
|
513 |
+
force_words,
|
514 |
+
length_penalty,
|
515 |
+
repetition_penalty,
|
516 |
+
hide_special_tokens,
|
517 |
+
stop_generation,
|
518 |
+
decoding_strategy,
|
519 |
+
num_beams,
|
520 |
+
top_k,
|
521 |
+
top_p,
|
522 |
+
penalty_alpha,
|
523 |
+
):
|
524 |
+
outputs = []
|
525 |
+
print(
|
526 |
+
prompt,
|
527 |
+
temperature,
|
528 |
+
no_repeat_ngram_size,
|
529 |
+
max_new_tokens,
|
530 |
+
min_length,
|
531 |
+
ban_tokens,
|
532 |
+
forced_eos_token_id,
|
533 |
+
eos_tokens,
|
534 |
+
force_words,
|
535 |
+
length_penalty,
|
536 |
+
repetition_penalty,
|
537 |
+
hide_special_tokens,
|
538 |
+
stop_generation,
|
539 |
+
decoding_strategy,
|
540 |
+
num_beams,
|
541 |
+
top_k,
|
542 |
+
top_p,
|
543 |
+
penalty_alpha,
|
544 |
+
)
|
545 |
+
outputs = Parallel(n_jobs=len(MODEL_TO_SPACE_MAPPING), backend="threading")(
|
546 |
+
delayed(try_model_inference)(
|
547 |
+
model,
|
548 |
+
os.path.join(os.path.dirname(__file__), "images", "bear.jpg"),
|
549 |
+
prompt,
|
550 |
+
temperature,
|
551 |
+
no_repeat_ngram_size,
|
552 |
+
max_new_tokens,
|
553 |
+
min_length,
|
554 |
+
ban_tokens,
|
555 |
+
forced_eos_token_id,
|
556 |
+
eos_tokens,
|
557 |
+
force_words,
|
558 |
+
length_penalty,
|
559 |
+
repetition_penalty,
|
560 |
+
hide_special_tokens,
|
561 |
+
stop_generation,
|
562 |
+
decoding_strategy,
|
563 |
+
num_beams,
|
564 |
+
top_k,
|
565 |
+
top_p,
|
566 |
+
penalty_alpha,
|
567 |
+
)
|
568 |
+
for model in MODEL_TO_SPACE_MAPPING
|
569 |
+
)
|
570 |
+
if len(outputs) == 1:
|
571 |
+
outputs = outputs[0]
|
572 |
+
return outputs
|
573 |
+
|
574 |
+
|
575 |
+
examples = [
|
576 |
+
[
|
577 |
+
None,
|
578 |
+
"""This is a conversation between a human, User, and an intelligent visual AI, Assistant. User sends images, and Assistant answers the questions from the user. The assistant should be friendly, informative and should not change the topic if it's not asked to.
|
579 |
+
|
580 |
+
Here's an example of a conversation:
|
581 |
+
User:<fake_token_around_image><image:https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg><fake_token_around_image>Describe this image.
|
582 |
+
Assistant: Two kittens are cuddling in the center of the photograph. They are surrounded by tall, bright green grass, and the background is blurred.
|
583 |
+
User:<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>How about this image? Can you describe it too?
|
584 |
+
Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you, tilting its head to the side. The dog has a white body and brown patches on its face and ears. Its eyes are brown. Its nose is black, and it has long, floppy ears, short legs, white paws, long fur, big eyes, and black eyebrows.
|
585 |
+
User: What kind of breed is it?
|
586 |
+
Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.
|
587 |
+
---
|
588 |
+
User:<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.
|
589 |
+
Assistant:""",
|
590 |
+
1.0,
|
591 |
+
0,
|
592 |
+
256,
|
593 |
+
10,
|
594 |
+
"<image>;<fake_token_around_image>",
|
595 |
+
False,
|
596 |
+
"</s>",
|
597 |
+
"",
|
598 |
+
1.0,
|
599 |
+
1.0,
|
600 |
+
False,
|
601 |
+
True,
|
602 |
+
"greedy",
|
603 |
+
1,
|
604 |
+
50,
|
605 |
+
0.5,
|
606 |
+
0.95,
|
607 |
+
],
|
608 |
+
# [
|
609 |
+
# None,
|
610 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
611 |
+
# User: <fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>
|
612 |
+
# Describe this image.
|
613 |
+
# Bot:""",
|
614 |
+
# 1,
|
615 |
+
# 2,
|
616 |
+
# 64,
|
617 |
+
# 10,
|
618 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
619 |
+
# False,
|
620 |
+
# False,
|
621 |
+
# True,
|
622 |
+
# ],
|
623 |
+
# [
|
624 |
+
# None,
|
625 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
626 |
+
# User: <fake_token_around_image><image:https://i.redd.it/hsktcp4nv1g01.jpg><fake_token_around_image>
|
627 |
+
# Why do people find this image funny?
|
628 |
+
# Bot:""",
|
629 |
+
# 1,
|
630 |
+
# 2,
|
631 |
+
# 64,
|
632 |
+
# 10,
|
633 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
634 |
+
# False,
|
635 |
+
# False,
|
636 |
+
# True,
|
637 |
+
# ],
|
638 |
+
# [
|
639 |
+
# None,
|
640 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
641 |
+
# User: <fake_token_around_image><image:https://pbs.twimg.com/media/FooD7oyakAIU5_Q?format=jpg&name=large><fake_token_around_image>
|
642 |
+
# Describe what's in this image.
|
643 |
+
# Bot:""",
|
644 |
+
# 1,
|
645 |
+
# 2,
|
646 |
+
# 64,
|
647 |
+
# 10,
|
648 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
649 |
+
# False,
|
650 |
+
# False,
|
651 |
+
# True,
|
652 |
+
# ],
|
653 |
+
# [
|
654 |
+
# None,
|
655 |
+
# """This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answer the questions from the user.
|
656 |
+
# User: <fake_token_around_image><image:https://www.tutorialride.com/images/non-verbal-analogy-questions/non-verbal-analogy-logical-reasoning-1.jpg><fake_token_around_image>
|
657 |
+
# What's the correct answer? A, B, C or D?
|
658 |
+
# Bot:""",
|
659 |
+
# 1,
|
660 |
+
# 2,
|
661 |
+
# 64,
|
662 |
+
# 10,
|
663 |
+
# "<image>;<fake_token_around_image>;User;user;Bot;bot;Question;question;Answer;answer;\n",
|
664 |
+
# False,
|
665 |
+
# False,
|
666 |
+
# True,
|
667 |
+
# ],
|
668 |
+
]
|
669 |
+
|
670 |
+
|
671 |
+
title = """<head><title><h1 align='center'>🔮✍️ Text generation with IDEFICS models 🦙📚</h1></title></head>"""
|
672 |
+
|
673 |
+
|
674 |
+
MSG_MAIN = """
|
675 |
+
# Text generation with Vllama models
|
676 |
+
|
677 |
+
### Help to write prompts:
|
678 |
+
|
679 |
+
Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.
|
680 |
+
"""
|
681 |
+
# MSG_DETAILS = """
|
682 |
+
# ### Additional details
|
683 |
+
# - if the model was trained with the template 1 (`\\n\\n<image>\\n\\n`), then `<fake_token_around_image>` will be replaced with `\\n\\n`. This is particularly useful if you are comparing the performance of different models trained with different templates.
|
684 |
+
# - special tokens are not automatically added to the prompt, so add them manually.
|
685 |
+
# - with the first template `\\n\\n<image>\\n\\n` , the sequence isn't necessary tokenized as `["\\n\\n", "<image>", "\\n\\n"]` to enforce this behavior, you can use the "Integrate image sequence as ids" parameter.
|
686 |
+
# """
|
687 |
+
# if ~IS_MAIN_SPACE:
|
688 |
+
# MSG_DETAILS += (
|
689 |
+
# "- alternatively, you can upload images and then directly specify them via \<image\> tag in the prompt."
|
690 |
+
# )
|
691 |
+
|
692 |
+
with gr.Blocks() as demo:
|
693 |
+
gr.HTML(title)
|
694 |
+
gr.HTML("""<h3 align='center'>Help to write prompts:🙌</h3><br>
|
695 |
+
<p>Put the urls to the images inside the image tokens,
|
696 |
+
it will be converted into the real image tokens.
|
697 |
+
Put <fake_token_around_image> before and after each
|
698 |
+
image token WITHOUT space. The texts \\n will be
|
699 |
+
converted into real newline characters.
|
700 |
+
See examples and additional details below.""")
|
701 |
+
|
702 |
+
#gr.HTML("<h3 align='center'>Help to write prompts:🙌</h3><br>Put the urls to the images inside the image tokens, it will be converted into the real image tokens. Put <fake_token_around_image> before and after each image token WITHOUT space. The texts \\n will be converted into real newline characters. See examples and additional details below.")
|
703 |
+
#gr.Markdown(MSG_MAIN)
|
704 |
+
#with gr.Row():
|
705 |
+
#with gr.Column():
|
706 |
+
gr.Markdown("## Input")
|
707 |
+
with gr.Row():
|
708 |
+
if not IS_MAIN_SPACE:
|
709 |
+
images = gr.File(label="Images", file_count="multiple")
|
710 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter the prompt here", lines=5)
|
711 |
+
|
712 |
+
#gr.Markdown("## Common parameters to all decoding strategy")
|
713 |
+
with gr.Row():
|
714 |
+
with gr.Accordion("Common parameters to all decoding strategy", open=False, elem_id="common_params"):
|
715 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Softmax temperature")
|
716 |
+
no_repeat_ngram_size = gr.Slider(
|
717 |
+
minimum=0,
|
718 |
+
maximum=10,
|
719 |
+
step=1,
|
720 |
+
value=0,
|
721 |
+
label="The size of an n-gram that cannot occur more than once (0=infinity)",
|
722 |
+
)
|
723 |
+
max_new_tokens = gr.Slider(
|
724 |
+
minimum=0, maximum=512, step=1, value=256, label="Maximum number of new tokens to generate"
|
725 |
+
)
|
726 |
+
min_length = gr.Slider(
|
727 |
+
minimum=0, maximum=512, step=1, value=16, label="Minimum length of the sequence to be generated"
|
728 |
+
)
|
729 |
+
ban_tokens = gr.Textbox(
|
730 |
+
label='Tokens to prevent from being generated (separated by ";")',
|
731 |
+
value="<image>;<fake_token_around_image>",
|
732 |
+
)
|
733 |
+
forced_eos_token_id = gr.Checkbox(label="Forced eos token id", value=False)
|
734 |
+
eos_tokens = gr.Textbox(label="EOS tokens", value="</s>")
|
735 |
+
force_words = gr.Textbox(label='Force words to be generated (separated by ";")', value="")
|
736 |
+
length_penalty = gr.Slider(
|
737 |
+
minimum=-1000,
|
738 |
+
maximum=1000,
|
739 |
+
step=0.1,
|
740 |
+
value=1,
|
741 |
+
label=(
|
742 |
+
"length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter"
|
743 |
+
" sequences."
|
744 |
+
),
|
745 |
+
)
|
746 |
+
repetition_penalty = gr.Slider(
|
747 |
+
minimum=0, maximum=10, step=0.01, value=1, label="repetition_penalty. CTRL paper suggests 1.2."
|
748 |
+
)
|
749 |
+
hide_special_tokens = gr.Checkbox(label="Hide special tokens in the text", value=False)
|
750 |
+
stop_generation = gr.Checkbox(
|
751 |
+
label="Stop generation when an image token, a bos or a eos token is generated", value=False
|
752 |
+
)
|
753 |
+
|
754 |
+
#gr.Markdown("## Decoding strategy and its specific parameters")
|
755 |
+
with gr.Accordion("Decoding strategy and its specific parameters", open=False, elem_id="decoding_params"):
|
756 |
+
decoding_strategy = gr.Dropdown(
|
757 |
+
["greedy", "beam_search", "beam_sampling", "sampling_top_k", "sampling_top_p", "contrastive_sampling"],
|
758 |
+
label="Decoding strategy",
|
759 |
+
value="greedy",
|
760 |
+
)
|
761 |
+
num_beams = gr.Slider(
|
762 |
+
minimum=0,
|
763 |
+
maximum=10,
|
764 |
+
step=1,
|
765 |
+
value=3,
|
766 |
+
label="Beam size",
|
767 |
+
info="Only used if `decoding_strategy` is `beam_search` or `beam_sampling`",
|
768 |
+
)
|
769 |
+
top_k = gr.Slider(
|
770 |
+
minimum=0,
|
771 |
+
maximum=500,
|
772 |
+
step=1,
|
773 |
+
value=50,
|
774 |
+
label="Top k",
|
775 |
+
info="Only used if `decoding_strategy` is `sampling_top_k` or `contrastive_sampling`",
|
776 |
+
)
|
777 |
+
top_p = gr.Slider(
|
778 |
+
minimum=0,
|
779 |
+
maximum=1,
|
780 |
+
step=0.01,
|
781 |
+
value=0.95,
|
782 |
+
label="Top p",
|
783 |
+
info="Only used if `decoding_strategy` is `sampling_top_p`",
|
784 |
+
)
|
785 |
+
penalty_alpha = gr.Slider(
|
786 |
+
minimum=0,
|
787 |
+
maximum=1,
|
788 |
+
step=0.01,
|
789 |
+
value=0.95,
|
790 |
+
label="Penalty alpha",
|
791 |
+
info="Only used if `decoding_strategy` is `contrastive_sampling`",
|
792 |
+
)
|
793 |
+
|
794 |
+
submit = gr.Button(label="Generate")
|
795 |
+
|
796 |
+
#with gr.Column():
|
797 |
+
with gr.Row():
|
798 |
+
if IS_MAIN_SPACE:
|
799 |
+
outputs = [
|
800 |
+
gr.Textbox(label=MODEL_TO_DISPLAY_NAME[model], multiline=True, readonly=True)
|
801 |
+
for model in MODEL_TO_MODEL_CLASS
|
802 |
+
]
|
803 |
+
inference_func = all_model_inference
|
804 |
+
inputs = [
|
805 |
+
prompt,
|
806 |
+
temperature,
|
807 |
+
no_repeat_ngram_size,
|
808 |
+
max_new_tokens,
|
809 |
+
min_length,
|
810 |
+
ban_tokens,
|
811 |
+
forced_eos_token_id,
|
812 |
+
eos_tokens,
|
813 |
+
force_words,
|
814 |
+
length_penalty,
|
815 |
+
repetition_penalty,
|
816 |
+
hide_special_tokens,
|
817 |
+
stop_generation,
|
818 |
+
decoding_strategy,
|
819 |
+
num_beams,
|
820 |
+
top_k,
|
821 |
+
top_p,
|
822 |
+
penalty_alpha,
|
823 |
+
]
|
824 |
+
|
825 |
+
# examples = [example[1:] for example in examples]
|
826 |
+
else:
|
827 |
+
outputs = gr.Textbox(label="Generated text", interactive=False, lines=5)
|
828 |
+
inference_func = model_inference
|
829 |
+
inputs = [
|
830 |
+
images,
|
831 |
+
prompt,
|
832 |
+
temperature,
|
833 |
+
no_repeat_ngram_size,
|
834 |
+
max_new_tokens,
|
835 |
+
min_length,
|
836 |
+
ban_tokens,
|
837 |
+
forced_eos_token_id,
|
838 |
+
eos_tokens,
|
839 |
+
force_words,
|
840 |
+
length_penalty,
|
841 |
+
repetition_penalty,
|
842 |
+
hide_special_tokens,
|
843 |
+
stop_generation,
|
844 |
+
decoding_strategy,
|
845 |
+
num_beams,
|
846 |
+
top_k,
|
847 |
+
top_p,
|
848 |
+
penalty_alpha,
|
849 |
+
]
|
850 |
+
with gr.Row():
|
851 |
+
gr.Examples(inputs=inputs, examples=examples)
|
852 |
+
# gr.Markdown(MSG_DETAILS)
|
853 |
+
|
854 |
+
submit.click(inference_func, inputs=inputs, outputs=outputs, api_name="model_inference")
|
855 |
+
|
856 |
+
demo.queue()
|
857 |
+
demo.launch()
|
app_dialogue.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import requests
|
4 |
+
|
5 |
+
|
6 |
+
models = [
|
7 |
+
"HuggingFaceM4/tr_209_ift_mixture_opt_step-14000"
|
8 |
+
]
|
9 |
+
|
10 |
+
SYSTEM_PROMPT = """The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Assistant, and a human user, called User. In the following interactions, User and Assistant will converse in natural language, and Assistant will do its best to answer User’s questions. Assistant was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. When prompted with an image, it does not make up facts.
|
11 |
+
|
12 |
+
The conversation begins:
|
13 |
+
User:<fake_token_around_image><image:https://i1.pickpik.com/photos/515/906/355/kittens-cats-pet-cute-preview.jpg><fake_token_around_image>Describe this image.</s>
|
14 |
+
Assistant: Five kittens are standing together in the center of the photograph. From the left to right, there is one orange kitte, two kittens white and black stripes, and two brown kittens with an orange head. They are in bright green grass and it looks like they are moving forward.</s>
|
15 |
+
User:How about this image?<fake_token_around_image><image:https://cdn.pixabay.com/photo/2017/09/25/13/12/puppy-2785074_1280.jpg><fake_token_around_image>Can you describe it too?</s>
|
16 |
+
Assistant: A dog is lying on the floor, looking at the camera. It is looking directly at you. The dog has a white body and brown patches on its face and ears. Its eyes are dark. Its nose is black, and it has long, floppy ears, white paws, long fur, big eyes.</s>
|
17 |
+
User: What kind of breed is it?</s>
|
18 |
+
Assistant: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.</s>
|
19 |
+
User: What can you tell me about this breed of dogs?</s>
|
20 |
+
Assistant: One specific characteristic of the Cavalier King Charles Spaniel is its friendly and affectionate nature. This breed is known for being extremely sociable and forming strong bonds with their owners. They are often described as "velcro dogs" because they love to be close to their human companions, whether it's sitting on their laps, cuddling on the couch, or simply following them around the house.</s>
|
21 |
+
User: ghjkhjabnufs</s>
|
22 |
+
Assistant: That doesn’t seem to be a word. Could you ask me another way?</s>
|
23 |
+
User: Do you like Cavalier King Charles Spaniel?</s>
|
24 |
+
Assistant: I do not have personal opinions as I’m just a computer program. However, cavaliers are known for being excellent family pets due to their gentle and patient demeanor, making them great with children and other pets. Their affectionate nature also makes them ideal therapy dogs, as they can provide comfort and emotional support to those in need.</s>
|
25 |
+
User: How many dogs do you see in this image?<fake_token_around_image><image:https://i.dailymail.co.uk/i/pix/2011/07/01/article-2010308-0CD22A8300000578-496_634x414.jpg><fake_token_around_image></s>
|
26 |
+
Assistant: There is no dogs in this image. The picture shows a tennis player jumping to volley the ball.</s>"""
|
27 |
+
|
28 |
+
BAN_TOKENS = "<image>;<fake_token_around_image>"
|
29 |
+
EOS_TOKENS = "</s>;User"
|
30 |
+
|
31 |
+
from accelerate.utils import get_max_memory
|
32 |
+
from transformers import AutoTokenizer
|
33 |
+
from m4.models.vllama.configuration_vllama import VLlamaConfig
|
34 |
+
from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
|
35 |
+
|
36 |
+
import logging
|
37 |
+
from PIL import Image
|
38 |
+
from io import BytesIO
|
39 |
+
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
|
40 |
+
from m4.training.utils import build_image_transform
|
41 |
+
import torch
|
42 |
+
import re
|
43 |
+
|
44 |
+
TOKENIZER_FAST = True
|
45 |
+
MAX_SEQ_LEN = 1024
|
46 |
+
|
47 |
+
logging.basicConfig(level=logging.INFO)
|
48 |
+
logger = logging.getLogger()
|
49 |
+
|
50 |
+
|
51 |
+
def load_tokenizer_model(model_name):
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
53 |
+
model_name,
|
54 |
+
use_fast=TOKENIZER_FAST,
|
55 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True), # `use_fast=False` for 1B3 OPT, True for all the other models
|
56 |
+
)
|
57 |
+
tokenizer.padding_side = "left"
|
58 |
+
|
59 |
+
config = VLlamaConfig.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
|
60 |
+
max_memory_map = get_max_memory()
|
61 |
+
|
62 |
+
for key in max_memory_map.keys():
|
63 |
+
if key != "cpu":
|
64 |
+
# Get this in GB
|
65 |
+
max_memory_map[key] = max_memory_map[key] // (1024 * 1024 * 1024)
|
66 |
+
# Decrease 2 for Pytorch overhead and 2 for the forward to be safe
|
67 |
+
max_memory_map[key] = f"{max_memory_map[key] - 4} GiB"
|
68 |
+
|
69 |
+
model = VLlamaForCausalLM.from_pretrained(
|
70 |
+
model_name,
|
71 |
+
use_auth_token=os.getenv("HF_AUTH_TOKEN", True),
|
72 |
+
device_map="auto",
|
73 |
+
offload_folder="./offload",
|
74 |
+
torch_dtype=config.torch_dtype,
|
75 |
+
max_memory=max_memory_map,
|
76 |
+
)
|
77 |
+
model.eval()
|
78 |
+
print("Current device map:", model.hf_device_map)
|
79 |
+
print("Model default generation config:", model.generation_config)
|
80 |
+
# TODO: the device_map looks very inefficien right now. that could be improved
|
81 |
+
return tokenizer, model
|
82 |
+
|
83 |
+
|
84 |
+
def fetch_images(url_images):
|
85 |
+
images = []
|
86 |
+
for url in url_images:
|
87 |
+
if isinstance(url, str):
|
88 |
+
images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
|
89 |
+
else:
|
90 |
+
images.append(url)
|
91 |
+
return images
|
92 |
+
|
93 |
+
|
94 |
+
def model_generation(
|
95 |
+
prompt,
|
96 |
+
images,
|
97 |
+
tokenizer,
|
98 |
+
model,
|
99 |
+
temperature,
|
100 |
+
no_repeat_ngram_size,
|
101 |
+
max_new_tokens,
|
102 |
+
min_length,
|
103 |
+
ban_tokens,
|
104 |
+
eos_tokens,
|
105 |
+
force_words,
|
106 |
+
repetition_penalty,
|
107 |
+
hide_special_tokens,
|
108 |
+
decoding_strategy,
|
109 |
+
num_beams,
|
110 |
+
length_penalty,
|
111 |
+
top_k,
|
112 |
+
top_p,
|
113 |
+
penalty_alpha,
|
114 |
+
):
|
115 |
+
# Preparing inputs
|
116 |
+
tokens = tokenizer(
|
117 |
+
[prompt],
|
118 |
+
truncation=True,
|
119 |
+
max_length=MAX_SEQ_LEN,
|
120 |
+
padding=True,
|
121 |
+
add_special_tokens=False,
|
122 |
+
)
|
123 |
+
|
124 |
+
input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
|
125 |
+
attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
|
126 |
+
|
127 |
+
image_attention_mask = [
|
128 |
+
incremental_to_binary_attention_mask(
|
129 |
+
image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
|
130 |
+
)
|
131 |
+
]
|
132 |
+
|
133 |
+
image_transform = build_image_transform(eval=True)
|
134 |
+
pixel_values = [torch.stack([image_transform(img) for img in images])]
|
135 |
+
|
136 |
+
input_ids = input_ids.to(0)
|
137 |
+
attention_mask = attention_mask.to(0)
|
138 |
+
pixel_values = torch.stack(pixel_values).to(0)
|
139 |
+
image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
|
140 |
+
|
141 |
+
# Excluding some words from the generation
|
142 |
+
bad_words_ids = None
|
143 |
+
ban_tokens = ban_tokens.replace("\\n", "\n")
|
144 |
+
bad_words = ban_tokens.split(";")
|
145 |
+
if len(bad_words) > 0:
|
146 |
+
bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
|
147 |
+
|
148 |
+
# Forcing some words in the generation
|
149 |
+
force_words_ids = None
|
150 |
+
if force_words != "":
|
151 |
+
force_words = force_words.replace("\\n", "\n")
|
152 |
+
force_words = force_words.split(";")
|
153 |
+
if len(force_words) > 0:
|
154 |
+
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
155 |
+
|
156 |
+
eos_token_ids = None
|
157 |
+
if eos_tokens != "":
|
158 |
+
eos_tokens = eos_tokens.replace("\\n", "\n")
|
159 |
+
eos_tokens = eos_tokens.split(";")
|
160 |
+
if len(eos_tokens) > 0:
|
161 |
+
eos_token_ids = []
|
162 |
+
for eos_token in eos_tokens:
|
163 |
+
tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
|
164 |
+
if len(tokenized_eos_token) > 1:
|
165 |
+
raise ValueError(
|
166 |
+
f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
|
167 |
+
f" {tokenized_eos_token}"
|
168 |
+
)
|
169 |
+
eos_token_ids += tokenized_eos_token
|
170 |
+
|
171 |
+
# Inputs
|
172 |
+
input_args = {
|
173 |
+
"input_ids": input_ids,
|
174 |
+
"attention_mask": attention_mask,
|
175 |
+
"pixel_values": pixel_values,
|
176 |
+
"image_attention_mask": image_attention_mask,
|
177 |
+
}
|
178 |
+
# Common parameters to all decoding strategies
|
179 |
+
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
180 |
+
generation_args = {
|
181 |
+
"temperature": temperature,
|
182 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
183 |
+
"max_new_tokens": max_new_tokens,
|
184 |
+
"min_length": min_length,
|
185 |
+
"bad_words_ids": bad_words_ids,
|
186 |
+
"force_words_ids": force_words_ids,
|
187 |
+
"repetition_penalty": repetition_penalty,
|
188 |
+
"eos_token_id": eos_token_ids,
|
189 |
+
}
|
190 |
+
|
191 |
+
assert decoding_strategy in [
|
192 |
+
"greedy",
|
193 |
+
"beam_search",
|
194 |
+
"beam_sampling",
|
195 |
+
"sampling_top_k",
|
196 |
+
"sampling_top_p",
|
197 |
+
"contrastive_sampling",
|
198 |
+
]
|
199 |
+
if decoding_strategy == "greedy":
|
200 |
+
pass
|
201 |
+
elif decoding_strategy == "beam_search":
|
202 |
+
generation_args["num_beams"] = num_beams
|
203 |
+
generation_args["length_penalty"] = length_penalty
|
204 |
+
assert generation_args["num_beams"] > 1
|
205 |
+
elif decoding_strategy == "beam_sampling":
|
206 |
+
generation_args["num_beams"] = num_beams
|
207 |
+
generation_args["length_penalty"] = length_penalty
|
208 |
+
generation_args["do_sample"] = True
|
209 |
+
assert generation_args["num_beams"] > 1
|
210 |
+
elif decoding_strategy == "sampling_top_k":
|
211 |
+
generation_args["do_sample"] = True
|
212 |
+
generation_args["top_k"] = top_k
|
213 |
+
elif decoding_strategy == "sampling_top_p":
|
214 |
+
generation_args["do_sample"] = True
|
215 |
+
generation_args["top_p"] = top_p
|
216 |
+
elif decoding_strategy == "contrastive_sampling":
|
217 |
+
generation_args["do_sample"] = True
|
218 |
+
generation_args["penalty_alpha"] = penalty_alpha
|
219 |
+
generation_args["top_k"] = top_k
|
220 |
+
|
221 |
+
generated_tokens = model.generate(
|
222 |
+
**input_args,
|
223 |
+
**generation_args,
|
224 |
+
)
|
225 |
+
|
226 |
+
tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
|
227 |
+
decoded_skip_special_tokens = repr(
|
228 |
+
tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
|
229 |
+
)
|
230 |
+
|
231 |
+
actual_generated_tokens = generated_tokens[:, input_ids.shape[-1] :]
|
232 |
+
first_end_token = len(actual_generated_tokens[0])
|
233 |
+
actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
|
234 |
+
generated_text = tokenizer.batch_decode(actual_generated_tokens, skip_special_tokens=hide_special_tokens)[0]
|
235 |
+
|
236 |
+
logger.info(
|
237 |
+
"Result: \n"
|
238 |
+
f"----Prompt: `{prompt}`\n"
|
239 |
+
f"----Tokens ids - prompt + generation: `{generated_tokens[0].tolist()}`\n"
|
240 |
+
f"----Tokens converted - prompt + generation: `{tokens}`\n"
|
241 |
+
f"----String decoded with skipped special tokens - prompt + generation: `{decoded_skip_special_tokens}`\n"
|
242 |
+
f"----Token ids - generation: `{actual_generated_tokens[0].tolist()}`"
|
243 |
+
f"----Tokens converted - generation: `{tokenizer.convert_ids_to_tokens(actual_generated_tokens[0])}`"
|
244 |
+
f"----String decoded with skipped special tokens - generation: `{generated_text}`\n"
|
245 |
+
f"----Generation mode: `{decoding_strategy}`\n"
|
246 |
+
f"----Generation parameters: `{generation_args}`\n"
|
247 |
+
)
|
248 |
+
|
249 |
+
return generated_text
|
250 |
+
|
251 |
+
|
252 |
+
textbox = gr.Textbox(
|
253 |
+
show_label=False,
|
254 |
+
value="<fake_token_around_image><image:https://m.media-amazon.com/images/M/MV5BMjE4MTcwMTM1Nl5BMl5BanBnXkFtZTcwMTIwMzMzMw@@._V1_.jpg><fake_token_around_image>Describe all of the parts of this image.",
|
255 |
+
placeholder="To input images, use the following syntax: `<fake_token_around_image><image:URL_IMAGE><fake_token_around_image>textexttext`",
|
256 |
+
visible=True,
|
257 |
+
container=False
|
258 |
+
)
|
259 |
+
with gr.Blocks(title="LLaVA", theme=gr.themes.Base()) as demo:
|
260 |
+
# state = gr.State()
|
261 |
+
|
262 |
+
with gr.Row():
|
263 |
+
with gr.Column(scale=3):
|
264 |
+
with gr.Row(elem_id="model_selector_row"):
|
265 |
+
model_selector = gr.Dropdown(
|
266 |
+
choices=models,
|
267 |
+
value=models[0] if len(models) > 0 else "",
|
268 |
+
interactive=True,
|
269 |
+
show_label=False,
|
270 |
+
container=False)
|
271 |
+
tokenizer, model = load_tokenizer_model(model_selector.value)
|
272 |
+
|
273 |
+
imagebox = gr.Image(type="pil", label="Image input - This image box is not supported yet! To include images, do through the text by adding `<fake_token_around_image><image:IMAGE_URL><fake_token_around_image>`. The backend takes care of parsing that <image:URL> and download the correponding image. That way, you can technically interleave as many images and texts as you want. No need to add space before and after `<fake_token_around_image>`")
|
274 |
+
|
275 |
+
with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
|
276 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
277 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
278 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
279 |
+
|
280 |
+
with gr.Column(scale=6):
|
281 |
+
chatbot = gr.Chatbot(
|
282 |
+
elem_id="chatbot",
|
283 |
+
label="Idefics Chatbot",
|
284 |
+
visible=True,
|
285 |
+
height=550,
|
286 |
+
value=[
|
287 |
+
[
|
288 |
+
"Where can we typically find animals like this one?<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/d/db/Alpaca_%2831562329701%29.jpg/640px-Alpaca_%2831562329701%29.jpg><fake_token_around_image>",
|
289 |
+
"Animals like the one in the image, which is a llama, can typically be found in rural areas, such as farms or ranches. Llamas are often used as pack animals in mountainous regions, as they are well-adapted to the challenging terrain and can carry heavy loads. They are also valued for their wool, which is used to make clothing and other textiles. In addition, llamas are sometimes kept as pets or for their therapeutic benefits, as they are known to be calm and gentle animals."
|
290 |
+
],
|
291 |
+
[
|
292 |
+
"How many of these animals can we fit into an engine like that<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/4/4e/Nassau_County_Police_Bell_407.jpg/1200px-Nassau_County_Police_Bell_407.jpg><fake_token_around_image>?",
|
293 |
+
"The image shows a helicopter with a large engine, but it is not possible to determine the exact number of animals that can fit into it based on the image alone. The size and capacity of the helicopter's engine would depend on various factors, such as the size of the animals, the weight of the animals, and the size of the helicopter itself. However, it is safe to assume that the helicopter is designed to carry a limited number of animals, and it is not intended to be used as a means of transporting large groups of animals."
|
294 |
+
]
|
295 |
+
]
|
296 |
+
)
|
297 |
+
with gr.Row():
|
298 |
+
with gr.Column(scale=8):
|
299 |
+
textbox.render()
|
300 |
+
with gr.Column(scale=1, min_width=60):
|
301 |
+
submit_btn = gr.Button(value="Submit", visible=True)
|
302 |
+
with gr.Column(scale=1, min_width=20):
|
303 |
+
clear_btn = gr.ClearButton([textbox, chatbot])
|
304 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
305 |
+
# gr.Examples(examples=[
|
306 |
+
# [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
307 |
+
# [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
|
308 |
+
# ], inputs=[textbox])
|
309 |
+
# gr.Examples(
|
310 |
+
# examples=[
|
311 |
+
# [
|
312 |
+
# ("How many of these animals can we fit into an engine like that<fake_token_around_image><image:https://upload.wikimedia.org/wikipedia/commons/thumb/4/4e/Nassau_County_Police_Bell_407.jpg/1200px-Nassau_County_Police_Bell_407.jpg><fake_token_around_image>?", "The image shows a helicopter with a large engine, but it is not possible to determine the exact number of animals that can fit into it based on the image alone. The size and capacity of the helicopter's engine would depend on various factors, such as the size of the animals, the weight of the animals, and the size of the helicopter itself. However, it is safe to assume that the helicopter is designed to carry a limited number of animals, and it is not intended to be used as a means of transporting large groups of animals."),
|
313 |
+
# ],
|
314 |
+
# ],
|
315 |
+
# inputs = [chatbot]
|
316 |
+
# )
|
317 |
+
|
318 |
+
def format_prompt_with_history_and_system_conditioning(current_user_prompt, history):
|
319 |
+
resulting_text = SYSTEM_PROMPT
|
320 |
+
for turn in history:
|
321 |
+
user_utterance, assistant_utterance = turn
|
322 |
+
resulting_text += f"\nUser:{user_utterance}</s>\nAssistant:{assistant_utterance}"
|
323 |
+
resulting_text += f"\nUser:{current_user_prompt}</s>\nAssistant:"
|
324 |
+
return resulting_text
|
325 |
+
|
326 |
+
def model_inference(
|
327 |
+
user_prompt,
|
328 |
+
chat_history,
|
329 |
+
):
|
330 |
+
global model, tokenizer
|
331 |
+
|
332 |
+
temperature = 1.0
|
333 |
+
no_repeat_ngram_size = 0
|
334 |
+
max_new_tokens = 256
|
335 |
+
min_length = 16
|
336 |
+
force_words = ""
|
337 |
+
repetition_penalty = 1.0
|
338 |
+
hide_special_tokens = False
|
339 |
+
decoding_strategy = "greedy"
|
340 |
+
num_beams = 3
|
341 |
+
length_penalty = 1.
|
342 |
+
top_k = 50,
|
343 |
+
top_p = 0.95
|
344 |
+
penalty_alpha = 0.95
|
345 |
+
|
346 |
+
formated_prompt = format_prompt_with_history_and_system_conditioning(
|
347 |
+
current_user_prompt=user_prompt.strip(),
|
348 |
+
history=chat_history,
|
349 |
+
)
|
350 |
+
|
351 |
+
url_images = re.findall(r"<image(.*?)>", formated_prompt)
|
352 |
+
for idx, url_image in enumerate(url_images):
|
353 |
+
formated_prompt = formated_prompt.replace(url_image, "")
|
354 |
+
url_images[idx] = url_images[idx][1:]
|
355 |
+
images = fetch_images(url_images)
|
356 |
+
|
357 |
+
generated_text = model_generation(
|
358 |
+
prompt=formated_prompt,
|
359 |
+
images=images,
|
360 |
+
tokenizer=tokenizer,
|
361 |
+
model=model,
|
362 |
+
temperature=temperature,
|
363 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
364 |
+
max_new_tokens=max_new_tokens,
|
365 |
+
min_length=min_length,
|
366 |
+
ban_tokens=BAN_TOKENS,
|
367 |
+
eos_tokens=EOS_TOKENS,
|
368 |
+
force_words=force_words,
|
369 |
+
repetition_penalty=repetition_penalty,
|
370 |
+
hide_special_tokens=hide_special_tokens,
|
371 |
+
decoding_strategy=decoding_strategy,
|
372 |
+
num_beams=num_beams,
|
373 |
+
length_penalty=length_penalty,
|
374 |
+
top_k=top_k,
|
375 |
+
top_p=top_p,
|
376 |
+
penalty_alpha=penalty_alpha,
|
377 |
+
)
|
378 |
+
|
379 |
+
chat_history.append((user_prompt, generated_text.strip()))
|
380 |
+
return "", chat_history
|
381 |
+
|
382 |
+
|
383 |
+
textbox.submit(
|
384 |
+
fn=model_inference,
|
385 |
+
inputs=[textbox, chatbot],
|
386 |
+
outputs=[textbox, chatbot],
|
387 |
+
)
|
388 |
+
submit_btn.click(
|
389 |
+
fn=model_inference,
|
390 |
+
inputs=[textbox, chatbot],
|
391 |
+
outputs=[textbox, chatbot],
|
392 |
+
)
|
393 |
+
|
394 |
+
demo.queue()
|
395 |
+
demo.launch(share=True)
|
elon_musk.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is a conversation between a human, User, and an intelligent visual AI, Bot. User sends images, and Bot answers the questions from the user. The bot should be friendly, informative and should not change the topic if it's not asked to.
|
2 |
+
|
3 |
+
Here's an example of a conversation:
|
4 |
+
User: <fake_token_around_image> <image> <fake_token_around_image> Describe this image.
|
5 |
+
Bot: Two kittens are cuddling in the center of the photograph. They are surrounded by tall, bright green grass, and the background is blurred.
|
6 |
+
User: <fake_token_around_image> <image> <fake_token_around_image> How about this image? Can you describe it too?
|
7 |
+
Bot: A dog is lying on the floor, looking at the camera. It is looking directly at you, tilting its head to the side. The dog has a white body and brown patches on its face and ears. Its eyes are brown. Its nose is black, and it has long, floppy ears, short legs, white paws, long fur, big eyes, and black eyebrows.
|
8 |
+
User: What kind of breed is it?
|
9 |
+
Bot: It looks like a Cavalier King Charles Spaniel. These breeds are known for their gentle, affectionate nature. Their long ears and big, brown eyes give them a sweet, expressive face. Cavaliers are small dogs weighing between 13 and 20 pounds. They have a silky, medium-length coat that comes in various colors, including black and tan, tricolor, red and white, or black.
|
10 |
+
---
|
11 |
+
User: <fake_token_around_image> <image> <fake_token_around_image>
|
12 |
+
Describe this image.
|
13 |
+
Bot: This image shows a robot named Wall-E. He is standing in front of a pile of garbage, holding a plant in his hands. His eyes and mouth are open, as if he is surprised. Wall E is made of metal and has two arms and two legs. There is also a small light on his head. In the distance, there are some buildings and a large mountain. This scene takes place on Earth, which is covered in trash.
|
14 |
+
User: What happens in the film? Explain in a few sentences.
|
15 |
+
Bot: WallE is an animated film about a lonely robot who falls in love with another robot. When they meet, they discover that they have been programmed to clean up the Earth. Together, the robots work to restore the planet to its former glory. Along the way, Wall e learns about friendship, love, responsibility, sacrifice, hope, courage, determination, perseverance, teamwork, loyalty, honesty, trust, forgiveness, compassion, empathy, selflessness, kindness and generosity. Ultimately, he realizes that he can make a difference in this world.
|
16 |
+
User: In which year was the film released?
|
17 |
+
Bot: The film was released in 3D in June 08, 4 years ago.
|
18 |
+
User: Is a future in which the earth is not liveable anymore (similarly as in the movie) conceivable?
|
19 |
+
Bot: Yes, it is possible. If we continue to pollute the environment and destroy the natural resources, we will eventually make the world uninhabitable. We need to take action now to protect the future of our planet. One way to do this is to reduce our carbon footprint by using less energy and recycling more. Another way is by supporting companies that are working to develop sustainable technologies. Finally, you can also get involved in local environmental initiatives, such as planting trees or cleaning up parks. By taking these steps, each of us can help to make sure that our children and grandchildren will be able to enjoy a healthy and beautiful planet for many years to come.
|
20 |
+
User: I think this is a lost cause. <fake_token_around_image> <image> <fake_token_around_image> Like Elon Musk, I think we should instead all move to another planet now.
|
21 |
+
Bot: El ón Mus ́k is the CEO of SpaceX and Tesla Motors. Space X is developing a reusable rocket system that could be used to transport people and cargo to Mars. Mus k believes that colonizing Mars is essential for the survival of humanity, because it would provide a backup plan in case something goes wrong on earth. However, many people are skeptical about the feasibility of such a project. Some argue that it will take too long and cost too much money to establish a colony on Mars, while others say that the risks are too great. Nevertheless, Mus K remains optimistic about his vision for Mars colonization. "I think there is good reason to believe that we can build a self-sustaining city on mars," he said in an interview with Time magazine. “It would be quite an adventure.”
|
22 |
+
----
|
23 |
+
The project is still in development. I'm currently working on improving the accuracy of image recognition and adding more features to it.</s>
|
images/bear.jpg
ADDED
m4/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from m4.utils import logging
|
m4/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from m4.models.custom_modules import DecoupledEmbedding, DecoupledLinear
|
2 |
+
from m4.models.vbloom.configuration_vbloom import VBloomConfig
|
3 |
+
from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM
|
4 |
+
from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
|
5 |
+
from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel
|
6 |
+
from m4.models.vllama.configuration_vllama import VLlamaConfig
|
7 |
+
from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
|
8 |
+
from m4.models.vopt.configuration_vopt import VOPTConfig
|
9 |
+
from m4.models.vopt.modeling_vopt import VOPTForCausalLM
|
10 |
+
from m4.models.vt5.configuration_vt5 import VT5Config
|
11 |
+
from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration
|
12 |
+
|
13 |
+
|
14 |
+
_SUPPORTED_MODELS = {
|
15 |
+
"vgpt2": VGPT2Config,
|
16 |
+
"vt5": VT5Config,
|
17 |
+
"vbloom": VBloomConfig,
|
18 |
+
"vopt": VOPTConfig,
|
19 |
+
"vllama": VLlamaConfig,
|
20 |
+
}
|
21 |
+
|
22 |
+
model_type_to_modeling_class = {
|
23 |
+
"vgpt2": VGPT2LMHeadModel,
|
24 |
+
"vt5": VT5ForConditionalGeneration,
|
25 |
+
"vbloom": VBloomForCausalLM,
|
26 |
+
"vopt": VOPTForCausalLM,
|
27 |
+
"vllama": VLlamaForCausalLM,
|
28 |
+
}
|
m4/models/common.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def expand_inputs_for_generation(
|
5 |
+
input_ids,
|
6 |
+
expand_size=1,
|
7 |
+
is_encoder_decoder=False,
|
8 |
+
attention_mask=None,
|
9 |
+
encoder_outputs=None,
|
10 |
+
**model_kwargs,
|
11 |
+
):
|
12 |
+
expanded_return_idx = (
|
13 |
+
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
|
14 |
+
)
|
15 |
+
input_ids = input_ids.index_select(0, expanded_return_idx)
|
16 |
+
|
17 |
+
if "token_type_ids" in model_kwargs:
|
18 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
19 |
+
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
|
20 |
+
|
21 |
+
if attention_mask is not None:
|
22 |
+
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
23 |
+
model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
|
24 |
+
0, expanded_return_idx
|
25 |
+
)
|
26 |
+
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
27 |
+
|
28 |
+
if is_encoder_decoder:
|
29 |
+
if encoder_outputs is None:
|
30 |
+
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
31 |
+
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
32 |
+
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
33 |
+
)
|
34 |
+
model_kwargs["encoder_outputs"] = encoder_outputs
|
35 |
+
return input_ids, model_kwargs
|
36 |
+
|
37 |
+
|
38 |
+
def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
39 |
+
# must have this key set to at least None
|
40 |
+
model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
|
41 |
+
|
42 |
+
# update past
|
43 |
+
if "past_key_values" in outputs:
|
44 |
+
model_kwargs["past"] = outputs.past_key_values
|
45 |
+
elif "mems" in outputs:
|
46 |
+
model_kwargs["past"] = outputs.mems
|
47 |
+
elif "past_buckets_states" in outputs:
|
48 |
+
model_kwargs["past"] = outputs.past_buckets_states
|
49 |
+
else:
|
50 |
+
model_kwargs["past"] = None
|
51 |
+
|
52 |
+
# update token_type_ids with last value
|
53 |
+
if "token_type_ids" in model_kwargs:
|
54 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
55 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
56 |
+
|
57 |
+
# update attention masks
|
58 |
+
if not is_encoder_decoder:
|
59 |
+
if "attention_mask" in model_kwargs:
|
60 |
+
attention_mask = model_kwargs["attention_mask"]
|
61 |
+
model_kwargs["attention_mask"] = torch.cat(
|
62 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
63 |
+
)
|
64 |
+
if "image_attention_mask" in model_kwargs:
|
65 |
+
image_attention_mask = model_kwargs["image_attention_mask"]
|
66 |
+
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
|
67 |
+
model_kwargs["image_attention_mask"] = last_mask
|
68 |
+
|
69 |
+
return model_kwargs
|
70 |
+
|
71 |
+
|
72 |
+
def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
|
73 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
74 |
+
# only last token for inputs_ids if past is defined in kwargs
|
75 |
+
if past:
|
76 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
77 |
+
if token_type_ids is not None:
|
78 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
79 |
+
|
80 |
+
attention_mask = kwargs.get("attention_mask", None)
|
81 |
+
position_ids = kwargs.get("position_ids", None)
|
82 |
+
|
83 |
+
if attention_mask is not None and position_ids is None:
|
84 |
+
# create position_ids on the fly for batch generation
|
85 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
86 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
87 |
+
if past:
|
88 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
89 |
+
|
90 |
+
pixel_values = kwargs.get("pixel_values", None)
|
91 |
+
image_attention_mask = kwargs.get("image_attention_mask", None)
|
92 |
+
if pixel_values is None or image_attention_mask is None:
|
93 |
+
raise ValueError("pixel values and image attention mask cannot be None")
|
94 |
+
|
95 |
+
return {
|
96 |
+
"input_ids": input_ids,
|
97 |
+
"past_key_values": past,
|
98 |
+
"use_cache": kwargs.get("use_cache"),
|
99 |
+
"position_ids": position_ids,
|
100 |
+
"attention_mask": attention_mask,
|
101 |
+
"token_type_ids": token_type_ids,
|
102 |
+
"pixel_values": pixel_values,
|
103 |
+
"image_attention_mask": image_attention_mask,
|
104 |
+
}
|
m4/models/custom_modules.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
7 |
+
from transformers.utils import ContextManagers
|
8 |
+
|
9 |
+
from m4.training.setup_vision_model import vision_model_name_to_model
|
10 |
+
from m4.training.utils import (
|
11 |
+
deepspeed_zero_init_disabled_context_manager,
|
12 |
+
is_deepspeed_zero_init_enabled,
|
13 |
+
load_state_dict_into_model,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# from pathlib import Path
|
18 |
+
|
19 |
+
|
20 |
+
class VLOOMPreTrainedModelBase(PreTrainedModel):
|
21 |
+
# The problem we are trying to solve is 2 nested zero.Init thanks to fetching from_pretrained(vision_model_name)
|
22 |
+
# and then one more zero.Init to override from_pretrained(vision_model_name) once again as it was done in the original - this breaks deepspeed zero3 w/ zero.Init
|
23 |
+
# So one solution is this:
|
24 |
+
# a. replace from_pretrained(vision_model_name) with from_config(vision_model_name) while hacking to disable zero.Init context
|
25 |
+
# b. instead of straight replacement of model.vision_model = from_pretrained(vision_model_name) when it gets updated, we first do from_pretrained(vision_model_name) and then update the existing model with weights using the already zero.Init'ed pre-sharded weights
|
26 |
+
#
|
27 |
+
# there are a few variations to get_vision_model_from_config - all need to bypass zero.Init under zero3
|
28 |
+
# 1. one variant is to hack into accelerate's deepspeed_plugin and turn off zero.Init while loading the vision model
|
29 |
+
# 2. the other variant is to override _from_config method with our version that doesn't do zero.Init
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def override_vision_model(cls, model, vision_model_name, vision_model_params, torch_dtype):
|
33 |
+
# 1. fetch the pretrained vision model w/o zero.Init
|
34 |
+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
35 |
+
vision_model = AutoModel.from_pretrained(vision_model_name, **vision_model_params, torch_dtype=torch_dtype)
|
36 |
+
|
37 |
+
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
|
38 |
+
real_vision_model = vision_model_name_to_model(vision_model_name, vision_model)
|
39 |
+
|
40 |
+
# 2. now override the weights already sharded by zero.Init with the weights from the real_vision_model
|
41 |
+
# by gradually gathering sharded weights and replacing with new weights
|
42 |
+
if is_deepspeed_zero_init_enabled():
|
43 |
+
state_dict = real_vision_model.state_dict()
|
44 |
+
load_state_dict_into_model(model.vision_model, state_dict, start_prefix="")
|
45 |
+
else:
|
46 |
+
model.vision_model = real_vision_model
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_config(cls, config, **kwargs):
|
50 |
+
# torch_dtype is crucial for using the minimal amount of memory at load time
|
51 |
+
torch_dtype = kwargs.get("torch_dtype", None)
|
52 |
+
|
53 |
+
vision_model_name = config.vision_model_name
|
54 |
+
vision_model_params = eval(config.vision_model_params)
|
55 |
+
|
56 |
+
# 1. create an uninitialized vision_model to insert into the main model.
|
57 |
+
# It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
|
58 |
+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
59 |
+
vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
|
60 |
+
vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
|
61 |
+
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
|
62 |
+
kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
|
63 |
+
|
64 |
+
# 2. create the main class's model, passing the uninitialized vision_model to it
|
65 |
+
model = cls(config, **kwargs)
|
66 |
+
|
67 |
+
return model
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def from_pretrained_models(cls, *args, **kwargs):
|
71 |
+
"""
|
72 |
+
Use this method when creating a new vloom model that hasn't been yet trained and it'll be
|
73 |
+
composed of 2 pre-trained models - hence `pretrained_models`.
|
74 |
+
"""
|
75 |
+
|
76 |
+
return cls.from_pretrained(*args, **kwargs, new_model=True)
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_pretrained(cls, *model_args, is_resume=False, new_model=False, **kwargs):
|
80 |
+
"""
|
81 |
+
Use this method when loading an already pretrained vloom model - either from a checkpoint or from hub.
|
82 |
+
For creating an untrained model use `pretrained_models` instead.
|
83 |
+
"""
|
84 |
+
|
85 |
+
is_untrained_vloom_model = False
|
86 |
+
is_pretrained_vloom_model_resumed = False
|
87 |
+
is_pretrained_vloom_model_from_hub_or_path = False
|
88 |
+
|
89 |
+
# we have 3 use cases:
|
90 |
+
# 1. is_untrained_vloom_model - a totally new vloom model
|
91 |
+
# 2. is_pretrained_vloom_model_resumed - a pretrained vloom model being resumed from a
|
92 |
+
# checkpoint (instantiate a random empty model in this case)
|
93 |
+
# 3. is_pretrained_vloom_model_from_hub_or_path - a pretrained vloom model loaded from hub or local path
|
94 |
+
if new_model:
|
95 |
+
is_untrained_vloom_model = True
|
96 |
+
elif is_resume:
|
97 |
+
is_pretrained_vloom_model_resumed = True
|
98 |
+
else:
|
99 |
+
is_pretrained_vloom_model_from_hub_or_path = True
|
100 |
+
|
101 |
+
# torch_dtype is crucial for using the minimal amount of memory at load time
|
102 |
+
torch_dtype = kwargs.get("torch_dtype", None)
|
103 |
+
|
104 |
+
# config is:
|
105 |
+
# 1. either not passed and then we use the model's default config (used by tests)
|
106 |
+
# 2. passed and in which case it's one of:
|
107 |
+
# 2a. `PretrainedConfig` (a new m4 model)
|
108 |
+
# 2b. path to a json config (an already pretrained m4 model, usually resumed training)
|
109 |
+
config = kwargs.get("config", None)
|
110 |
+
if config is None:
|
111 |
+
config = cls.config_class.from_pretrained(*model_args, **kwargs, return_unused_kwargs=False)
|
112 |
+
elif not isinstance(config, PretrainedConfig):
|
113 |
+
# adapted from https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/modeling_utils.py#L1920
|
114 |
+
assert isinstance(config, os.PathLike)
|
115 |
+
config_path = str(config)
|
116 |
+
config = cls.config_class.from_pretrained(
|
117 |
+
config_path,
|
118 |
+
return_unused_kwargs=False,
|
119 |
+
**kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
vision_model_name = config.vision_model_name
|
123 |
+
vision_model_params = eval(config.vision_model_params)
|
124 |
+
|
125 |
+
# 1. create an uninitialized vision_model to insert into the main model.
|
126 |
+
# It has to be created outside lm's `from_pretrained` and w/o zero.Init so that zero3+zero.Init works
|
127 |
+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
128 |
+
vision_model_config = AutoConfig.from_pretrained(vision_model_name, **vision_model_params)
|
129 |
+
vision_model_from_config = AutoModel.from_config(vision_model_config, torch_dtype=torch_dtype)
|
130 |
+
# this extracts the desired submodule if the part we want is nested (e.g. as in clip)
|
131 |
+
kwargs["vision_model"] = vision_model_name_to_model(vision_model_name, vision_model_from_config)
|
132 |
+
|
133 |
+
# 2. create the vloom model
|
134 |
+
if is_untrained_vloom_model or is_pretrained_vloom_model_from_hub_or_path:
|
135 |
+
model = super().from_pretrained(*model_args, **kwargs)
|
136 |
+
elif is_pretrained_vloom_model_resumed:
|
137 |
+
# in the case of resume under deepspeed we create an empty model, and get deepspeed
|
138 |
+
# to load the weights from the checkpoint
|
139 |
+
# but not all models have these keys so handle the case they don't have them
|
140 |
+
_ = kwargs.pop("config", None)
|
141 |
+
model = super().from_pretrained(None, config=config, state_dict={}, **kwargs)
|
142 |
+
|
143 |
+
# 3. if is_untrained_vloom_model, now override the uninitialized vision_model with one with pretrained weights
|
144 |
+
if is_untrained_vloom_model:
|
145 |
+
cls.override_vision_model_wrapper(model, config, vision_model_name, vision_model_params, torch_dtype)
|
146 |
+
|
147 |
+
return model
|
148 |
+
|
149 |
+
|
150 |
+
class DecoupledEmbedding(nn.Embedding):
|
151 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
|
152 |
+
"""
|
153 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
|
154 |
+
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
|
155 |
+
If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
num_embeddings,
|
161 |
+
num_additional_embeddings,
|
162 |
+
embedding_dim,
|
163 |
+
partially_freeze=False,
|
164 |
+
device=None,
|
165 |
+
dtype=None,
|
166 |
+
padding_idx=None,
|
167 |
+
**kwargs,
|
168 |
+
) -> None:
|
169 |
+
"""
|
170 |
+
num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
|
171 |
+
partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
|
172 |
+
|
173 |
+
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
|
174 |
+
"""
|
175 |
+
if padding_idx is not None and padding_idx > num_embeddings:
|
176 |
+
raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
|
177 |
+
super().__init__(
|
178 |
+
num_embeddings=num_embeddings,
|
179 |
+
embedding_dim=embedding_dim,
|
180 |
+
device=device,
|
181 |
+
dtype=dtype,
|
182 |
+
padding_idx=padding_idx,
|
183 |
+
**kwargs,
|
184 |
+
)
|
185 |
+
self.num_embeddings = num_embeddings
|
186 |
+
self.padding_idx = padding_idx
|
187 |
+
self.num_additional_embeddings = num_additional_embeddings
|
188 |
+
self.partially_freeze = partially_freeze
|
189 |
+
|
190 |
+
if partially_freeze:
|
191 |
+
self.weight.requires_grad_(False)
|
192 |
+
|
193 |
+
if self.num_additional_embeddings > 0:
|
194 |
+
self.additional_embedding = nn.Embedding(
|
195 |
+
num_embeddings=self.num_additional_embeddings,
|
196 |
+
embedding_dim=embedding_dim,
|
197 |
+
device=device,
|
198 |
+
dtype=dtype,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, input_ids):
|
202 |
+
"""
|
203 |
+
we have 2 embeddings, with different indices - one pretrained self.weight and another
|
204 |
+
self.additional_embedding.weight that is being trained.
|
205 |
+
|
206 |
+
in order to make a lookup of the input ids, we:
|
207 |
+
1. find out the indices of the entries belonging to the 2nd embedding
|
208 |
+
2. extract those values while subtracting the size of the first embedding (num_embeddings),
|
209 |
+
since the 2nd embedding starts from 0 and not num_embeddings
|
210 |
+
3. perform the 2nd embedding lookup
|
211 |
+
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
|
212 |
+
5. perform the 1st embedding lookup
|
213 |
+
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
|
214 |
+
|
215 |
+
note: for the 1st embedding lookup we could have looked up only the low indices and not do
|
216 |
+
the padding, but then we have to create a new tensor and populate it with 2 tensors that are
|
217 |
+
spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
|
218 |
+
complex case if it's any faster, given that seqlens are usually relatively short it's
|
219 |
+
probably not faster or if faster not by much - but might be a good idea to measure.
|
220 |
+
|
221 |
+
"""
|
222 |
+
if self.num_additional_embeddings == 0:
|
223 |
+
return F.embedding(input_ids, self.weight)
|
224 |
+
|
225 |
+
# Clone so that we don't modify the original input_ids later on
|
226 |
+
input_ids = input_ids.clone()
|
227 |
+
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
|
228 |
+
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
229 |
+
additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
|
230 |
+
|
231 |
+
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
232 |
+
input_ids[additional_vocab_indices] = 0
|
233 |
+
full_vector = F.embedding(input_ids, self.weight)
|
234 |
+
|
235 |
+
# overwrite the records with high indices
|
236 |
+
full_vector[additional_vocab_indices] = additional_embeddings
|
237 |
+
|
238 |
+
return full_vector
|
239 |
+
|
240 |
+
def extra_repr(self) -> str:
|
241 |
+
return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
|
242 |
+
self.num_embeddings,
|
243 |
+
self.num_additional_embeddings,
|
244 |
+
self.embedding_dim,
|
245 |
+
self.partially_freeze,
|
246 |
+
)
|
247 |
+
|
248 |
+
@classmethod
|
249 |
+
def from_pretrained(cls, embeddings, freeze=True, **kwargs):
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
|
253 |
+
class DecoupledLinear(nn.Linear):
|
254 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
|
255 |
+
"""
|
256 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
|
257 |
+
In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
|
258 |
+
If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
in_features: int,
|
264 |
+
out_features: int,
|
265 |
+
out_additional_features: int = 0,
|
266 |
+
bias: bool = True,
|
267 |
+
partially_freeze: bool = True,
|
268 |
+
device=None,
|
269 |
+
dtype=None,
|
270 |
+
) -> None:
|
271 |
+
"""
|
272 |
+
out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
|
273 |
+
partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
|
274 |
+
"""
|
275 |
+
super().__init__(in_features, out_features, bias, device, dtype)
|
276 |
+
self.out_additional_features = out_additional_features
|
277 |
+
self.partially_freeze = partially_freeze
|
278 |
+
|
279 |
+
self.in_features = in_features
|
280 |
+
self.out_features = out_features
|
281 |
+
|
282 |
+
if partially_freeze:
|
283 |
+
self.weight.requires_grad_(False)
|
284 |
+
if bias:
|
285 |
+
self.bias.requires_grad_(False)
|
286 |
+
|
287 |
+
if out_additional_features > 0:
|
288 |
+
self.additional_fc = nn.Linear(
|
289 |
+
in_features=in_features,
|
290 |
+
out_features=out_additional_features,
|
291 |
+
bias=bias,
|
292 |
+
device=device,
|
293 |
+
dtype=dtype,
|
294 |
+
)
|
295 |
+
|
296 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
297 |
+
output = F.linear(input, self.weight, self.bias)
|
298 |
+
|
299 |
+
if self.out_additional_features > 0:
|
300 |
+
additional_features = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
|
301 |
+
output = torch.cat((output, additional_features), -1)
|
302 |
+
|
303 |
+
return output
|
304 |
+
|
305 |
+
def extra_repr(self) -> str:
|
306 |
+
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
307 |
+
return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
|
308 |
+
self.in_features,
|
309 |
+
self.out_features,
|
310 |
+
self.out_additional_features,
|
311 |
+
self.bias is not None,
|
312 |
+
self.partially_freeze,
|
313 |
+
)
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
emb = DecoupledEmbedding(num_embeddings=10, num_additional_embeddings=3, embedding_dim=5, partially_freeze=True)
|
318 |
+
for n, p in emb.named_parameters():
|
319 |
+
print(n, p.requires_grad)
|
320 |
+
idx = torch.tensor([[11, 1, 3]])
|
321 |
+
y = emb(idx)
|
322 |
+
loss = y.sum()
|
323 |
+
loss.backward()
|
324 |
+
print(emb.weight, emb.weight.grad)
|
325 |
+
print(emb.additional_embedding, emb.additional_embedding.grad)
|
326 |
+
|
327 |
+
lin = DecoupledLinear(in_features=3, out_features=4, out_additional_features=2, bias=True, partially_freeze=True)
|
328 |
+
for n, p in lin.named_parameters():
|
329 |
+
print(n, p.requires_grad)
|
330 |
+
x = torch.randn(12, 3)
|
331 |
+
y = lin(x)
|
332 |
+
loss = y.sum()
|
333 |
+
loss.backward()
|
334 |
+
print("Weight w and grad:", lin.weight, lin.weight.grad)
|
335 |
+
print("bias w and grad:", lin.bias, lin.bias.grad)
|
336 |
+
print("additional_fc.weight w and grad:", lin.additional_fc.weight, lin.additional_fc.weight.grad)
|
337 |
+
print("additional_bias w and grad:", lin.additional_fc.bias, lin.additional_fc.bias.grad)
|
m4/models/perceiver/perceiver.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
perceiver.py
|
3 |
+
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
|
4 |
+
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents!
|
5 |
+
Note that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here
|
6 |
+
to prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use
|
7 |
+
that to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
|
8 |
+
References:
|
9 |
+
- DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
|
10 |
+
- Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
|
11 |
+
"""
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
|
18 |
+
|
19 |
+
class PerceiverResampler(nn.Module):
|
20 |
+
def __init__(self, config, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int) -> None:
|
21 |
+
"""
|
22 |
+
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
|
23 |
+
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
|
24 |
+
returns a Tensor of shape [bsz, n_latents, embed_dim].
|
25 |
+
:param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of
|
26 |
+
latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet
|
27 |
+
pool dim, and so on.
|
28 |
+
:param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
|
29 |
+
:param n_heads: Number of heads in each Transformer block (for multi-headed self-attention).
|
30 |
+
:param head_dim: Dimensionality of each head projection in the Transformer block.
|
31 |
+
:param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
|
35 |
+
self.qk_layer_norms = config.qk_layer_norms_perceiver
|
36 |
+
|
37 |
+
# Create Latents for Perceiver
|
38 |
+
self.latents = nn.Parameter(torch.randn(self.n_latents, self.embed_dim), requires_grad=True)
|
39 |
+
|
40 |
+
self.intermediate_dim = (
|
41 |
+
self.embed_dim * 4 if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim * 4
|
42 |
+
)
|
43 |
+
# Create Transformer Blocks
|
44 |
+
self.blocks = nn.ModuleList(
|
45 |
+
[
|
46 |
+
nn.ModuleList(
|
47 |
+
[
|
48 |
+
PerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms),
|
49 |
+
MLP(self.intermediate_dim, config),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
for _ in range(depth)
|
53 |
+
]
|
54 |
+
)
|
55 |
+
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
56 |
+
|
57 |
+
def forward(self, context: torch.Tensor) -> torch.Tensor:
|
58 |
+
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
|
59 |
+
latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
|
60 |
+
|
61 |
+
# Feed through Perceiver Attention blocks...
|
62 |
+
for attn, ff in self.blocks:
|
63 |
+
latents = attn(context, latents) + latents
|
64 |
+
latents = ff(latents) + latents
|
65 |
+
|
66 |
+
return self.layer_norm(latents)
|
67 |
+
|
68 |
+
|
69 |
+
class PerceiverAttention(nn.Module):
|
70 |
+
def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None:
|
71 |
+
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
72 |
+
super().__init__()
|
73 |
+
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
|
74 |
+
self.qk_layer_norms = qk_layer_norms
|
75 |
+
# Normalization & Scaling
|
76 |
+
self.context_layer_norm = nn.LayerNorm(self.embed_dim)
|
77 |
+
self.latents_layer_norm = nn.LayerNorm(self.embed_dim)
|
78 |
+
if self.qk_layer_norms:
|
79 |
+
self.q_layer_norm = nn.LayerNorm(self.head_dim)
|
80 |
+
self.k_layer_norm = nn.LayerNorm(self.head_dim)
|
81 |
+
|
82 |
+
self.qk_scale = self.head_dim**-0.5
|
83 |
+
|
84 |
+
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
|
85 |
+
self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
|
86 |
+
self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
|
87 |
+
self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
|
88 |
+
|
89 |
+
self.output_proj = nn.Linear(self.n_heads * self.head_dim, embed_dim, bias=False)
|
90 |
+
|
91 |
+
def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
|
94 |
+
:param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
|
95 |
+
:param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
|
96 |
+
:return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context.
|
97 |
+
"""
|
98 |
+
context = self.context_layer_norm(context)
|
99 |
+
latents = self.latents_layer_norm(latents)
|
100 |
+
|
101 |
+
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
|
102 |
+
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
|
103 |
+
q = self.q_proj(latents)
|
104 |
+
k = self.k_proj(torch.cat([context, latents], dim=-2))
|
105 |
+
v = self.v_proj(torch.cat([context, latents], dim=-2))
|
106 |
+
|
107 |
+
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
|
108 |
+
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
|
109 |
+
q, k, v = [rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) for x in (q, k, v)]
|
110 |
+
if self.qk_layer_norms:
|
111 |
+
q = self.q_layer_norm(q)
|
112 |
+
k = self.k_layer_norm(k)
|
113 |
+
|
114 |
+
scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
|
115 |
+
stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
|
116 |
+
attn = stabilized_scores.softmax(dim=-1)
|
117 |
+
|
118 |
+
# Attend & project back to output...
|
119 |
+
resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
120 |
+
return self.output_proj(
|
121 |
+
rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
class MLP(nn.Module):
|
126 |
+
def __init__(self, intermediate_size, config):
|
127 |
+
"""Simple MLP block with intermediate_size and embedding size"""
|
128 |
+
super().__init__()
|
129 |
+
self.embed_dim = config.vision_embed_dim
|
130 |
+
self.ln = nn.LayerNorm(self.embed_dim)
|
131 |
+
self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False)
|
132 |
+
self.act = nn.ReLU()
|
133 |
+
self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False)
|
134 |
+
|
135 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
136 |
+
hidden_states = self.ln(hidden_states)
|
137 |
+
hidden_states = self.fc(hidden_states)
|
138 |
+
hidden_states = self.act(hidden_states)
|
139 |
+
hidden_states = self.c_proj(hidden_states)
|
140 |
+
|
141 |
+
return hidden_states
|
m4/models/vbloom/__init__.py
ADDED
File without changes
|
m4/models/vbloom/configuration_vbloom.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" VBloom configuration"""
|
16 |
+
import os
|
17 |
+
from typing import Tuple, Union
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
|
28 |
+
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json",
|
29 |
+
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json",
|
30 |
+
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json",
|
31 |
+
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json",
|
32 |
+
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json",
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class VBloomConfig(PretrainedConfig):
|
37 |
+
"""
|
38 |
+
This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
|
39 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
40 |
+
defaults will yield a similar configuration to the Bloom architecture
|
41 |
+
[bigscience/bloom](https://huggingface.co/bigscience/bloom).
|
42 |
+
|
43 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
44 |
+
documentation from [`PretrainedConfig`] for more information.
|
45 |
+
|
46 |
+
TODO: this doc is completely out of sync with the actual args
|
47 |
+
|
48 |
+
Args:
|
49 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
50 |
+
Vocabulary size of the Bloom model. Defines the number of different tokens that can be represented by the
|
51 |
+
`inputs_ids` passed when calling [`BloomModel`].
|
52 |
+
additional_vocab_size (`int`, *optional`, defaults to 0):
|
53 |
+
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
54 |
+
are always trainable whereas regular vocab tokens can be frozen or not.
|
55 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
56 |
+
Dimensionality of the embeddings and hidden states.
|
57 |
+
n_layer (`int`, *optional*, defaults to 12):
|
58 |
+
Number of hidden layers in the Transformer encoder.
|
59 |
+
n_head (`int`, *optional*, defaults to 12):
|
60 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
61 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
62 |
+
The dropout ratio for the attention.
|
63 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
64 |
+
The epsilon to use in the layer normalization layers.
|
65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
67 |
+
alpha_initializer (`str`, *optional*, defaults to `"ones"`):
|
68 |
+
Initialization type for the alphas.
|
69 |
+
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
70 |
+
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
|
71 |
+
alpha_type (`str`, *optional*, defaults to `"vector"`):
|
72 |
+
Whether the gating alphas should be vectors or single floats.
|
73 |
+
apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
|
74 |
+
If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
|
75 |
+
skip_bias_add (`bool`, *optional*, defaults to `True`):
|
76 |
+
If set to `True`, it will skip bias add for each linear layer in the transformer blocks
|
77 |
+
skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
|
78 |
+
If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
|
79 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
80 |
+
Dropout rate of the dropout function on the bias dropout.
|
81 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
82 |
+
Dropout rate applied to the attention probs
|
83 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
84 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
85 |
+
pretraining_tp (`int`, *optional*, defaults to `1`):
|
86 |
+
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
|
87 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
88 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
89 |
+
issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
|
90 |
+
`slow_but_exact=True`.
|
91 |
+
slow_but_exact (`bool`, *optional*, defaults to `False`):
|
92 |
+
Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
|
93 |
+
merging the TP rank tensors, due to slicing operations the results may be slightly different between the
|
94 |
+
model trained on Megatron and our model. Please refer to [this
|
95 |
+
issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
|
96 |
+
enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
|
97 |
+
resolved in the future once the main model has been fine-tuned with TP_rank=1.
|
98 |
+
|
99 |
+
Example:
|
100 |
+
|
101 |
+
```python
|
102 |
+
>>> from transformers import BloomModel, BloomConfig
|
103 |
+
|
104 |
+
>>> # Initializing a Bloom configuration
|
105 |
+
>>> configuration = BloomConfig()
|
106 |
+
|
107 |
+
>>> # Initializing a model from the configuration
|
108 |
+
>>> model = BloomModel(configuration)
|
109 |
+
|
110 |
+
>>> # Accessing the model configuration
|
111 |
+
>>> configuration = model.config
|
112 |
+
```"""
|
113 |
+
|
114 |
+
model_type = "vbloom"
|
115 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
116 |
+
attribute_map = {
|
117 |
+
"num_hidden_layers": "n_layer",
|
118 |
+
"num_attention_heads": "n_head",
|
119 |
+
}
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
vocab_size=250880,
|
124 |
+
additional_vocab_size=0,
|
125 |
+
hidden_size=64,
|
126 |
+
n_layer=2,
|
127 |
+
n_head=8,
|
128 |
+
layer_norm_epsilon=1e-5,
|
129 |
+
initializer_range=0.02,
|
130 |
+
alpha_initializer="ones",
|
131 |
+
alphas_initializer_range=0.0,
|
132 |
+
alpha_type="vector",
|
133 |
+
use_cache=False,
|
134 |
+
bos_token_id=1,
|
135 |
+
eos_token_id=2,
|
136 |
+
apply_residual_connection_post_layernorm=False,
|
137 |
+
hidden_dropout=0.0,
|
138 |
+
attention_dropout=0.0,
|
139 |
+
pretraining_tp=1, # TP rank used when training with megatron
|
140 |
+
slow_but_exact=False,
|
141 |
+
cross_layer_interval=1,
|
142 |
+
tie_word_embeddings=False,
|
143 |
+
freeze_text_layers=True,
|
144 |
+
freeze_lm_head=False,
|
145 |
+
freeze_vision_layers=True,
|
146 |
+
vision_model_name="google/vit-base-patch16-224",
|
147 |
+
vision_model_params="{}",
|
148 |
+
vision_embed_dim=768,
|
149 |
+
image_token_index=250880,
|
150 |
+
use_resampler=False,
|
151 |
+
resampler_n_latents=64,
|
152 |
+
resampler_depth=6,
|
153 |
+
resampler_n_heads=16,
|
154 |
+
resampler_head_dim=96,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
self.vocab_size = vocab_size
|
158 |
+
self.additional_vocab_size = additional_vocab_size
|
159 |
+
# Backward compatibility with n_embed kwarg
|
160 |
+
n_embed = kwargs.pop("n_embed", None)
|
161 |
+
self.hidden_size = hidden_size if n_embed is None else n_embed
|
162 |
+
self.n_layer = n_layer
|
163 |
+
self.n_head = n_head
|
164 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
165 |
+
self.initializer_range = initializer_range
|
166 |
+
self.alpha_initializer = alpha_initializer
|
167 |
+
self.alphas_initializer_range = alphas_initializer_range
|
168 |
+
self.alpha_type = alpha_type
|
169 |
+
self.use_cache = use_cache
|
170 |
+
self.pretraining_tp = pretraining_tp
|
171 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
172 |
+
self.hidden_dropout = hidden_dropout
|
173 |
+
self.attention_dropout = attention_dropout
|
174 |
+
|
175 |
+
self.bos_token_id = bos_token_id
|
176 |
+
self.eos_token_id = eos_token_id
|
177 |
+
self.slow_but_exact = slow_but_exact
|
178 |
+
|
179 |
+
self.cross_layer_interval = cross_layer_interval
|
180 |
+
self.freeze_vision_layers = freeze_vision_layers
|
181 |
+
self.vision_model_name = vision_model_name
|
182 |
+
self.vision_model_params = vision_model_params
|
183 |
+
|
184 |
+
self.tie_word_embeddings = tie_word_embeddings
|
185 |
+
self.freeze_text_layers = freeze_text_layers
|
186 |
+
self.freeze_lm_head = freeze_lm_head
|
187 |
+
self.image_token_index = image_token_index
|
188 |
+
|
189 |
+
self.vision_embed_dim = vision_embed_dim
|
190 |
+
|
191 |
+
# Resampler params
|
192 |
+
self.use_resampler = use_resampler
|
193 |
+
self.resampler_n_latents = resampler_n_latents
|
194 |
+
self.resampler_depth = resampler_depth
|
195 |
+
self.resampler_n_heads = resampler_n_heads
|
196 |
+
self.resampler_head_dim = resampler_head_dim
|
197 |
+
|
198 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
199 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
200 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
201 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
202 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
203 |
+
|
204 |
+
super().__init__(
|
205 |
+
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
206 |
+
)
|
207 |
+
|
208 |
+
def check_compatibilities(self):
|
209 |
+
if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
|
210 |
+
raise ValueError(
|
211 |
+
"if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
|
212 |
+
)
|
213 |
+
|
214 |
+
vision_model_params = eval(self.vision_model_params)
|
215 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
216 |
+
if hasattr(config, "vision_config"):
|
217 |
+
vison_config = config.vision_config
|
218 |
+
else:
|
219 |
+
vison_config = config
|
220 |
+
vision_embed_dim = vison_config.hidden_size
|
221 |
+
if self.vision_embed_dim != vision_embed_dim:
|
222 |
+
raise ValueError(
|
223 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
224 |
+
f" ({vision_embed_dim})"
|
225 |
+
)
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
229 |
+
outputs = super(VBloomConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
230 |
+
if isinstance(outputs, Tuple):
|
231 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
232 |
+
outputs[0].check_compatibilities()
|
233 |
+
else:
|
234 |
+
outputs.check_compatibilities()
|
235 |
+
return outputs
|
m4/models/vbloom/modeling_vbloom.py
ADDED
@@ -0,0 +1,1396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch BLOOM model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import warnings
|
19 |
+
from typing import Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
from torch import nn
|
24 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
25 |
+
from torch.nn import functional as F
|
26 |
+
from transformers.file_utils import (
|
27 |
+
add_code_sample_docstrings,
|
28 |
+
add_start_docstrings,
|
29 |
+
add_start_docstrings_to_model_forward,
|
30 |
+
)
|
31 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
32 |
+
|
33 |
+
from m4.models import DecoupledEmbedding, DecoupledLinear
|
34 |
+
from m4.models.common import (
|
35 |
+
expand_inputs_for_generation,
|
36 |
+
prepare_inputs_for_generation,
|
37 |
+
update_model_kwargs_for_generation,
|
38 |
+
)
|
39 |
+
from m4.models.custom_modules import VLOOMPreTrainedModelBase
|
40 |
+
from m4.models.perceiver.perceiver import PerceiverResampler
|
41 |
+
from m4.models.vbloom.configuration_vbloom import VBloomConfig
|
42 |
+
from m4.training.utils import (
|
43 |
+
compute_perceiver_tflops_per_batch_per_gpu,
|
44 |
+
compute_tflops_per_batch_per_gpu,
|
45 |
+
freeze_model,
|
46 |
+
)
|
47 |
+
from m4.utils import logging
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
53 |
+
_CONFIG_FOR_DOC = "VBloomConfig"
|
54 |
+
_TOKENIZER_FOR_DOC = "BloomTokenizerFast"
|
55 |
+
|
56 |
+
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
57 |
+
"bigscience/bigscience-small-testing",
|
58 |
+
"bigscience/bloom-560m",
|
59 |
+
"bigscience/bloom-1b1",
|
60 |
+
"bigscience/bloom-1b7",
|
61 |
+
"bigscience/bloom-3b",
|
62 |
+
"bigscience/bloom-7b1",
|
63 |
+
"bigscience/bloom",
|
64 |
+
]
|
65 |
+
|
66 |
+
|
67 |
+
def _make_causal_mask(
|
68 |
+
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
69 |
+
) -> torch.BoolTensor:
|
70 |
+
"""
|
71 |
+
Make causal mask used for self-attention.
|
72 |
+
"""
|
73 |
+
batch_size, target_length = input_ids_shape
|
74 |
+
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
75 |
+
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
76 |
+
seq_ids = torch.arange(target_length, device=device)
|
77 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
78 |
+
|
79 |
+
if past_key_values_length > 0:
|
80 |
+
mask[:, :past_key_values_length] = False
|
81 |
+
|
82 |
+
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
83 |
+
return expanded_mask
|
84 |
+
|
85 |
+
|
86 |
+
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
87 |
+
"""
|
88 |
+
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
89 |
+
"""
|
90 |
+
batch_size, src_length = mask.shape
|
91 |
+
tgt_length = tgt_length if tgt_length is not None else src_length
|
92 |
+
|
93 |
+
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
94 |
+
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
95 |
+
|
96 |
+
|
97 |
+
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
98 |
+
"""
|
99 |
+
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
100 |
+
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
101 |
+
`softmax(l+a) = softmax(l)`. Based on
|
102 |
+
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
103 |
+
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
107 |
+
attention_mask (`torch.Tensor`):
|
108 |
+
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
109 |
+
num_heads (`int`, *required*):
|
110 |
+
number of heads
|
111 |
+
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
112 |
+
dtype of the output tensor
|
113 |
+
"""
|
114 |
+
batch_size, seq_length = attention_mask.shape
|
115 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
116 |
+
base = torch.tensor(
|
117 |
+
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
118 |
+
)
|
119 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
120 |
+
slopes = torch.pow(base, powers)
|
121 |
+
|
122 |
+
if closest_power_of_2 != num_heads:
|
123 |
+
extra_base = torch.tensor(
|
124 |
+
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
125 |
+
)
|
126 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
127 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
128 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
129 |
+
|
130 |
+
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
131 |
+
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
132 |
+
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
133 |
+
# => the query_length dimension will then be broadcasted correctly
|
134 |
+
# This is more or less identical to T5's relative position bias:
|
135 |
+
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
136 |
+
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
137 |
+
alibi = slopes[..., None] * arange_tensor
|
138 |
+
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
139 |
+
|
140 |
+
|
141 |
+
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
142 |
+
"""
|
143 |
+
Dropout add function
|
144 |
+
|
145 |
+
Args:
|
146 |
+
x (`torch.tensor`, *required*):
|
147 |
+
input tensor
|
148 |
+
residual (`torch.tensor`, *required*):
|
149 |
+
esidual tensor
|
150 |
+
prob (`float`, *required*):
|
151 |
+
dropout probability
|
152 |
+
training (`bool`, *required*):
|
153 |
+
training mode
|
154 |
+
"""
|
155 |
+
out = F.dropout(x, p=prob, training=training)
|
156 |
+
out = residual + out
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
|
161 |
+
"""
|
162 |
+
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
|
163 |
+
make the model jitable.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
x (`torch.tensor`, *required*):
|
167 |
+
input hidden states
|
168 |
+
"""
|
169 |
+
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
170 |
+
|
171 |
+
|
172 |
+
def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
173 |
+
"""
|
174 |
+
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
|
175 |
+
0.3989423 * x * torch.exp(-0.5 * x * x)
|
176 |
+
|
177 |
+
Args:
|
178 |
+
g (`torch.tensor`, *required*):
|
179 |
+
gradient output tensor
|
180 |
+
x (`torch.tensor`, *required*):
|
181 |
+
input tensor
|
182 |
+
"""
|
183 |
+
x = x[0] # x is a tuple of 1 element, needs to unpack it first
|
184 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
185 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
186 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
187 |
+
return ff * g
|
188 |
+
|
189 |
+
|
190 |
+
class GeLUFunction(torch.autograd.Function):
|
191 |
+
@staticmethod
|
192 |
+
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
193 |
+
ctx.save_for_backward(input)
|
194 |
+
return bloom_gelu_forward(input)
|
195 |
+
|
196 |
+
@staticmethod
|
197 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
198 |
+
input = ctx.saved_tensors
|
199 |
+
tmp = bloom_gelu_back(grad_output, input)
|
200 |
+
return tmp
|
201 |
+
|
202 |
+
|
203 |
+
class BloomGelu(nn.Module):
|
204 |
+
"""
|
205 |
+
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
|
206 |
+
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
|
207 |
+
copied from Megatron-DeepSpeed code and adapted for our needs
|
208 |
+
|
209 |
+
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
216 |
+
if self.training:
|
217 |
+
return GeLUFunction.apply(x)
|
218 |
+
else:
|
219 |
+
return bloom_gelu_forward(x)
|
220 |
+
|
221 |
+
|
222 |
+
class BloomAttention(nn.Module):
|
223 |
+
def __init__(self, config: VBloomConfig, is_cross_attention=False):
|
224 |
+
super().__init__()
|
225 |
+
|
226 |
+
self.pretraining_tp = config.pretraining_tp
|
227 |
+
self.slow_but_exact = config.slow_but_exact
|
228 |
+
|
229 |
+
self.hidden_size = config.hidden_size
|
230 |
+
self.num_heads = config.n_head
|
231 |
+
self.head_dim = self.hidden_size // self.num_heads
|
232 |
+
self.split_size = self.hidden_size
|
233 |
+
self.hidden_dropout = config.hidden_dropout
|
234 |
+
|
235 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
236 |
+
raise ValueError(
|
237 |
+
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
238 |
+
f" {self.num_heads})."
|
239 |
+
)
|
240 |
+
|
241 |
+
# Layer-wise attention scaling
|
242 |
+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
243 |
+
self.beta = 1.0
|
244 |
+
|
245 |
+
self.is_cross_attention = is_cross_attention
|
246 |
+
|
247 |
+
if self.is_cross_attention:
|
248 |
+
self.query = nn.Linear(self.hidden_size, 1 * self.hidden_size, bias=True)
|
249 |
+
kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
|
250 |
+
self.key_value = nn.Linear(kv_input_dim, 2 * self.hidden_size, bias=True)
|
251 |
+
else:
|
252 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
253 |
+
|
254 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
255 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
256 |
+
|
257 |
+
if self.is_cross_attention:
|
258 |
+
# The alpha stuff
|
259 |
+
self.act = nn.Tanh()
|
260 |
+
|
261 |
+
if config.alpha_initializer == "zeros":
|
262 |
+
if config.alpha_type == "vector":
|
263 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
264 |
+
elif config.alpha_type == "float":
|
265 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
|
266 |
+
else:
|
267 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
268 |
+
|
269 |
+
elif config.alpha_initializer == "ones":
|
270 |
+
if config.alpha_type == "vector":
|
271 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size))
|
272 |
+
elif config.alpha_type == "float":
|
273 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1))
|
274 |
+
else:
|
275 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
276 |
+
|
277 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
278 |
+
if config.alpha_type == "vector":
|
279 |
+
self.alpha_cross_attn = nn.Parameter(
|
280 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
|
281 |
+
)
|
282 |
+
elif config.alpha_type == "float":
|
283 |
+
self.alpha_cross_attn = nn.Parameter(
|
284 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
288 |
+
|
289 |
+
else:
|
290 |
+
raise NotImplementedError(
|
291 |
+
f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!"
|
292 |
+
)
|
293 |
+
|
294 |
+
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
295 |
+
"""
|
296 |
+
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
297 |
+
storage as `fused_qkv`
|
298 |
+
|
299 |
+
Args:
|
300 |
+
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
304 |
+
value: [batch_size, seq_length, num_heads, head_dim]
|
305 |
+
"""
|
306 |
+
batch_size, seq_length, n_times_hidden_size = fused_qkv.shape
|
307 |
+
n = int(n_times_hidden_size / self.hidden_size)
|
308 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, n, self.head_dim)
|
309 |
+
outputs = ()
|
310 |
+
for i in range(n):
|
311 |
+
outputs += (fused_qkv[..., i, :],)
|
312 |
+
return outputs
|
313 |
+
|
314 |
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
315 |
+
"""
|
316 |
+
Merge heads together over the last dimenstion
|
317 |
+
|
318 |
+
Args:
|
319 |
+
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
323 |
+
"""
|
324 |
+
# What we want to achieve is:
|
325 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
326 |
+
batch_size_and_num_heads, seq_length, _ = x.shape
|
327 |
+
batch_size = batch_size_and_num_heads // self.num_heads
|
328 |
+
|
329 |
+
# First view to decompose the batch size
|
330 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
331 |
+
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
332 |
+
|
333 |
+
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
334 |
+
x = x.permute(0, 2, 1, 3)
|
335 |
+
|
336 |
+
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
337 |
+
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
hidden_states: torch.Tensor,
|
342 |
+
residual: torch.Tensor,
|
343 |
+
alibi: torch.Tensor,
|
344 |
+
attention_mask: torch.Tensor,
|
345 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
346 |
+
head_mask: Optional[torch.Tensor] = None,
|
347 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
348 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
349 |
+
use_cache: bool = False,
|
350 |
+
output_attentions: bool = False,
|
351 |
+
):
|
352 |
+
if not self.is_cross_attention:
|
353 |
+
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
354 |
+
|
355 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
356 |
+
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
357 |
+
else:
|
358 |
+
if encoder_hidden_states is not None:
|
359 |
+
attention_mask = encoder_attention_mask
|
360 |
+
q = self.query(hidden_states)
|
361 |
+
kv = self.key_value(encoder_hidden_states)
|
362 |
+
|
363 |
+
query_layer = self._split_heads(q)[0]
|
364 |
+
key_layer, value_layer = self._split_heads(kv)
|
365 |
+
|
366 |
+
batch_size, q_length, _, _ = query_layer.shape
|
367 |
+
_, kv_length, _, _ = key_layer.shape
|
368 |
+
|
369 |
+
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
370 |
+
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, kv_length)
|
371 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, kv_length, self.head_dim)
|
372 |
+
if layer_past is not None:
|
373 |
+
past_key, past_value = layer_past
|
374 |
+
# concatenate along seq_length dimension:
|
375 |
+
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
376 |
+
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
377 |
+
key_layer = torch.cat((past_key, key_layer), dim=2)
|
378 |
+
value_layer = torch.cat((past_value, value_layer), dim=1)
|
379 |
+
_, _, kv_length = key_layer.shape
|
380 |
+
|
381 |
+
if use_cache is True:
|
382 |
+
present = (key_layer, value_layer)
|
383 |
+
else:
|
384 |
+
present = None
|
385 |
+
|
386 |
+
# [batch_size * num_heads, q_length, kv_length]
|
387 |
+
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
388 |
+
if alibi is None:
|
389 |
+
alibi = torch.empty(
|
390 |
+
batch_size * self.num_heads, q_length, kv_length, dtype=query_layer.dtype, device=query_layer.device
|
391 |
+
)
|
392 |
+
|
393 |
+
matmul_result = alibi.baddbmm(
|
394 |
+
batch1=query_layer,
|
395 |
+
batch2=key_layer,
|
396 |
+
beta=0.0 if self.is_cross_attention else self.beta,
|
397 |
+
alpha=self.inv_norm_factor,
|
398 |
+
)
|
399 |
+
|
400 |
+
# change view to [batch_size, num_heads, q_length, kv_length]
|
401 |
+
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
402 |
+
|
403 |
+
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
404 |
+
input_dtype = attention_scores.dtype
|
405 |
+
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
406 |
+
if input_dtype == torch.float16:
|
407 |
+
attention_scores = attention_scores.to(torch.float)
|
408 |
+
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
409 |
+
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
410 |
+
|
411 |
+
# [batch_size, num_heads, q_length, kv_length]
|
412 |
+
attention_probs = self.attention_dropout(attention_probs)
|
413 |
+
|
414 |
+
if head_mask is not None:
|
415 |
+
attention_probs = attention_probs * head_mask
|
416 |
+
|
417 |
+
# change view [batch_size x num_heads, q_length, kv_length]
|
418 |
+
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
419 |
+
|
420 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
421 |
+
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
422 |
+
|
423 |
+
# change view [batch_size, num_heads, q_length, head_dim]
|
424 |
+
context_layer = self._merge_heads(context_layer)
|
425 |
+
|
426 |
+
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
427 |
+
if self.pretraining_tp > 1 and self.slow_but_exact:
|
428 |
+
slices = self.hidden_size / self.pretraining_tp
|
429 |
+
output_tensor = torch.zeros_like(context_layer)
|
430 |
+
for i in range(self.pretraining_tp):
|
431 |
+
output_tensor = output_tensor + F.linear(
|
432 |
+
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
433 |
+
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
434 |
+
)
|
435 |
+
else:
|
436 |
+
output_tensor = self.dense(context_layer)
|
437 |
+
|
438 |
+
if not self.is_cross_attention:
|
439 |
+
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
440 |
+
else:
|
441 |
+
output_tensor = dropout_add(
|
442 |
+
self.act(self.alpha_cross_attn) * output_tensor, residual, self.hidden_dropout, self.training
|
443 |
+
)
|
444 |
+
|
445 |
+
outputs = (output_tensor, present)
|
446 |
+
if output_attentions:
|
447 |
+
outputs += (attention_probs,)
|
448 |
+
|
449 |
+
return outputs
|
450 |
+
|
451 |
+
|
452 |
+
class BloomMLP(nn.Module):
|
453 |
+
def __init__(self, config: VBloomConfig, is_gated=False):
|
454 |
+
super().__init__()
|
455 |
+
hidden_size = config.hidden_size
|
456 |
+
|
457 |
+
self.pretraining_tp = config.pretraining_tp
|
458 |
+
self.slow_but_exact = config.slow_but_exact
|
459 |
+
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
460 |
+
self.gelu_impl = BloomGelu()
|
461 |
+
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
462 |
+
self.hidden_dropout = config.hidden_dropout
|
463 |
+
|
464 |
+
# The alpha stuff
|
465 |
+
self.is_gated = is_gated
|
466 |
+
if is_gated:
|
467 |
+
self.act = nn.Tanh()
|
468 |
+
|
469 |
+
if config.alpha_initializer == "zeros":
|
470 |
+
if config.alpha_type == "vector":
|
471 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
472 |
+
elif config.alpha_type == "float":
|
473 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1))
|
474 |
+
else:
|
475 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
476 |
+
|
477 |
+
elif config.alpha_initializer == "ones":
|
478 |
+
if config.alpha_type == "vector":
|
479 |
+
self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
|
480 |
+
elif config.alpha_type == "float":
|
481 |
+
self.alpha_dense = nn.Parameter(torch.ones(1))
|
482 |
+
else:
|
483 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
484 |
+
|
485 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
486 |
+
if config.alpha_type == "vector":
|
487 |
+
self.alpha_dense = nn.Parameter(
|
488 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
|
489 |
+
)
|
490 |
+
elif config.alpha_type == "float":
|
491 |
+
self.alpha_dense = nn.Parameter(
|
492 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
496 |
+
|
497 |
+
else:
|
498 |
+
raise NotImplementedError(
|
499 |
+
f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!"
|
500 |
+
)
|
501 |
+
|
502 |
+
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
503 |
+
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
504 |
+
|
505 |
+
if self.pretraining_tp > 1 and self.slow_but_exact:
|
506 |
+
intermediate_output = torch.zeros_like(residual)
|
507 |
+
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
508 |
+
for i in range(self.pretraining_tp):
|
509 |
+
intermediate_output = intermediate_output + F.linear(
|
510 |
+
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
511 |
+
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
512 |
+
)
|
513 |
+
else:
|
514 |
+
intermediate_output = self.dense_4h_to_h(hidden_states)
|
515 |
+
|
516 |
+
if not self.is_gated:
|
517 |
+
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
518 |
+
else:
|
519 |
+
output = dropout_add(
|
520 |
+
self.act(self.alpha_dense) * intermediate_output, residual, self.hidden_dropout, self.training
|
521 |
+
)
|
522 |
+
|
523 |
+
return output
|
524 |
+
|
525 |
+
|
526 |
+
class BloomBlock(nn.Module):
|
527 |
+
def __init__(self, config: VBloomConfig):
|
528 |
+
super().__init__()
|
529 |
+
hidden_size = config.hidden_size
|
530 |
+
|
531 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
532 |
+
self.num_heads = config.n_head
|
533 |
+
self.self_attention = BloomAttention(config)
|
534 |
+
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
535 |
+
|
536 |
+
self.mlp = BloomMLP(config)
|
537 |
+
|
538 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
539 |
+
self.hidden_dropout = config.hidden_dropout
|
540 |
+
|
541 |
+
def forward(
|
542 |
+
self,
|
543 |
+
hidden_states: torch.Tensor,
|
544 |
+
alibi: torch.Tensor,
|
545 |
+
attention_mask: torch.Tensor,
|
546 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
547 |
+
head_mask: Optional[torch.Tensor] = None,
|
548 |
+
use_cache: bool = False,
|
549 |
+
output_attentions: bool = False,
|
550 |
+
):
|
551 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
552 |
+
|
553 |
+
# Layer norm at the beginning of the transformer layer.
|
554 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
555 |
+
|
556 |
+
# Layer norm post the self attention.
|
557 |
+
if self.apply_residual_connection_post_layernorm:
|
558 |
+
residual = layernorm_output
|
559 |
+
else:
|
560 |
+
residual = hidden_states
|
561 |
+
|
562 |
+
# Self attention.
|
563 |
+
attn_outputs = self.self_attention(
|
564 |
+
layernorm_output,
|
565 |
+
residual,
|
566 |
+
layer_past=layer_past,
|
567 |
+
attention_mask=attention_mask,
|
568 |
+
alibi=alibi,
|
569 |
+
head_mask=head_mask,
|
570 |
+
use_cache=use_cache,
|
571 |
+
output_attentions=output_attentions,
|
572 |
+
)
|
573 |
+
|
574 |
+
attention_output = attn_outputs[0]
|
575 |
+
|
576 |
+
outputs = attn_outputs[1:]
|
577 |
+
|
578 |
+
layernorm_output = self.post_attention_layernorm(attention_output)
|
579 |
+
|
580 |
+
# Get residual
|
581 |
+
if self.apply_residual_connection_post_layernorm:
|
582 |
+
residual = layernorm_output
|
583 |
+
else:
|
584 |
+
residual = attention_output
|
585 |
+
|
586 |
+
# MLP.
|
587 |
+
output = self.mlp(layernorm_output, residual)
|
588 |
+
|
589 |
+
if use_cache:
|
590 |
+
outputs = (output,) + outputs
|
591 |
+
else:
|
592 |
+
outputs = (output,) + outputs[1:]
|
593 |
+
|
594 |
+
return outputs # hidden_states, present, attentions
|
595 |
+
|
596 |
+
|
597 |
+
class VBloomGatedCrossAttentionBlock(nn.Module):
|
598 |
+
def __init__(self, config):
|
599 |
+
super().__init__()
|
600 |
+
hidden_size = config.hidden_size
|
601 |
+
|
602 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
603 |
+
self.num_heads = config.n_head
|
604 |
+
self.cross_attention = BloomAttention(config, is_cross_attention=True)
|
605 |
+
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
606 |
+
|
607 |
+
self.gated_mlp = BloomMLP(config, is_gated=True)
|
608 |
+
|
609 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
610 |
+
self.hidden_dropout = config.hidden_dropout
|
611 |
+
|
612 |
+
def forward(
|
613 |
+
self,
|
614 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
615 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
616 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
617 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
618 |
+
image_hidden_states: Optional[torch.Tensor] = None,
|
619 |
+
image_attention_mask: Optional[torch.FloatTensor] = None,
|
620 |
+
use_cache: Optional[bool] = False,
|
621 |
+
output_attentions: Optional[bool] = False,
|
622 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
623 |
+
# hidden_states: [batch_size, seq_length, hidden_size]
|
624 |
+
|
625 |
+
# Layer norm at the beginning of the transformer layer.
|
626 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
627 |
+
|
628 |
+
# Layer norm post the self attention.
|
629 |
+
if self.apply_residual_connection_post_layernorm:
|
630 |
+
residual = layernorm_output
|
631 |
+
else:
|
632 |
+
residual = hidden_states
|
633 |
+
|
634 |
+
# Self attention.
|
635 |
+
attn_outputs = self.cross_attention(
|
636 |
+
layernorm_output,
|
637 |
+
residual,
|
638 |
+
alibi=None,
|
639 |
+
layer_past=layer_past,
|
640 |
+
attention_mask=attention_mask,
|
641 |
+
head_mask=head_mask,
|
642 |
+
encoder_hidden_states=image_hidden_states,
|
643 |
+
encoder_attention_mask=image_attention_mask,
|
644 |
+
use_cache=use_cache,
|
645 |
+
output_attentions=output_attentions,
|
646 |
+
)
|
647 |
+
|
648 |
+
attention_output = attn_outputs[0]
|
649 |
+
|
650 |
+
outputs = attn_outputs[1:]
|
651 |
+
|
652 |
+
layernorm_output = self.post_attention_layernorm(attention_output)
|
653 |
+
|
654 |
+
# Get residual
|
655 |
+
if self.apply_residual_connection_post_layernorm:
|
656 |
+
residual = layernorm_output
|
657 |
+
else:
|
658 |
+
residual = attention_output
|
659 |
+
|
660 |
+
# MLP.
|
661 |
+
output = self.gated_mlp(layernorm_output, residual)
|
662 |
+
|
663 |
+
if use_cache:
|
664 |
+
outputs = (output,) + outputs
|
665 |
+
else:
|
666 |
+
outputs = (output,) + outputs[1:]
|
667 |
+
|
668 |
+
return outputs # hidden_states, present, attentions
|
669 |
+
|
670 |
+
|
671 |
+
class VBloomPreTrainedModel(VLOOMPreTrainedModelBase):
|
672 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
673 |
+
"""
|
674 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
675 |
+
models.
|
676 |
+
"""
|
677 |
+
|
678 |
+
config_class = VBloomConfig
|
679 |
+
base_model_prefix = "transformer"
|
680 |
+
supports_gradient_checkpointing = True
|
681 |
+
_no_split_modules = ["BloomBlock"]
|
682 |
+
|
683 |
+
def __init__(self, *inputs, **kwargs):
|
684 |
+
super().__init__(*inputs, **kwargs)
|
685 |
+
|
686 |
+
def _init_weights(self, module: nn.Module):
|
687 |
+
"""Initialize the weights."""
|
688 |
+
if isinstance(module, nn.Linear):
|
689 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
690 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
691 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
692 |
+
if module.bias is not None:
|
693 |
+
module.bias.data.zero_()
|
694 |
+
elif isinstance(module, nn.Embedding):
|
695 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
696 |
+
if module.padding_idx is not None:
|
697 |
+
module.weight.data[module.padding_idx].zero_()
|
698 |
+
elif isinstance(module, LayerNorm):
|
699 |
+
module.bias.data.zero_()
|
700 |
+
module.weight.data.fill_(1.0)
|
701 |
+
|
702 |
+
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
|
703 |
+
if isinstance(module, VBloomModel):
|
704 |
+
module.gradient_checkpointing = value
|
705 |
+
|
706 |
+
@classmethod
|
707 |
+
def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
|
708 |
+
# this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
|
709 |
+
beheaded_model = model.transformer if hasattr(model, "transformer") else model
|
710 |
+
cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
|
711 |
+
beheaded_model.freeze_relevant_params(config)
|
712 |
+
|
713 |
+
|
714 |
+
BLOOM_START_DOCSTRING = r"""
|
715 |
+
|
716 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
717 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
|
718 |
+
|
719 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
720 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
721 |
+
and behavior.
|
722 |
+
|
723 |
+
Parameters:
|
724 |
+
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
|
725 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
726 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
727 |
+
"""
|
728 |
+
|
729 |
+
BLOOM_INPUTS_DOCSTRING = r"""
|
730 |
+
Args:
|
731 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
732 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
|
733 |
+
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
|
734 |
+
|
735 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
736 |
+
`input_ids`.
|
737 |
+
|
738 |
+
Indices can be obtained using [`BloomTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and
|
739 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
740 |
+
|
741 |
+
[What are input IDs?](../glossary#input-ids)
|
742 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
743 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
744 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
745 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
746 |
+
|
747 |
+
Each element of `past_key_values` is a tuple (past_key, past_value):
|
748 |
+
- past_key: [batch_size * num_heads, head_dim, kv_length]
|
749 |
+
- past_value: [batch_size * num_heads, kv_length, head_dim]
|
750 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
751 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
752 |
+
|
753 |
+
- 1 for tokens that are **not masked**,
|
754 |
+
- 0 for tokens that are **masked**.
|
755 |
+
|
756 |
+
[What are attention masks?](../glossary#attention-mask)
|
757 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
758 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
759 |
+
|
760 |
+
- 1 indicates the head is **not masked**,
|
761 |
+
- 0 indicates the head is **masked**.
|
762 |
+
|
763 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
764 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
765 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
766 |
+
model's internal embedding lookup matrix.
|
767 |
+
|
768 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
769 |
+
`past_key_values`).
|
770 |
+
use_cache (`bool`, *optional*):
|
771 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
772 |
+
`past_key_values`).
|
773 |
+
output_attentions (`bool`, *optional*):
|
774 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
775 |
+
tensors for more detail.
|
776 |
+
output_hidden_states (`bool`, *optional*):
|
777 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
778 |
+
more detail.
|
779 |
+
return_dict (`bool`, *optional*):
|
780 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
781 |
+
"""
|
782 |
+
|
783 |
+
|
784 |
+
@add_start_docstrings(
|
785 |
+
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
786 |
+
BLOOM_START_DOCSTRING,
|
787 |
+
)
|
788 |
+
class VBloomModel(VBloomPreTrainedModel):
|
789 |
+
def __init__(self, config: VBloomConfig, vision_model=None):
|
790 |
+
super().__init__(config)
|
791 |
+
|
792 |
+
self.embed_dim = config.hidden_size
|
793 |
+
self.num_heads = config.n_head
|
794 |
+
|
795 |
+
# Embedding + LN Embedding
|
796 |
+
self.word_embeddings = DecoupledEmbedding(
|
797 |
+
num_embeddings=config.vocab_size,
|
798 |
+
num_additional_embeddings=config.additional_vocab_size,
|
799 |
+
embedding_dim=self.embed_dim,
|
800 |
+
partially_freeze=config.freeze_text_layers,
|
801 |
+
)
|
802 |
+
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
803 |
+
|
804 |
+
# Transformer blocks
|
805 |
+
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
806 |
+
|
807 |
+
# Final Layer Norm
|
808 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
809 |
+
|
810 |
+
self.cross_layer_interval = config.cross_layer_interval
|
811 |
+
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
812 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
813 |
+
[VBloomGatedCrossAttentionBlock(config) for i in range(num_cross_layers)]
|
814 |
+
)
|
815 |
+
|
816 |
+
# Perceiver Resampler
|
817 |
+
if config.use_resampler:
|
818 |
+
self.perceiver_resampler = PerceiverResampler(
|
819 |
+
self.config,
|
820 |
+
self.config.vision_embed_dim,
|
821 |
+
config.resampler_depth,
|
822 |
+
config.resampler_n_heads,
|
823 |
+
config.resampler_head_dim,
|
824 |
+
config.resampler_n_latents,
|
825 |
+
)
|
826 |
+
self.gradient_checkpointing = False
|
827 |
+
|
828 |
+
# Load an uninitialized model and later in from_pretrained will load the pre-trained model -
|
829 |
+
# this solves the losing of weights in `from_pretrained` on the main model
|
830 |
+
self.vision_model = vision_model
|
831 |
+
|
832 |
+
# Initialize weights and apply final processing
|
833 |
+
self.post_init()
|
834 |
+
|
835 |
+
self.freeze_relevant_params(config)
|
836 |
+
|
837 |
+
def freeze_relevant_params(self, config=None):
|
838 |
+
if config is None:
|
839 |
+
config = self.config
|
840 |
+
|
841 |
+
if config.freeze_text_layers:
|
842 |
+
self.freeze_text_layers()
|
843 |
+
|
844 |
+
if config.freeze_vision_layers:
|
845 |
+
freeze_model(self.vision_model)
|
846 |
+
|
847 |
+
def freeze_text_layers(self):
|
848 |
+
for module in [self.word_embeddings_layernorm, self.h, self.ln_f]:
|
849 |
+
freeze_model(module)
|
850 |
+
|
851 |
+
def get_input_embeddings(self):
|
852 |
+
return self.word_embeddings
|
853 |
+
|
854 |
+
def _prepare_attn_mask(
|
855 |
+
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
856 |
+
) -> torch.BoolTensor:
|
857 |
+
# create causal mask
|
858 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
859 |
+
combined_attention_mask = None
|
860 |
+
device = attention_mask.device
|
861 |
+
_, src_length = input_shape
|
862 |
+
|
863 |
+
if src_length > 1:
|
864 |
+
combined_attention_mask = _make_causal_mask(
|
865 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
866 |
+
)
|
867 |
+
|
868 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
869 |
+
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
870 |
+
combined_attention_mask = (
|
871 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
872 |
+
)
|
873 |
+
|
874 |
+
return combined_attention_mask
|
875 |
+
|
876 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
877 |
+
self.word_embeddings = new_embeddings
|
878 |
+
|
879 |
+
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
880 |
+
@add_code_sample_docstrings(
|
881 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
882 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
883 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
884 |
+
config_class=_CONFIG_FOR_DOC,
|
885 |
+
)
|
886 |
+
def forward(
|
887 |
+
self,
|
888 |
+
input_ids: Optional[torch.LongTensor] = None,
|
889 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
890 |
+
attention_mask: Optional[torch.Tensor] = None,
|
891 |
+
head_mask: Optional[torch.LongTensor] = None,
|
892 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
893 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
894 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
895 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
896 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
897 |
+
use_cache: Optional[bool] = None,
|
898 |
+
output_attentions: Optional[bool] = None,
|
899 |
+
output_hidden_states: Optional[bool] = None,
|
900 |
+
return_dict: Optional[bool] = None,
|
901 |
+
**deprecated_arguments,
|
902 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
903 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
904 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
905 |
+
warnings.warn(
|
906 |
+
(
|
907 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely"
|
908 |
+
" ignore passing `position_ids`."
|
909 |
+
),
|
910 |
+
FutureWarning,
|
911 |
+
)
|
912 |
+
if len(deprecated_arguments) > 0:
|
913 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
914 |
+
|
915 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
916 |
+
output_hidden_states = (
|
917 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
918 |
+
)
|
919 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
920 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
921 |
+
|
922 |
+
if input_ids is not None and inputs_embeds is not None:
|
923 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
924 |
+
elif input_ids is not None:
|
925 |
+
batch_size, seq_length = input_ids.shape
|
926 |
+
elif inputs_embeds is not None:
|
927 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
928 |
+
else:
|
929 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
930 |
+
|
931 |
+
if past_key_values is None:
|
932 |
+
past_key_values = tuple([None] * len(self.h))
|
933 |
+
|
934 |
+
# Prepare head mask if needed
|
935 |
+
# 1.0 in head_mask indicate we keep the head
|
936 |
+
# attention_probs has shape batch_size x num_heads x N x N
|
937 |
+
# head_mask has shape n_layer x batch x num_heads x N x N
|
938 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
939 |
+
|
940 |
+
if inputs_embeds is None:
|
941 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
942 |
+
|
943 |
+
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
944 |
+
|
945 |
+
presents = () if use_cache else None
|
946 |
+
all_self_attentions = () if output_attentions else None
|
947 |
+
all_hidden_states = () if output_hidden_states else None
|
948 |
+
|
949 |
+
# Compute alibi tensor: check build_alibi_tensor documentation
|
950 |
+
seq_length_with_past = seq_length
|
951 |
+
past_key_values_length = 0
|
952 |
+
if past_key_values[0] is not None:
|
953 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
954 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
955 |
+
if attention_mask is None:
|
956 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
957 |
+
else:
|
958 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
959 |
+
|
960 |
+
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
961 |
+
|
962 |
+
causal_mask = self._prepare_attn_mask(
|
963 |
+
attention_mask,
|
964 |
+
input_shape=(batch_size, seq_length),
|
965 |
+
past_key_values_length=past_key_values_length,
|
966 |
+
)
|
967 |
+
|
968 |
+
if pixel_values is not None and image_embeddings is not None:
|
969 |
+
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
970 |
+
elif pixel_values is not None:
|
971 |
+
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
972 |
+
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
973 |
+
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
974 |
+
# Get sequence from the vision encoder
|
975 |
+
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
976 |
+
elif image_embeddings is not None:
|
977 |
+
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
978 |
+
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
979 |
+
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
980 |
+
|
981 |
+
if self.config.use_resampler:
|
982 |
+
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
983 |
+
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
984 |
+
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
985 |
+
# Make image_attention_mask compatible with hidden states
|
986 |
+
text_seq_len = image_attention_mask.size(1)
|
987 |
+
image_attention_mask = image_attention_mask.unsqueeze(
|
988 |
+
-1
|
989 |
+
) # TODO: something i don't understand here. why are the few last tokens not attending when there is just a single image?
|
990 |
+
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
991 |
+
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
992 |
+
|
993 |
+
if image_hidden_states is not None:
|
994 |
+
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
995 |
+
image_hidden_shape = (image_batch_size, image_sequence_length)
|
996 |
+
if image_attention_mask is None:
|
997 |
+
image_attention_mask = torch.ones(image_hidden_shape, device=hidden_states.device)
|
998 |
+
# image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
999 |
+
image_attention_mask = image_attention_mask.to(torch.bool)
|
1000 |
+
image_attention_mask = image_attention_mask[:, None, :, :]
|
1001 |
+
else:
|
1002 |
+
image_attention_mask = None
|
1003 |
+
|
1004 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
1005 |
+
if output_hidden_states:
|
1006 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1007 |
+
|
1008 |
+
def vblock(
|
1009 |
+
main_block,
|
1010 |
+
hidden_states,
|
1011 |
+
alibi,
|
1012 |
+
layer_past,
|
1013 |
+
attention_mask,
|
1014 |
+
layer_head_mask,
|
1015 |
+
use_cache,
|
1016 |
+
output_attentions,
|
1017 |
+
image_hidden_states,
|
1018 |
+
image_attention_mask,
|
1019 |
+
layer_idx,
|
1020 |
+
cross_layer_interval,
|
1021 |
+
gated_cross_attn_layers,
|
1022 |
+
):
|
1023 |
+
if layer_idx % cross_layer_interval == 0:
|
1024 |
+
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
1025 |
+
outputs = xblock(
|
1026 |
+
hidden_states,
|
1027 |
+
attention_mask=attention_mask,
|
1028 |
+
image_hidden_states=image_hidden_states,
|
1029 |
+
image_attention_mask=image_attention_mask,
|
1030 |
+
use_cache=use_cache,
|
1031 |
+
output_attentions=output_attentions,
|
1032 |
+
)
|
1033 |
+
hidden_states = outputs[0]
|
1034 |
+
|
1035 |
+
outputs = main_block(
|
1036 |
+
hidden_states,
|
1037 |
+
alibi=alibi,
|
1038 |
+
layer_past=layer_past,
|
1039 |
+
attention_mask=attention_mask,
|
1040 |
+
head_mask=layer_head_mask,
|
1041 |
+
use_cache=use_cache,
|
1042 |
+
output_attentions=output_attentions,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
return outputs
|
1046 |
+
|
1047 |
+
if self.gradient_checkpointing and self.training:
|
1048 |
+
layer_past = None
|
1049 |
+
if use_cache:
|
1050 |
+
logger.warning_once(
|
1051 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1052 |
+
)
|
1053 |
+
use_cache = False
|
1054 |
+
|
1055 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
1056 |
+
vblock,
|
1057 |
+
block,
|
1058 |
+
hidden_states,
|
1059 |
+
alibi,
|
1060 |
+
layer_past,
|
1061 |
+
causal_mask,
|
1062 |
+
head_mask[i],
|
1063 |
+
use_cache,
|
1064 |
+
output_attentions,
|
1065 |
+
image_hidden_states,
|
1066 |
+
image_attention_mask,
|
1067 |
+
i,
|
1068 |
+
self.cross_layer_interval,
|
1069 |
+
self.gated_cross_attn_layers,
|
1070 |
+
)
|
1071 |
+
else:
|
1072 |
+
outputs = vblock(
|
1073 |
+
block,
|
1074 |
+
hidden_states,
|
1075 |
+
alibi=alibi,
|
1076 |
+
layer_past=layer_past,
|
1077 |
+
attention_mask=causal_mask,
|
1078 |
+
layer_head_mask=head_mask[i],
|
1079 |
+
use_cache=use_cache,
|
1080 |
+
output_attentions=output_attentions,
|
1081 |
+
image_hidden_states=image_hidden_states,
|
1082 |
+
image_attention_mask=image_attention_mask,
|
1083 |
+
layer_idx=i,
|
1084 |
+
cross_layer_interval=self.cross_layer_interval,
|
1085 |
+
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
hidden_states = outputs[0]
|
1089 |
+
if use_cache is True:
|
1090 |
+
presents = presents + (outputs[1],)
|
1091 |
+
|
1092 |
+
if output_attentions:
|
1093 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
1094 |
+
|
1095 |
+
# Add last hidden state
|
1096 |
+
hidden_states = self.ln_f(hidden_states)
|
1097 |
+
|
1098 |
+
if output_hidden_states:
|
1099 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1100 |
+
|
1101 |
+
if not return_dict:
|
1102 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
1103 |
+
|
1104 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1105 |
+
last_hidden_state=hidden_states,
|
1106 |
+
past_key_values=presents,
|
1107 |
+
hidden_states=all_hidden_states,
|
1108 |
+
attentions=all_self_attentions,
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
|
1112 |
+
@add_start_docstrings(
|
1113 |
+
"""
|
1114 |
+
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
1115 |
+
embeddings).
|
1116 |
+
""",
|
1117 |
+
BLOOM_START_DOCSTRING,
|
1118 |
+
)
|
1119 |
+
class VBloomForCausalLM(VBloomPreTrainedModel):
|
1120 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
1121 |
+
|
1122 |
+
def __init__(self, config: VBloomConfig, vision_model=None):
|
1123 |
+
super().__init__(config)
|
1124 |
+
self.transformer = VBloomModel(config, vision_model=vision_model)
|
1125 |
+
self.lm_head = DecoupledLinear(
|
1126 |
+
in_features=config.hidden_size,
|
1127 |
+
out_features=config.vocab_size,
|
1128 |
+
out_additional_features=config.additional_vocab_size,
|
1129 |
+
bias=False,
|
1130 |
+
partially_freeze=config.freeze_lm_head,
|
1131 |
+
)
|
1132 |
+
# Initialize weights and apply final processing
|
1133 |
+
self.post_init()
|
1134 |
+
|
1135 |
+
def get_output_embeddings(self):
|
1136 |
+
return self.lm_head
|
1137 |
+
|
1138 |
+
def set_output_embeddings(self, new_embeddings: torch.Tensor):
|
1139 |
+
self.lm_head = new_embeddings
|
1140 |
+
|
1141 |
+
def tie_weights(self):
|
1142 |
+
"""
|
1143 |
+
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
1144 |
+
"""
|
1145 |
+
output_embeddings = self.get_output_embeddings()
|
1146 |
+
input_embeddings = self.get_input_embeddings()
|
1147 |
+
|
1148 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
1149 |
+
output_embeddings.weight = input_embeddings.weight
|
1150 |
+
if input_embeddings.num_additional_embeddings > 0:
|
1151 |
+
assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
|
1152 |
+
output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
|
1153 |
+
|
1154 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
1155 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
1156 |
+
if hasattr(output_embeddings, "out_additional_features") and hasattr(
|
1157 |
+
input_embeddings, "num_additional_embeddings"
|
1158 |
+
):
|
1159 |
+
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
1160 |
+
|
1161 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1162 |
+
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
1163 |
+
unwanted_kwargs = ["position_ids", "token_type_ids"]
|
1164 |
+
for kwarg in unwanted_kwargs:
|
1165 |
+
inputs.pop(kwarg, None)
|
1166 |
+
return inputs
|
1167 |
+
|
1168 |
+
@staticmethod
|
1169 |
+
def _expand_inputs_for_generation(
|
1170 |
+
*args,
|
1171 |
+
**model_kwargs,
|
1172 |
+
):
|
1173 |
+
return expand_inputs_for_generation(*args, **model_kwargs)
|
1174 |
+
|
1175 |
+
@staticmethod
|
1176 |
+
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
1177 |
+
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
|
1178 |
+
|
1179 |
+
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
1180 |
+
@add_code_sample_docstrings(
|
1181 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1182 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1183 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1184 |
+
config_class=_CONFIG_FOR_DOC,
|
1185 |
+
)
|
1186 |
+
def forward(
|
1187 |
+
self,
|
1188 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1189 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1190 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1191 |
+
head_mask: Optional[torch.Tensor] = None,
|
1192 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1193 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1194 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1195 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1196 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
1197 |
+
labels: Optional[torch.Tensor] = None,
|
1198 |
+
use_cache: Optional[bool] = None,
|
1199 |
+
output_attentions: Optional[bool] = None,
|
1200 |
+
output_hidden_states: Optional[bool] = None,
|
1201 |
+
return_dict: Optional[bool] = None,
|
1202 |
+
**deprecated_arguments,
|
1203 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
1204 |
+
r"""
|
1205 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1206 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1207 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1208 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1209 |
+
"""
|
1210 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
1211 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
1212 |
+
warnings.warn(
|
1213 |
+
(
|
1214 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely"
|
1215 |
+
" ignore passing `position_ids`."
|
1216 |
+
),
|
1217 |
+
FutureWarning,
|
1218 |
+
)
|
1219 |
+
if len(deprecated_arguments) > 0:
|
1220 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
1221 |
+
|
1222 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1223 |
+
|
1224 |
+
transformer_outputs = self.transformer(
|
1225 |
+
input_ids,
|
1226 |
+
past_key_values=past_key_values,
|
1227 |
+
attention_mask=attention_mask,
|
1228 |
+
head_mask=head_mask,
|
1229 |
+
inputs_embeds=inputs_embeds,
|
1230 |
+
pixel_values=pixel_values,
|
1231 |
+
image_embeddings=image_embeddings,
|
1232 |
+
image_attention_mask=image_attention_mask,
|
1233 |
+
crossblock_head_mask=crossblock_head_mask,
|
1234 |
+
use_cache=use_cache,
|
1235 |
+
output_attentions=output_attentions,
|
1236 |
+
output_hidden_states=output_hidden_states,
|
1237 |
+
return_dict=return_dict,
|
1238 |
+
)
|
1239 |
+
hidden_states = transformer_outputs[0]
|
1240 |
+
|
1241 |
+
lm_logits = self.lm_head(hidden_states)
|
1242 |
+
|
1243 |
+
loss = None
|
1244 |
+
if labels is not None:
|
1245 |
+
# Shift so that tokens < n predict n
|
1246 |
+
if attention_mask is not None:
|
1247 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1248 |
+
shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
1249 |
+
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
1250 |
+
else:
|
1251 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1252 |
+
shift_labels = labels[..., 1:].contiguous()
|
1253 |
+
# Flatten the tokens
|
1254 |
+
loss_fct = CrossEntropyLoss()
|
1255 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1256 |
+
|
1257 |
+
if not return_dict:
|
1258 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1259 |
+
return ((loss,) + output) if loss is not None else output
|
1260 |
+
|
1261 |
+
return CausalLMOutputWithCrossAttentions(
|
1262 |
+
loss=loss,
|
1263 |
+
logits=lm_logits,
|
1264 |
+
past_key_values=transformer_outputs.past_key_values,
|
1265 |
+
hidden_states=transformer_outputs.hidden_states,
|
1266 |
+
attentions=transformer_outputs.attentions,
|
1267 |
+
)
|
1268 |
+
|
1269 |
+
@staticmethod
|
1270 |
+
def _reorder_cache(
|
1271 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
1272 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
1273 |
+
"""
|
1274 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1275 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1276 |
+
beam_idx at every generation step.
|
1277 |
+
|
1278 |
+
Output shares the same memory storage as `past`.
|
1279 |
+
"""
|
1280 |
+
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
|
1281 |
+
batch_size = len(beam_idx)
|
1282 |
+
num_heads = batch_size_times_num_heads // batch_size
|
1283 |
+
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
1284 |
+
device_to_beam_idx = {
|
1285 |
+
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
1286 |
+
}
|
1287 |
+
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
|
1288 |
+
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
|
1289 |
+
return tuple(
|
1290 |
+
(
|
1291 |
+
layer_past[0]
|
1292 |
+
.view(batch_size, num_heads, head_dim, seq_length)
|
1293 |
+
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
1294 |
+
.view(batch_size_times_num_heads, head_dim, seq_length),
|
1295 |
+
layer_past[1]
|
1296 |
+
.view(batch_size, num_heads, seq_length, head_dim)
|
1297 |
+
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
1298 |
+
.view(batch_size_times_num_heads, seq_length, head_dim),
|
1299 |
+
)
|
1300 |
+
for layer_past in past
|
1301 |
+
)
|
1302 |
+
|
1303 |
+
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
|
1304 |
+
config_vl_model = self.config
|
1305 |
+
|
1306 |
+
language_embed_size = config_vl_model.hidden_size
|
1307 |
+
vision_config = self.transformer.vision_model.config
|
1308 |
+
num_language_layers = config_vl_model.n_layer
|
1309 |
+
ffn_inner_size = 4 * config_vl_model.hidden_size
|
1310 |
+
|
1311 |
+
# Get vision model blocks infos
|
1312 |
+
vision_patch_size = vision_config.patch_size
|
1313 |
+
vision_hidden_size = vision_config.hidden_size
|
1314 |
+
num_vision_layers = vision_config.num_hidden_layers
|
1315 |
+
# The +1 is for the CLS token
|
1316 |
+
single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
|
1317 |
+
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
|
1318 |
+
|
1319 |
+
# Get language and cross-att blocks infos
|
1320 |
+
num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
|
1321 |
+
language_seq_len = data_param.max_seq_len
|
1322 |
+
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1323 |
+
cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1324 |
+
k_v_cross_attn_seq_len = (
|
1325 |
+
(self.config.resampler_n_latents * max_num_images)
|
1326 |
+
if self.config.use_resampler
|
1327 |
+
else (single_image_seq_len * max_num_images)
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1331 |
+
num_layers=num_language_layers,
|
1332 |
+
batch_size=hparams.batch_size_per_gpu,
|
1333 |
+
q_seq_len=language_seq_len,
|
1334 |
+
k_seq_len=language_seq_len,
|
1335 |
+
hidden_size=language_embed_size,
|
1336 |
+
kv_in_dim=language_embed_size,
|
1337 |
+
ff_exp_factor=language_exp_factor,
|
1338 |
+
grad_acc_size=hparams.grad_acc_size,
|
1339 |
+
swiglu=False,
|
1340 |
+
vocab_size=tokenizer.vocab_size,
|
1341 |
+
count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
|
1342 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1343 |
+
)
|
1344 |
+
cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1345 |
+
num_layers=num_cross_attn_layers,
|
1346 |
+
batch_size=hparams.batch_size_per_gpu,
|
1347 |
+
q_seq_len=language_seq_len,
|
1348 |
+
k_seq_len=k_v_cross_attn_seq_len,
|
1349 |
+
hidden_size=language_embed_size,
|
1350 |
+
kv_in_dim=vision_hidden_size,
|
1351 |
+
ff_exp_factor=cross_att_exp_factor,
|
1352 |
+
grad_acc_size=hparams.grad_acc_size,
|
1353 |
+
swiglu=False,
|
1354 |
+
vocab_size=None,
|
1355 |
+
count_backward=True,
|
1356 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1357 |
+
)
|
1358 |
+
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1359 |
+
num_layers=num_vision_layers,
|
1360 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1361 |
+
q_seq_len=single_image_seq_len,
|
1362 |
+
k_seq_len=single_image_seq_len,
|
1363 |
+
hidden_size=vision_hidden_size,
|
1364 |
+
kv_in_dim=vision_hidden_size,
|
1365 |
+
ff_exp_factor=vision_exp_factor,
|
1366 |
+
grad_acc_size=hparams.grad_acc_size,
|
1367 |
+
swiglu=False,
|
1368 |
+
vocab_size=None,
|
1369 |
+
count_backward=not hparams.model_params["freeze_vision_layers"],
|
1370 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1371 |
+
)
|
1372 |
+
if self.config.use_resampler:
|
1373 |
+
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
|
1374 |
+
num_layers=self.config.resampler_depth,
|
1375 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1376 |
+
q_seq_len=self.config.resampler_n_latents,
|
1377 |
+
vision_embed_seq_len=single_image_seq_len,
|
1378 |
+
q_k_v_input_dim=vision_hidden_size,
|
1379 |
+
attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
|
1380 |
+
ff_exp_factor=cross_att_exp_factor,
|
1381 |
+
count_backward=True,
|
1382 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1383 |
+
)
|
1384 |
+
flop_count = (
|
1385 |
+
language_tflops_per_batch_per_gpu
|
1386 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1387 |
+
+ vision_tflops_per_batch_per_gpu
|
1388 |
+
+ perceiver_tflops_per_batch_per_gpu
|
1389 |
+
)
|
1390 |
+
else:
|
1391 |
+
flop_count = (
|
1392 |
+
language_tflops_per_batch_per_gpu
|
1393 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1394 |
+
+ vision_tflops_per_batch_per_gpu
|
1395 |
+
)
|
1396 |
+
return flop_count
|
m4/models/vgpt2/__init__.py
ADDED
File without changes
|
m4/models/vgpt2/configuration_vgpt2.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" OpenAI GPT-2 configuration"""
|
17 |
+
import os
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
from transformers import AutoConfig
|
21 |
+
from transformers.configuration_utils import PretrainedConfig
|
22 |
+
from transformers.utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
28 |
+
"gpt2": "https://huggingface.co/gpt2/resolve/main/config.json",
|
29 |
+
"gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json",
|
30 |
+
"gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json",
|
31 |
+
"gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json",
|
32 |
+
"distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json",
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class VGPT2Config(PretrainedConfig):
|
37 |
+
"""
|
38 |
+
This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
|
39 |
+
instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
|
40 |
+
configuration with the defaults will yield a similar configuration to that of the GPT-2
|
41 |
+
[gpt2](https://huggingface.co/gpt2) architecture.
|
42 |
+
|
43 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
44 |
+
documentation from [`PretrainedConfig`] for more information.
|
45 |
+
|
46 |
+
TODO: this doc is completely out of sync with the actual args
|
47 |
+
|
48 |
+
Args:
|
49 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
50 |
+
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
51 |
+
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
|
52 |
+
additional_vocab_size (`int`, *optional`, defaults to 0):
|
53 |
+
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
54 |
+
are always trainable whereas regular vocab tokens can be frozen or not.
|
55 |
+
n_positions (`int`, *optional*, defaults to 1024):
|
56 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
57 |
+
just in case (e.g., 512 or 1024 or 2048).
|
58 |
+
n_embd (`int`, *optional*, defaults to 768):
|
59 |
+
Dimensionality of the embeddings and hidden states.
|
60 |
+
n_layer (`int`, *optional*, defaults to 12):
|
61 |
+
Number of hidden layers in the Transformer encoder.
|
62 |
+
n_head (`int`, *optional*, defaults to 12):
|
63 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
64 |
+
n_inner (`int`, *optional*, defaults to None):
|
65 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
66 |
+
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
67 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
68 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
69 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
70 |
+
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
71 |
+
The dropout ratio for the embeddings.
|
72 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
73 |
+
The dropout ratio for the attention.
|
74 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
75 |
+
The epsilon to use in the layer normalization layers.
|
76 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
77 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
78 |
+
alpha_initializer (`str`, *optional*, defaults to `"ones"`):
|
79 |
+
Initialization type for the alphas.
|
80 |
+
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
81 |
+
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
|
82 |
+
alpha_type (`str`, *optional*, defaults to `"vector"`):
|
83 |
+
Whether the gating alphas should be vectors or single floats.
|
84 |
+
summary_type (`string`, *optional*, defaults to `"cls_index"`):
|
85 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
86 |
+
[`TFGPT2DoubleHeadsModel`].
|
87 |
+
|
88 |
+
Has to be one of the following options:
|
89 |
+
|
90 |
+
- `"last"`: Take the last token hidden state (like XLNet).
|
91 |
+
- `"first"`: Take the first token hidden state (like BERT).
|
92 |
+
- `"mean"`: Take the mean of all tokens hidden states.
|
93 |
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
94 |
+
- `"attn"`: Not implemented now, use multi-head attention.
|
95 |
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
96 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
97 |
+
[`TFGPT2DoubleHeadsModel`].
|
98 |
+
|
99 |
+
Whether or not to add a projection after the vector extraction.
|
100 |
+
summary_activation (`str`, *optional*):
|
101 |
+
Argument used when doing sequence summary. Used in for the multiple choice head in
|
102 |
+
[`GPT2DoubleHeadsModel`].
|
103 |
+
|
104 |
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
105 |
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
106 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
107 |
+
[`TFGPT2DoubleHeadsModel`].
|
108 |
+
|
109 |
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
110 |
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
111 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
112 |
+
[`TFGPT2DoubleHeadsModel`].
|
113 |
+
|
114 |
+
The dropout ratio to be used after the projection and activation.
|
115 |
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
116 |
+
Scale attention weights by dividing by sqrt(hidden_size)..
|
117 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
118 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
119 |
+
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
120 |
+
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
121 |
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
122 |
+
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
123 |
+
dot-product/softmax to float() when training with mixed precision.
|
124 |
+
cross_layer_interval (`int`, *optional*, default to 1)
|
125 |
+
Interval for cross attention (from text to image) layers.
|
126 |
+
|
127 |
+
Example:
|
128 |
+
|
129 |
+
```python
|
130 |
+
>>> from transformers import GPT2Model, GPT2Config
|
131 |
+
|
132 |
+
>>> # Initializing a GPT2 configuration
|
133 |
+
>>> configuration = GPT2Config()
|
134 |
+
|
135 |
+
>>> # Initializing a model from the configuration
|
136 |
+
>>> model = GPT2Model(configuration)
|
137 |
+
|
138 |
+
>>> # Accessing the model configuration
|
139 |
+
>>> configuration = model.config
|
140 |
+
```"""
|
141 |
+
|
142 |
+
model_type = "vgpt2"
|
143 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
144 |
+
attribute_map = {
|
145 |
+
"hidden_size": "n_embd",
|
146 |
+
"max_position_embeddings": "n_positions",
|
147 |
+
"num_attention_heads": "n_head",
|
148 |
+
"num_hidden_layers": "n_layer",
|
149 |
+
}
|
150 |
+
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
vocab_size=50257,
|
154 |
+
additional_vocab_size=0,
|
155 |
+
n_positions=1024,
|
156 |
+
n_embd=768,
|
157 |
+
n_layer=12,
|
158 |
+
n_head=12,
|
159 |
+
n_inner=None,
|
160 |
+
activation_function="gelu_new",
|
161 |
+
resid_pdrop=0.1,
|
162 |
+
embd_pdrop=0.1,
|
163 |
+
attn_pdrop=0.1,
|
164 |
+
layer_norm_epsilon=1e-5,
|
165 |
+
initializer_range=0.02,
|
166 |
+
alpha_initializer="ones",
|
167 |
+
alphas_initializer_range=0.0,
|
168 |
+
alpha_type="vector",
|
169 |
+
summary_type="cls_index",
|
170 |
+
summary_use_proj=True,
|
171 |
+
summary_activation=None,
|
172 |
+
summary_proj_to_labels=True,
|
173 |
+
summary_first_dropout=0.1,
|
174 |
+
scale_attn_weights=True,
|
175 |
+
use_cache=True,
|
176 |
+
bos_token_id=50256,
|
177 |
+
eos_token_id=50256,
|
178 |
+
scale_attn_by_inverse_layer_idx=False,
|
179 |
+
reorder_and_upcast_attn=False,
|
180 |
+
cross_layer_interval=1,
|
181 |
+
tie_word_embeddings=False,
|
182 |
+
freeze_text_layers=True,
|
183 |
+
freeze_lm_head=False,
|
184 |
+
freeze_vision_layers=True,
|
185 |
+
vision_model_name="google/vit-base-patch16-224",
|
186 |
+
vision_model_params="{}",
|
187 |
+
vision_embed_dim=768,
|
188 |
+
vision_image_size=224,
|
189 |
+
image_token_index=50257,
|
190 |
+
use_resampler=False,
|
191 |
+
resampler_n_latents=64,
|
192 |
+
resampler_depth=6,
|
193 |
+
resampler_n_heads=16,
|
194 |
+
resampler_head_dim=96,
|
195 |
+
**kwargs,
|
196 |
+
):
|
197 |
+
self.vocab_size = vocab_size
|
198 |
+
self.additional_vocab_size = additional_vocab_size
|
199 |
+
self.n_positions = n_positions
|
200 |
+
self.n_embd = n_embd
|
201 |
+
self.n_layer = n_layer
|
202 |
+
self.n_head = n_head
|
203 |
+
self.n_inner = n_inner
|
204 |
+
self.activation_function = activation_function
|
205 |
+
self.resid_pdrop = resid_pdrop
|
206 |
+
self.embd_pdrop = embd_pdrop
|
207 |
+
self.attn_pdrop = attn_pdrop
|
208 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
209 |
+
self.initializer_range = initializer_range
|
210 |
+
self.alpha_initializer = alpha_initializer
|
211 |
+
self.alphas_initializer_range = alphas_initializer_range
|
212 |
+
self.alpha_type = alpha_type
|
213 |
+
self.summary_type = summary_type
|
214 |
+
self.summary_use_proj = summary_use_proj
|
215 |
+
self.summary_activation = summary_activation
|
216 |
+
self.summary_first_dropout = summary_first_dropout
|
217 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
218 |
+
self.scale_attn_weights = scale_attn_weights
|
219 |
+
self.use_cache = use_cache
|
220 |
+
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
221 |
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
222 |
+
|
223 |
+
self.bos_token_id = bos_token_id
|
224 |
+
self.eos_token_id = eos_token_id
|
225 |
+
self.cross_layer_interval = cross_layer_interval
|
226 |
+
self.freeze_vision_layers = freeze_vision_layers
|
227 |
+
self.vision_model_name = vision_model_name
|
228 |
+
self.vision_model_params = vision_model_params
|
229 |
+
|
230 |
+
self.tie_word_embeddings = tie_word_embeddings
|
231 |
+
self.freeze_text_layers = freeze_text_layers
|
232 |
+
self.freeze_lm_head = freeze_lm_head
|
233 |
+
self.image_token_index = image_token_index
|
234 |
+
|
235 |
+
self.vision_embed_dim = vision_embed_dim
|
236 |
+
self.vision_image_size = vision_image_size
|
237 |
+
|
238 |
+
# Resampler params
|
239 |
+
self.use_resampler = use_resampler
|
240 |
+
self.resampler_n_latents = resampler_n_latents
|
241 |
+
self.resampler_depth = resampler_depth
|
242 |
+
self.resampler_n_heads = resampler_n_heads
|
243 |
+
self.resampler_head_dim = resampler_head_dim
|
244 |
+
|
245 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
246 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
247 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
248 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
249 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
250 |
+
|
251 |
+
super().__init__(
|
252 |
+
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
253 |
+
)
|
254 |
+
|
255 |
+
def check_compatibilities(self):
|
256 |
+
if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
|
257 |
+
raise ValueError(
|
258 |
+
"if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
|
259 |
+
)
|
260 |
+
|
261 |
+
vision_model_params = eval(self.vision_model_params)
|
262 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
263 |
+
if hasattr(config, "vision_config"):
|
264 |
+
vision_config = config.vision_config
|
265 |
+
else:
|
266 |
+
vision_config = config
|
267 |
+
vision_embed_dim = vision_config.hidden_size
|
268 |
+
if self.vision_embed_dim != vision_embed_dim:
|
269 |
+
raise ValueError(
|
270 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
271 |
+
f" ({vision_embed_dim})"
|
272 |
+
)
|
273 |
+
vision_image_size = vision_config.image_size
|
274 |
+
if self.vision_image_size != vision_image_size:
|
275 |
+
raise ValueError(
|
276 |
+
f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
|
277 |
+
f" ({vision_image_size})"
|
278 |
+
)
|
279 |
+
|
280 |
+
@classmethod
|
281 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
282 |
+
outputs = super(VGPT2Config, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
283 |
+
if isinstance(outputs, Tuple):
|
284 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
285 |
+
outputs[0].check_compatibilities()
|
286 |
+
else:
|
287 |
+
outputs.check_compatibilities()
|
288 |
+
return outputs
|
m4/models/vgpt2/modeling_vgpt2.py
ADDED
@@ -0,0 +1,1384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch OpenAI GPT-2 model."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
from typing import Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.utils.checkpoint
|
24 |
+
from torch import nn
|
25 |
+
from torch.cuda.amp import autocast
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
29 |
+
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
30 |
+
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
31 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
32 |
+
|
33 |
+
from m4.models import DecoupledEmbedding, DecoupledLinear
|
34 |
+
from m4.models.common import (
|
35 |
+
expand_inputs_for_generation,
|
36 |
+
prepare_inputs_for_generation,
|
37 |
+
update_model_kwargs_for_generation,
|
38 |
+
)
|
39 |
+
from m4.models.custom_modules import VLOOMPreTrainedModelBase
|
40 |
+
from m4.models.perceiver.perceiver import PerceiverResampler
|
41 |
+
from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
|
42 |
+
from m4.training.utils import (
|
43 |
+
compute_perceiver_tflops_per_batch_per_gpu,
|
44 |
+
compute_tflops_per_batch_per_gpu,
|
45 |
+
freeze_model,
|
46 |
+
)
|
47 |
+
from m4.utils import logging
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
_CHECKPOINT_FOR_DOC = "gpt2"
|
53 |
+
_CONFIG_FOR_DOC = "VGPT2Config"
|
54 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
55 |
+
|
56 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
57 |
+
"gpt2",
|
58 |
+
"gpt2-medium",
|
59 |
+
"gpt2-large",
|
60 |
+
"gpt2-xl",
|
61 |
+
"distilgpt2",
|
62 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
67 |
+
"""Load tf checkpoints in a pytorch model"""
|
68 |
+
try:
|
69 |
+
import re
|
70 |
+
|
71 |
+
import tensorflow as tf
|
72 |
+
except ImportError:
|
73 |
+
logger.error(
|
74 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
75 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
76 |
+
)
|
77 |
+
raise
|
78 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
79 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
80 |
+
# Load weights from TF model
|
81 |
+
init_vars = tf.train.list_variables(tf_path)
|
82 |
+
names = []
|
83 |
+
arrays = []
|
84 |
+
for name, shape in init_vars:
|
85 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
86 |
+
array = tf.train.load_variable(tf_path, name)
|
87 |
+
names.append(name)
|
88 |
+
arrays.append(array.squeeze())
|
89 |
+
|
90 |
+
for name, array in zip(names, arrays):
|
91 |
+
name = name[6:] # skip "model/"
|
92 |
+
name = name.split("/")
|
93 |
+
pointer = model
|
94 |
+
for m_name in name:
|
95 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
96 |
+
scope_names = re.split(r"(\d+)", m_name)
|
97 |
+
else:
|
98 |
+
scope_names = [m_name]
|
99 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
100 |
+
pointer = getattr(pointer, "weight")
|
101 |
+
elif scope_names[0] == "b":
|
102 |
+
pointer = getattr(pointer, "bias")
|
103 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
104 |
+
pointer = getattr(pointer, scope_names[0])
|
105 |
+
pointer = getattr(pointer, "weight")
|
106 |
+
else:
|
107 |
+
pointer = getattr(pointer, scope_names[0])
|
108 |
+
if len(scope_names) >= 2:
|
109 |
+
num = int(scope_names[1])
|
110 |
+
pointer = pointer[num]
|
111 |
+
try:
|
112 |
+
assert (
|
113 |
+
pointer.shape == array.shape
|
114 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
115 |
+
except AssertionError as e:
|
116 |
+
e.args += (pointer.shape, array.shape)
|
117 |
+
raise
|
118 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
119 |
+
pointer.data = torch.from_numpy(array)
|
120 |
+
return model
|
121 |
+
|
122 |
+
|
123 |
+
class GPT2Attention(nn.Module):
|
124 |
+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
max_positions = config.max_position_embeddings
|
128 |
+
self.register_buffer(
|
129 |
+
"bias",
|
130 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
131 |
+
1, 1, max_positions, max_positions
|
132 |
+
),
|
133 |
+
)
|
134 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
135 |
+
|
136 |
+
self.embed_dim = config.hidden_size
|
137 |
+
self.num_heads = config.num_attention_heads
|
138 |
+
self.head_dim = self.embed_dim // self.num_heads
|
139 |
+
self.split_size = self.embed_dim
|
140 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
141 |
+
raise ValueError(
|
142 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
143 |
+
f" {self.num_heads})."
|
144 |
+
)
|
145 |
+
|
146 |
+
self.scale_attn_weights = config.scale_attn_weights
|
147 |
+
self.is_cross_attention = is_cross_attention
|
148 |
+
|
149 |
+
# Layer-wise attention scaling, reordering, and upcasting
|
150 |
+
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
151 |
+
self.layer_idx = layer_idx
|
152 |
+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
153 |
+
|
154 |
+
if self.is_cross_attention:
|
155 |
+
in_dim = self.embed_dim if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
|
156 |
+
self.c_attn = Conv1D(2 * self.embed_dim, in_dim)
|
157 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
158 |
+
else:
|
159 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
160 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
161 |
+
|
162 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
163 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
164 |
+
|
165 |
+
self.pruned_heads = set()
|
166 |
+
|
167 |
+
def prune_heads(self, heads):
|
168 |
+
if len(heads) == 0:
|
169 |
+
return
|
170 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
171 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
172 |
+
|
173 |
+
# Prune conv1d layers
|
174 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
175 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
176 |
+
|
177 |
+
# Update hyper params
|
178 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
179 |
+
self.num_heads = self.num_heads - len(heads)
|
180 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
181 |
+
|
182 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
183 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
184 |
+
|
185 |
+
if self.scale_attn_weights:
|
186 |
+
attn_weights = attn_weights / torch.tensor(
|
187 |
+
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
188 |
+
)
|
189 |
+
|
190 |
+
# Layer-wise attention scaling
|
191 |
+
if self.scale_attn_by_inverse_layer_idx:
|
192 |
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
193 |
+
|
194 |
+
if not self.is_cross_attention:
|
195 |
+
# if only "normal" attention layer implements causal mask
|
196 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
197 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
198 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
199 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
200 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
201 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
202 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
203 |
+
|
204 |
+
if attention_mask is not None:
|
205 |
+
# Apply the attention mask
|
206 |
+
attn_weights = attn_weights + attention_mask
|
207 |
+
|
208 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
209 |
+
|
210 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
211 |
+
attn_weights = attn_weights.type(value.dtype)
|
212 |
+
attn_weights = self.attn_dropout(attn_weights)
|
213 |
+
|
214 |
+
# Mask heads if we want to
|
215 |
+
if head_mask is not None:
|
216 |
+
attn_weights = attn_weights * head_mask
|
217 |
+
|
218 |
+
attn_output = torch.matmul(attn_weights, value)
|
219 |
+
|
220 |
+
return attn_output, attn_weights
|
221 |
+
|
222 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
223 |
+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
224 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
225 |
+
_, _, k_seq_len, _ = key.size()
|
226 |
+
|
227 |
+
# Preallocate attn_weights for `baddbmm`
|
228 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
229 |
+
|
230 |
+
# Compute Scale Factor
|
231 |
+
scale_factor = 1.0
|
232 |
+
if self.scale_attn_weights:
|
233 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
234 |
+
|
235 |
+
if self.scale_attn_by_inverse_layer_idx:
|
236 |
+
scale_factor /= float(self.layer_idx + 1)
|
237 |
+
|
238 |
+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
239 |
+
with autocast(enabled=False):
|
240 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
241 |
+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
242 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
243 |
+
|
244 |
+
if not self.is_cross_attention:
|
245 |
+
# if only "normal" attention layer implements causal mask
|
246 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
247 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
248 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
249 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
250 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
251 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
252 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
253 |
+
|
254 |
+
if attention_mask is not None:
|
255 |
+
# Apply the attention mask
|
256 |
+
attn_weights = attn_weights + attention_mask
|
257 |
+
|
258 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
259 |
+
|
260 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
261 |
+
if attn_weights.dtype != torch.float32:
|
262 |
+
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
263 |
+
attn_weights = attn_weights.type(value.dtype)
|
264 |
+
attn_weights = self.attn_dropout(attn_weights)
|
265 |
+
|
266 |
+
# Mask heads if we want to
|
267 |
+
if head_mask is not None:
|
268 |
+
attn_weights = attn_weights * head_mask
|
269 |
+
|
270 |
+
attn_output = torch.matmul(attn_weights, value)
|
271 |
+
|
272 |
+
return attn_output, attn_weights
|
273 |
+
|
274 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
275 |
+
"""
|
276 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
277 |
+
"""
|
278 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
279 |
+
tensor = tensor.view(new_shape)
|
280 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
281 |
+
|
282 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
283 |
+
"""
|
284 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
285 |
+
"""
|
286 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
287 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
288 |
+
return tensor.view(new_shape)
|
289 |
+
|
290 |
+
def forward(
|
291 |
+
self,
|
292 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
293 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
294 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
295 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
296 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
297 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
298 |
+
use_cache: Optional[bool] = False,
|
299 |
+
output_attentions: Optional[bool] = False,
|
300 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
301 |
+
if encoder_hidden_states is not None:
|
302 |
+
if not hasattr(self, "q_attn"):
|
303 |
+
raise ValueError(
|
304 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
305 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
306 |
+
)
|
307 |
+
|
308 |
+
query = self.q_attn(hidden_states)
|
309 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
310 |
+
attention_mask = encoder_attention_mask
|
311 |
+
else:
|
312 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
313 |
+
|
314 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
315 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
316 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
317 |
+
|
318 |
+
if layer_past is not None:
|
319 |
+
past_key, past_value = layer_past
|
320 |
+
key = torch.cat((past_key, key), dim=-2)
|
321 |
+
value = torch.cat((past_value, value), dim=-2)
|
322 |
+
|
323 |
+
if use_cache is True:
|
324 |
+
present = (key, value)
|
325 |
+
else:
|
326 |
+
present = None
|
327 |
+
|
328 |
+
if self.reorder_and_upcast_attn:
|
329 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
330 |
+
else:
|
331 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
332 |
+
|
333 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
334 |
+
attn_output = self.c_proj(attn_output)
|
335 |
+
attn_output = self.resid_dropout(attn_output)
|
336 |
+
|
337 |
+
outputs = (attn_output, present)
|
338 |
+
if output_attentions:
|
339 |
+
outputs += (attn_weights,)
|
340 |
+
|
341 |
+
return outputs # a, present, (attentions)
|
342 |
+
|
343 |
+
|
344 |
+
class GPT2MLP(nn.Module):
|
345 |
+
def __init__(self, intermediate_size, config):
|
346 |
+
super().__init__()
|
347 |
+
embed_dim = config.hidden_size
|
348 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
349 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
350 |
+
self.act = ACT2FN[config.activation_function]
|
351 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
352 |
+
|
353 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
354 |
+
hidden_states = self.c_fc(hidden_states)
|
355 |
+
hidden_states = self.act(hidden_states)
|
356 |
+
hidden_states = self.c_proj(hidden_states)
|
357 |
+
hidden_states = self.dropout(hidden_states)
|
358 |
+
return hidden_states
|
359 |
+
|
360 |
+
|
361 |
+
class GPT2Block(nn.Module):
|
362 |
+
def __init__(self, config, layer_idx=None):
|
363 |
+
super().__init__()
|
364 |
+
hidden_size = config.hidden_size
|
365 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
366 |
+
|
367 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
368 |
+
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
369 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
370 |
+
|
371 |
+
if config.add_cross_attention:
|
372 |
+
self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
|
373 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
374 |
+
|
375 |
+
self.mlp = GPT2MLP(inner_dim, config)
|
376 |
+
|
377 |
+
def forward(
|
378 |
+
self,
|
379 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
380 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
381 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
382 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
383 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
384 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
385 |
+
use_cache: Optional[bool] = False,
|
386 |
+
output_attentions: Optional[bool] = False,
|
387 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
388 |
+
residual = hidden_states
|
389 |
+
hidden_states = self.ln_1(hidden_states)
|
390 |
+
attn_outputs = self.attn(
|
391 |
+
hidden_states,
|
392 |
+
layer_past=layer_past,
|
393 |
+
attention_mask=attention_mask,
|
394 |
+
head_mask=head_mask,
|
395 |
+
use_cache=use_cache,
|
396 |
+
output_attentions=output_attentions,
|
397 |
+
)
|
398 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
399 |
+
outputs = attn_outputs[1:]
|
400 |
+
# residual connection
|
401 |
+
hidden_states = attn_output + residual
|
402 |
+
|
403 |
+
if encoder_hidden_states is not None:
|
404 |
+
# add one self-attention block for cross-attention
|
405 |
+
if not hasattr(self, "crossattention"):
|
406 |
+
raise ValueError(
|
407 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
408 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
409 |
+
)
|
410 |
+
residual = hidden_states
|
411 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
412 |
+
cross_attn_outputs = self.crossattention(
|
413 |
+
hidden_states,
|
414 |
+
attention_mask=attention_mask,
|
415 |
+
head_mask=head_mask,
|
416 |
+
encoder_hidden_states=encoder_hidden_states,
|
417 |
+
encoder_attention_mask=encoder_attention_mask,
|
418 |
+
output_attentions=output_attentions,
|
419 |
+
)
|
420 |
+
attn_output = cross_attn_outputs[0]
|
421 |
+
# residual connection
|
422 |
+
hidden_states = residual + attn_output
|
423 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
424 |
+
|
425 |
+
residual = hidden_states
|
426 |
+
hidden_states = self.ln_2(hidden_states)
|
427 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
428 |
+
# residual connection
|
429 |
+
hidden_states = residual + feed_forward_hidden_states
|
430 |
+
|
431 |
+
if use_cache:
|
432 |
+
outputs = (hidden_states,) + outputs
|
433 |
+
else:
|
434 |
+
outputs = (hidden_states,) + outputs[1:]
|
435 |
+
|
436 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
437 |
+
|
438 |
+
|
439 |
+
class VGPT2GatedCrossAttentionBlock(nn.Module):
|
440 |
+
def __init__(self, config, layer_idx=None):
|
441 |
+
super().__init__()
|
442 |
+
hidden_size = config.hidden_size
|
443 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
444 |
+
|
445 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
446 |
+
self.cross_attn = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
|
447 |
+
self.mlp = GPT2MLP(inner_dim, config)
|
448 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
449 |
+
self.act = nn.Tanh()
|
450 |
+
|
451 |
+
if config.alpha_initializer == "zeros":
|
452 |
+
if config.alpha_type == "vector":
|
453 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
454 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
455 |
+
elif config.alpha_type == "float":
|
456 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
|
457 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1))
|
458 |
+
else:
|
459 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
460 |
+
|
461 |
+
elif config.alpha_initializer == "ones":
|
462 |
+
if config.alpha_type == "vector":
|
463 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, hidden_size))
|
464 |
+
self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
|
465 |
+
elif config.alpha_type == "float":
|
466 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1))
|
467 |
+
self.alpha_dense = nn.Parameter(torch.ones(1))
|
468 |
+
else:
|
469 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
470 |
+
|
471 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
472 |
+
if config.alpha_type == "vector":
|
473 |
+
self.alpha_cross_attn = nn.Parameter(
|
474 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
|
475 |
+
)
|
476 |
+
self.alpha_dense = nn.Parameter(
|
477 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
|
478 |
+
)
|
479 |
+
elif config.alpha_type == "float":
|
480 |
+
self.alpha_cross_attn = nn.Parameter(
|
481 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
482 |
+
)
|
483 |
+
self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
|
484 |
+
else:
|
485 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
486 |
+
|
487 |
+
else:
|
488 |
+
raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
|
489 |
+
|
490 |
+
def forward(
|
491 |
+
self,
|
492 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
493 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
494 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
495 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
496 |
+
image_hidden_states: Optional[torch.Tensor] = None,
|
497 |
+
image_attention_mask: Optional[torch.FloatTensor] = None,
|
498 |
+
use_cache: Optional[bool] = False,
|
499 |
+
output_attentions: Optional[bool] = False,
|
500 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
501 |
+
if image_hidden_states is None:
|
502 |
+
raise ValueError(
|
503 |
+
"`image_hidden_states` is required for VGPT2 cross attention module which are visual features to be"
|
504 |
+
" conditioned on."
|
505 |
+
)
|
506 |
+
# add one self-attention block for cross-attention
|
507 |
+
|
508 |
+
# TODO(aps): Handle cross attention in the outputs
|
509 |
+
# if not hasattr(self, "crossattention"):
|
510 |
+
# raise ValueError(
|
511 |
+
# f"If `image_hidden_states` are passed, {self} has to be instantiated with "
|
512 |
+
# "cross-attention layers by setting `config.add_cross_attention=True`"
|
513 |
+
# )
|
514 |
+
residual = hidden_states
|
515 |
+
|
516 |
+
hidden_states = self.ln_1(hidden_states)
|
517 |
+
cross_attn_outputs = self.cross_attn(
|
518 |
+
hidden_states,
|
519 |
+
attention_mask=attention_mask,
|
520 |
+
head_mask=head_mask,
|
521 |
+
encoder_hidden_states=image_hidden_states,
|
522 |
+
encoder_attention_mask=image_attention_mask,
|
523 |
+
output_attentions=output_attentions,
|
524 |
+
)
|
525 |
+
attn_output = cross_attn_outputs[0]
|
526 |
+
outputs = cross_attn_outputs[1:]
|
527 |
+
# residual connection
|
528 |
+
hidden_states = residual + self.act(self.alpha_cross_attn) * attn_output
|
529 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
530 |
+
|
531 |
+
residual = hidden_states
|
532 |
+
hidden_states = self.ln_2(hidden_states)
|
533 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
534 |
+
# residual connection
|
535 |
+
hidden_states = residual + self.act(self.alpha_dense) * feed_forward_hidden_states
|
536 |
+
|
537 |
+
if use_cache:
|
538 |
+
outputs = (hidden_states,) + outputs
|
539 |
+
else:
|
540 |
+
outputs = (hidden_states,) + outputs[1:]
|
541 |
+
|
542 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
543 |
+
|
544 |
+
|
545 |
+
class VGPT2PreTrainedModel(VLOOMPreTrainedModelBase):
|
546 |
+
"""
|
547 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
548 |
+
models.
|
549 |
+
"""
|
550 |
+
|
551 |
+
config_class = VGPT2Config
|
552 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
553 |
+
base_model_prefix = "transformer"
|
554 |
+
is_parallelizable = True
|
555 |
+
supports_gradient_checkpointing = True
|
556 |
+
_no_split_modules = ["GPT2Block"]
|
557 |
+
|
558 |
+
def __init__(self, *inputs, **kwargs):
|
559 |
+
super().__init__(*inputs, **kwargs)
|
560 |
+
|
561 |
+
def _init_weights(self, module):
|
562 |
+
"""Initialize the weights."""
|
563 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
564 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
565 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
566 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
567 |
+
if module.bias is not None:
|
568 |
+
module.bias.data.zero_()
|
569 |
+
elif isinstance(module, nn.Embedding):
|
570 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
571 |
+
if module.padding_idx is not None:
|
572 |
+
module.weight.data[module.padding_idx].zero_()
|
573 |
+
elif isinstance(module, nn.LayerNorm):
|
574 |
+
module.bias.data.zero_()
|
575 |
+
module.weight.data.fill_(1.0)
|
576 |
+
|
577 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
578 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
579 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
580 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
581 |
+
#
|
582 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
583 |
+
for name, p in module.named_parameters():
|
584 |
+
if name == "c_proj.weight":
|
585 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
586 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
587 |
+
|
588 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
589 |
+
if isinstance(module, VGPT2Model):
|
590 |
+
module.gradient_checkpointing = value
|
591 |
+
|
592 |
+
@classmethod
|
593 |
+
def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
|
594 |
+
# this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
|
595 |
+
beheaded_model = model.transformer if hasattr(model, "transformer") else model
|
596 |
+
cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
|
597 |
+
beheaded_model.freeze_relevant_params(config)
|
598 |
+
|
599 |
+
|
600 |
+
GPT2_START_DOCSTRING = r"""
|
601 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
602 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
603 |
+
etc.)
|
604 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
605 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
606 |
+
and behavior.
|
607 |
+
Parameters:
|
608 |
+
config ([`VGPT2Config`]): Model configuration class with all the parameters of the model.
|
609 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
610 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
611 |
+
"""
|
612 |
+
|
613 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
614 |
+
Args:
|
615 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
616 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
617 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
618 |
+
sequence tokens in the vocabulary.
|
619 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
620 |
+
`input_ids`.
|
621 |
+
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
622 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
623 |
+
[What are input IDs?](../glossary#input-ids)
|
624 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
625 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
626 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
627 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
628 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
629 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
630 |
+
- 1 for tokens that are **not masked**,
|
631 |
+
- 0 for tokens that are **masked**.
|
632 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
633 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
634 |
+
`len(past_key_values) + len(input_ids)`
|
635 |
+
[What are attention masks?](../glossary#attention-mask)
|
636 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
637 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
638 |
+
1]`:
|
639 |
+
- 0 corresponds to a *sentence A* token,
|
640 |
+
- 1 corresponds to a *sentence B* token.
|
641 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
642 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
643 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
644 |
+
config.max_position_embeddings - 1]`.
|
645 |
+
[What are position IDs?](../glossary#position-ids)
|
646 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
647 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
648 |
+
- 1 indicates the head is **not masked**,
|
649 |
+
- 0 indicates the head is **masked**.
|
650 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
651 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
652 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
653 |
+
model's internal embedding lookup matrix.
|
654 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
655 |
+
`past_key_values`).
|
656 |
+
use_cache (`bool`, *optional*):
|
657 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
658 |
+
`past_key_values`).
|
659 |
+
output_attentions (`bool`, *optional*):
|
660 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
661 |
+
tensors for more detail.
|
662 |
+
output_hidden_states (`bool`, *optional*):
|
663 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
664 |
+
more detail.
|
665 |
+
return_dict (`bool`, *optional*):
|
666 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
667 |
+
"""
|
668 |
+
PARALLELIZE_DOCSTRING = r"""
|
669 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
670 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
671 |
+
it will evenly distribute blocks across all devices.
|
672 |
+
Args:
|
673 |
+
device_map (`Dict[int, list]`, optional, defaults to None):
|
674 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
675 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
676 |
+
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
677 |
+
following number of attention modules:
|
678 |
+
- gpt2: 12
|
679 |
+
- gpt2-medium: 24
|
680 |
+
- gpt2-large: 36
|
681 |
+
- gpt2-xl: 48
|
682 |
+
Example:
|
683 |
+
```python
|
684 |
+
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
685 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
|
686 |
+
device_map = {
|
687 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
688 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
689 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
690 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
|
691 |
+
}
|
692 |
+
model.parallelize(device_map)
|
693 |
+
```
|
694 |
+
"""
|
695 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
696 |
+
Moves the model to cpu from a model parallel state.
|
697 |
+
Example:
|
698 |
+
```python
|
699 |
+
# On a 4 GPU machine with gpt2-large:
|
700 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2-large")
|
701 |
+
device_map = {
|
702 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7],
|
703 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
704 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
705 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
|
706 |
+
}
|
707 |
+
model.parallelize(device_map) # Splits the model across several devices
|
708 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
709 |
+
```
|
710 |
+
"""
|
711 |
+
|
712 |
+
|
713 |
+
@add_start_docstrings(
|
714 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
715 |
+
GPT2_START_DOCSTRING,
|
716 |
+
)
|
717 |
+
class VGPT2Model(VGPT2PreTrainedModel):
|
718 |
+
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
719 |
+
|
720 |
+
def __init__(self, config, vision_model=None):
|
721 |
+
super().__init__(config)
|
722 |
+
|
723 |
+
self.embed_dim = config.hidden_size
|
724 |
+
self.config = config
|
725 |
+
|
726 |
+
self.wte = DecoupledEmbedding(
|
727 |
+
num_embeddings=config.vocab_size,
|
728 |
+
num_additional_embeddings=config.additional_vocab_size,
|
729 |
+
embedding_dim=self.embed_dim,
|
730 |
+
partially_freeze=config.freeze_text_layers,
|
731 |
+
)
|
732 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
733 |
+
|
734 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
735 |
+
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
736 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
737 |
+
|
738 |
+
self.cross_layer_interval = config.cross_layer_interval
|
739 |
+
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
740 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
741 |
+
[VGPT2GatedCrossAttentionBlock(config, layer_idx=i) for i in range(num_cross_layers)]
|
742 |
+
)
|
743 |
+
|
744 |
+
# Perceiver Resampler
|
745 |
+
if config.use_resampler:
|
746 |
+
self.perceiver_resampler = PerceiverResampler(
|
747 |
+
self.config,
|
748 |
+
self.config.vision_embed_dim,
|
749 |
+
config.resampler_depth,
|
750 |
+
config.resampler_n_heads,
|
751 |
+
config.resampler_head_dim,
|
752 |
+
config.resampler_n_latents,
|
753 |
+
)
|
754 |
+
# Model parallel
|
755 |
+
self.model_parallel = False
|
756 |
+
self.device_map = None
|
757 |
+
self.gradient_checkpointing = False
|
758 |
+
# will be vocab_size because of indices starting from 0
|
759 |
+
self.image_token_idx = config.image_token_index
|
760 |
+
|
761 |
+
# Load an uninitialized model and later in from_pretrained will load the pre-trained model -
|
762 |
+
# this solves the losing of weights in `from_pretrained` on the main model
|
763 |
+
self.vision_model = vision_model
|
764 |
+
|
765 |
+
# Initialize weights and apply final processing
|
766 |
+
self.post_init()
|
767 |
+
|
768 |
+
self.freeze_relevant_params(config)
|
769 |
+
|
770 |
+
def freeze_relevant_params(self, config=None):
|
771 |
+
if config is None:
|
772 |
+
config = self.config
|
773 |
+
|
774 |
+
if config.freeze_text_layers:
|
775 |
+
self.freeze_text_layers()
|
776 |
+
|
777 |
+
if config.freeze_vision_layers:
|
778 |
+
freeze_model(self.vision_model)
|
779 |
+
|
780 |
+
def freeze_text_layers(self):
|
781 |
+
for module in [self.wpe, self.h, self.ln_f]:
|
782 |
+
freeze_model(module)
|
783 |
+
|
784 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
785 |
+
# TODO(aps): Implement later for VGPT2
|
786 |
+
def parallelize(self, device_map=None):
|
787 |
+
# Check validity of device_map
|
788 |
+
self.device_map = (
|
789 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
790 |
+
)
|
791 |
+
assert_device_map(self.device_map, len(self.h))
|
792 |
+
self.model_parallel = True
|
793 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
794 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
795 |
+
self.wte = self.wte.to(self.first_device)
|
796 |
+
self.wpe = self.wpe.to(self.first_device)
|
797 |
+
# Load onto devices
|
798 |
+
for k, v in self.device_map.items():
|
799 |
+
for block in v:
|
800 |
+
cuda_device = "cuda:" + str(k)
|
801 |
+
self.h[block] = self.h[block].to(cuda_device)
|
802 |
+
# ln_f to last
|
803 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
804 |
+
|
805 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
806 |
+
# TODO(aps): Implement later for VGPT2
|
807 |
+
def deparallelize(self):
|
808 |
+
self.model_parallel = False
|
809 |
+
self.device_map = None
|
810 |
+
self.first_device = "cpu"
|
811 |
+
self.last_device = "cpu"
|
812 |
+
self.wte = self.wte.to("cpu")
|
813 |
+
self.wpe = self.wpe.to("cpu")
|
814 |
+
for index in range(len(self.h)):
|
815 |
+
self.h[index] = self.h[index].to("cpu")
|
816 |
+
self.ln_f = self.ln_f.to("cpu")
|
817 |
+
torch.cuda.empty_cache()
|
818 |
+
|
819 |
+
def get_input_embeddings(self):
|
820 |
+
return self.wte
|
821 |
+
|
822 |
+
def set_input_embeddings(self, new_embeddings):
|
823 |
+
self.wte = new_embeddings
|
824 |
+
|
825 |
+
def _prune_heads(self, heads_to_prune):
|
826 |
+
"""
|
827 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
828 |
+
"""
|
829 |
+
for layer, heads in heads_to_prune.items():
|
830 |
+
self.h[layer].attn.prune_heads(heads)
|
831 |
+
|
832 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
833 |
+
@add_code_sample_docstrings(
|
834 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
835 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
836 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
837 |
+
config_class=_CONFIG_FOR_DOC,
|
838 |
+
)
|
839 |
+
def forward(
|
840 |
+
self,
|
841 |
+
input_ids: Optional[torch.LongTensor] = None,
|
842 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
843 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
844 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
845 |
+
position_ids: Optional[torch.LongTensor] = None,
|
846 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
847 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
848 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
849 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
850 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
851 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
852 |
+
use_cache: Optional[bool] = None,
|
853 |
+
output_attentions: Optional[bool] = None,
|
854 |
+
output_hidden_states: Optional[bool] = None,
|
855 |
+
return_dict: Optional[bool] = None,
|
856 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
857 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
858 |
+
output_hidden_states = (
|
859 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
860 |
+
)
|
861 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
862 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
863 |
+
|
864 |
+
if input_ids is not None and inputs_embeds is not None:
|
865 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
866 |
+
elif input_ids is not None:
|
867 |
+
input_shape = input_ids.size()
|
868 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
869 |
+
batch_size = input_ids.shape[0]
|
870 |
+
elif inputs_embeds is not None:
|
871 |
+
input_shape = inputs_embeds.size()[:-1]
|
872 |
+
batch_size = inputs_embeds.shape[0]
|
873 |
+
else:
|
874 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
875 |
+
|
876 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
877 |
+
|
878 |
+
if token_type_ids is not None:
|
879 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
880 |
+
if position_ids is not None:
|
881 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
882 |
+
|
883 |
+
if past_key_values is None:
|
884 |
+
past_length = 0
|
885 |
+
past_key_values = tuple([None] * len(self.h))
|
886 |
+
else:
|
887 |
+
past_length = past_key_values[0][0].size(-2)
|
888 |
+
if position_ids is None:
|
889 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
890 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
891 |
+
|
892 |
+
# GPT2Attention mask.
|
893 |
+
if attention_mask is not None:
|
894 |
+
if batch_size <= 0:
|
895 |
+
raise ValueError("batch_size has to be defined and > 0")
|
896 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
897 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
898 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
899 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
900 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
901 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
902 |
+
attention_mask = attention_mask[:, None, None, :]
|
903 |
+
|
904 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
905 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
906 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
907 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
908 |
+
# effectively the same as removing these entirely.
|
909 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
910 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
911 |
+
|
912 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
913 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
914 |
+
if pixel_values is not None and image_embeddings is not None:
|
915 |
+
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
916 |
+
elif pixel_values is not None:
|
917 |
+
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
918 |
+
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
919 |
+
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
920 |
+
# Get sequence from the vision encoder
|
921 |
+
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
922 |
+
elif image_embeddings is not None:
|
923 |
+
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
924 |
+
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
925 |
+
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
926 |
+
|
927 |
+
if self.config.use_resampler:
|
928 |
+
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
929 |
+
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
930 |
+
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
931 |
+
|
932 |
+
# Make image_attention_mask compatible with hidden states
|
933 |
+
text_seq_len = image_attention_mask.size(1)
|
934 |
+
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
935 |
+
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
936 |
+
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
937 |
+
if image_hidden_states is not None:
|
938 |
+
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
939 |
+
image_hidden_shape = (image_batch_size, image_sequence_length)
|
940 |
+
if image_attention_mask is None:
|
941 |
+
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
942 |
+
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
943 |
+
else:
|
944 |
+
image_attention_mask = None
|
945 |
+
|
946 |
+
# Prepare head mask if needed
|
947 |
+
# 1.0 in head_mask indicate we keep the head
|
948 |
+
# attention_probs has shape bsz x n_heads x N x N
|
949 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
950 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
951 |
+
|
952 |
+
if inputs_embeds is None:
|
953 |
+
inputs_embeds = self.wte(input_ids)
|
954 |
+
|
955 |
+
position_embeds = self.wpe(position_ids)
|
956 |
+
hidden_states = inputs_embeds + position_embeds
|
957 |
+
|
958 |
+
if token_type_ids is not None:
|
959 |
+
token_type_embeds = self.wte(token_type_ids)
|
960 |
+
hidden_states = hidden_states + token_type_embeds
|
961 |
+
|
962 |
+
hidden_states = self.drop(hidden_states)
|
963 |
+
|
964 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
965 |
+
|
966 |
+
presents = () if use_cache else None
|
967 |
+
all_self_attentions = () if output_attentions else None
|
968 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
969 |
+
all_hidden_states = () if output_hidden_states else None
|
970 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
971 |
+
# Model parallel
|
972 |
+
if self.model_parallel:
|
973 |
+
torch.cuda.set_device(hidden_states.device)
|
974 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
975 |
+
if layer_past is not None:
|
976 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
977 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
978 |
+
if attention_mask is not None:
|
979 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
980 |
+
if isinstance(head_mask, torch.Tensor):
|
981 |
+
head_mask = head_mask.to(hidden_states.device)
|
982 |
+
if output_hidden_states:
|
983 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
984 |
+
|
985 |
+
def vblock(
|
986 |
+
main_block,
|
987 |
+
hidden_states,
|
988 |
+
layer_past,
|
989 |
+
attention_mask,
|
990 |
+
layer_head_mask,
|
991 |
+
use_cache,
|
992 |
+
output_attentions,
|
993 |
+
image_hidden_states,
|
994 |
+
image_attention_mask,
|
995 |
+
layer_idx,
|
996 |
+
cross_layer_interval,
|
997 |
+
gated_cross_attn_layers,
|
998 |
+
):
|
999 |
+
# TODO(aps): Add cross attention values to respective lists
|
1000 |
+
# TODO(aps): Add xblock head mask support
|
1001 |
+
if layer_idx % cross_layer_interval == 0:
|
1002 |
+
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
1003 |
+
outputs = xblock(
|
1004 |
+
hidden_states,
|
1005 |
+
attention_mask=attention_mask,
|
1006 |
+
image_hidden_states=image_hidden_states,
|
1007 |
+
image_attention_mask=image_attention_mask,
|
1008 |
+
use_cache=use_cache,
|
1009 |
+
output_attentions=output_attentions,
|
1010 |
+
)
|
1011 |
+
hidden_states = outputs[0]
|
1012 |
+
|
1013 |
+
outputs = main_block(
|
1014 |
+
hidden_states,
|
1015 |
+
layer_past=layer_past,
|
1016 |
+
attention_mask=attention_mask,
|
1017 |
+
head_mask=layer_head_mask,
|
1018 |
+
use_cache=use_cache,
|
1019 |
+
output_attentions=output_attentions,
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
return outputs
|
1023 |
+
|
1024 |
+
if self.gradient_checkpointing and self.training:
|
1025 |
+
layer_past = None
|
1026 |
+
if use_cache:
|
1027 |
+
logger.warning_once(
|
1028 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1029 |
+
)
|
1030 |
+
use_cache = False
|
1031 |
+
|
1032 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
1033 |
+
vblock,
|
1034 |
+
block,
|
1035 |
+
hidden_states,
|
1036 |
+
layer_past,
|
1037 |
+
attention_mask,
|
1038 |
+
head_mask[i],
|
1039 |
+
use_cache,
|
1040 |
+
output_attentions,
|
1041 |
+
image_hidden_states,
|
1042 |
+
image_attention_mask,
|
1043 |
+
i,
|
1044 |
+
self.cross_layer_interval,
|
1045 |
+
self.gated_cross_attn_layers,
|
1046 |
+
)
|
1047 |
+
else:
|
1048 |
+
outputs = vblock(
|
1049 |
+
block,
|
1050 |
+
hidden_states,
|
1051 |
+
layer_past=layer_past,
|
1052 |
+
attention_mask=attention_mask,
|
1053 |
+
layer_head_mask=head_mask[i],
|
1054 |
+
use_cache=use_cache,
|
1055 |
+
output_attentions=output_attentions,
|
1056 |
+
image_hidden_states=image_hidden_states,
|
1057 |
+
image_attention_mask=image_attention_mask,
|
1058 |
+
layer_idx=i,
|
1059 |
+
cross_layer_interval=self.cross_layer_interval,
|
1060 |
+
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
hidden_states = outputs[0]
|
1064 |
+
if use_cache is True:
|
1065 |
+
presents = presents + (outputs[1],)
|
1066 |
+
|
1067 |
+
if output_attentions:
|
1068 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
1069 |
+
if self.config.add_cross_attention:
|
1070 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
1071 |
+
|
1072 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
1073 |
+
if self.model_parallel:
|
1074 |
+
for k, v in self.device_map.items():
|
1075 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
1076 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
1077 |
+
|
1078 |
+
hidden_states = self.ln_f(hidden_states)
|
1079 |
+
|
1080 |
+
hidden_states = hidden_states.view(output_shape)
|
1081 |
+
# Add last hidden state
|
1082 |
+
if output_hidden_states:
|
1083 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1084 |
+
|
1085 |
+
if not return_dict:
|
1086 |
+
return tuple(
|
1087 |
+
v
|
1088 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
1089 |
+
if v is not None
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1093 |
+
last_hidden_state=hidden_states,
|
1094 |
+
past_key_values=presents,
|
1095 |
+
hidden_states=all_hidden_states,
|
1096 |
+
attentions=all_self_attentions,
|
1097 |
+
cross_attentions=all_cross_attentions,
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
|
1101 |
+
@add_start_docstrings(
|
1102 |
+
"""
|
1103 |
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
1104 |
+
embeddings).
|
1105 |
+
""",
|
1106 |
+
GPT2_START_DOCSTRING,
|
1107 |
+
)
|
1108 |
+
class VGPT2LMHeadModel(VGPT2PreTrainedModel):
|
1109 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
1110 |
+
|
1111 |
+
def __init__(self, config, vision_model=None):
|
1112 |
+
super().__init__(config)
|
1113 |
+
self.transformer = VGPT2Model(config, vision_model=vision_model)
|
1114 |
+
self.lm_head = DecoupledLinear(
|
1115 |
+
in_features=config.n_embd,
|
1116 |
+
out_features=config.vocab_size,
|
1117 |
+
out_additional_features=config.additional_vocab_size,
|
1118 |
+
bias=False,
|
1119 |
+
partially_freeze=config.freeze_lm_head,
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
# Model parallel
|
1123 |
+
self.model_parallel = False
|
1124 |
+
self.device_map = None
|
1125 |
+
|
1126 |
+
# Initialize weights and apply final processing
|
1127 |
+
self.post_init()
|
1128 |
+
|
1129 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1130 |
+
def parallelize(self, device_map=None):
|
1131 |
+
self.device_map = (
|
1132 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1133 |
+
if device_map is None
|
1134 |
+
else device_map
|
1135 |
+
)
|
1136 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1137 |
+
self.transformer.parallelize(self.device_map)
|
1138 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1139 |
+
self.model_parallel = True
|
1140 |
+
|
1141 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1142 |
+
def deparallelize(self):
|
1143 |
+
self.transformer.deparallelize()
|
1144 |
+
self.transformer = self.transformer.to("cpu")
|
1145 |
+
self.lm_head = self.lm_head.to("cpu")
|
1146 |
+
self.model_parallel = False
|
1147 |
+
torch.cuda.empty_cache()
|
1148 |
+
|
1149 |
+
def get_output_embeddings(self):
|
1150 |
+
return self.lm_head
|
1151 |
+
|
1152 |
+
def set_output_embeddings(self, new_embeddings):
|
1153 |
+
self.lm_head = new_embeddings
|
1154 |
+
|
1155 |
+
def tie_weights(self):
|
1156 |
+
"""
|
1157 |
+
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
1158 |
+
"""
|
1159 |
+
output_embeddings = self.get_output_embeddings()
|
1160 |
+
input_embeddings = self.get_input_embeddings()
|
1161 |
+
|
1162 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
1163 |
+
output_embeddings.weight = input_embeddings.weight
|
1164 |
+
if input_embeddings.num_additional_embeddings > 0:
|
1165 |
+
assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
|
1166 |
+
output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
|
1167 |
+
|
1168 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
1169 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
1170 |
+
if hasattr(output_embeddings, "out_additional_features") and hasattr(
|
1171 |
+
input_embeddings, "num_additional_embeddings"
|
1172 |
+
):
|
1173 |
+
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
1174 |
+
|
1175 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1176 |
+
return prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
1177 |
+
|
1178 |
+
@staticmethod
|
1179 |
+
def _expand_inputs_for_generation(
|
1180 |
+
*args,
|
1181 |
+
**model_kwargs,
|
1182 |
+
):
|
1183 |
+
return expand_inputs_for_generation(*args, **model_kwargs)
|
1184 |
+
|
1185 |
+
@staticmethod
|
1186 |
+
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
1187 |
+
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder)
|
1188 |
+
|
1189 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1190 |
+
@add_code_sample_docstrings(
|
1191 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1192 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1193 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1194 |
+
config_class=_CONFIG_FOR_DOC,
|
1195 |
+
)
|
1196 |
+
def forward(
|
1197 |
+
self,
|
1198 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1199 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1200 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1201 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1202 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1203 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1204 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1205 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1206 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1207 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1208 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
1209 |
+
labels: Optional[torch.LongTensor] = None,
|
1210 |
+
use_cache: Optional[bool] = None,
|
1211 |
+
output_attentions: Optional[bool] = None,
|
1212 |
+
output_hidden_states: Optional[bool] = None,
|
1213 |
+
return_dict: Optional[bool] = None,
|
1214 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
1215 |
+
r"""
|
1216 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1217 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1218 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1219 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1220 |
+
"""
|
1221 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1222 |
+
|
1223 |
+
transformer_outputs = self.transformer(
|
1224 |
+
input_ids,
|
1225 |
+
past_key_values=past_key_values,
|
1226 |
+
attention_mask=attention_mask,
|
1227 |
+
token_type_ids=token_type_ids,
|
1228 |
+
position_ids=position_ids,
|
1229 |
+
head_mask=head_mask,
|
1230 |
+
inputs_embeds=inputs_embeds,
|
1231 |
+
pixel_values=pixel_values,
|
1232 |
+
image_embeddings=image_embeddings,
|
1233 |
+
image_attention_mask=image_attention_mask,
|
1234 |
+
crossblock_head_mask=crossblock_head_mask,
|
1235 |
+
use_cache=use_cache,
|
1236 |
+
output_attentions=output_attentions,
|
1237 |
+
output_hidden_states=output_hidden_states,
|
1238 |
+
return_dict=return_dict,
|
1239 |
+
)
|
1240 |
+
hidden_states = transformer_outputs[0]
|
1241 |
+
|
1242 |
+
# Set device for model parallelism
|
1243 |
+
if self.model_parallel:
|
1244 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1245 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1246 |
+
|
1247 |
+
lm_logits = self.lm_head(hidden_states)
|
1248 |
+
|
1249 |
+
loss = None
|
1250 |
+
if labels is not None:
|
1251 |
+
# Shift so that tokens < n predict n
|
1252 |
+
if attention_mask is not None:
|
1253 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1254 |
+
shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
1255 |
+
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
1256 |
+
else:
|
1257 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1258 |
+
shift_labels = labels[..., 1:].contiguous()
|
1259 |
+
# Flatten the tokens
|
1260 |
+
loss_fct = CrossEntropyLoss()
|
1261 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1262 |
+
|
1263 |
+
if not return_dict:
|
1264 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1265 |
+
return ((loss,) + output) if loss is not None else output
|
1266 |
+
|
1267 |
+
return CausalLMOutputWithCrossAttentions(
|
1268 |
+
loss=loss,
|
1269 |
+
logits=lm_logits,
|
1270 |
+
past_key_values=transformer_outputs.past_key_values,
|
1271 |
+
hidden_states=transformer_outputs.hidden_states,
|
1272 |
+
attentions=transformer_outputs.attentions,
|
1273 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
@staticmethod
|
1277 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
1278 |
+
"""
|
1279 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1280 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1281 |
+
beam_idx at every generation step.
|
1282 |
+
"""
|
1283 |
+
return tuple(
|
1284 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1285 |
+
for layer_past in past
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
|
1289 |
+
config_vl_model = self.config
|
1290 |
+
|
1291 |
+
language_embed_size = config_vl_model.n_embd
|
1292 |
+
num_language_layers = config_vl_model.n_layer
|
1293 |
+
ffn_inner_size = config_vl_model.n_inner
|
1294 |
+
|
1295 |
+
vision_config = self.transformer.vision_model.config
|
1296 |
+
if hasattr(vision_config, "vision_config"):
|
1297 |
+
vision_config = vision_config.vision_config
|
1298 |
+
|
1299 |
+
# Get vision model blocks infos
|
1300 |
+
vision_patch_size = vision_config.patch_size
|
1301 |
+
vision_hidden_size = vision_config.hidden_size
|
1302 |
+
num_vision_layers = vision_config.num_hidden_layers
|
1303 |
+
# The +1 is for the CLS token
|
1304 |
+
single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
|
1305 |
+
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
|
1306 |
+
|
1307 |
+
# Get language and cross-att blocks infos
|
1308 |
+
num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
|
1309 |
+
language_seq_len = data_param.max_seq_len
|
1310 |
+
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1311 |
+
cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1312 |
+
k_v_cross_attn_seq_len = (
|
1313 |
+
(self.config.resampler_n_latents * max_num_images)
|
1314 |
+
if self.config.use_resampler
|
1315 |
+
else (single_image_seq_len * max_num_images)
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1319 |
+
num_layers=num_language_layers,
|
1320 |
+
batch_size=hparams.batch_size_per_gpu,
|
1321 |
+
q_seq_len=language_seq_len,
|
1322 |
+
k_seq_len=language_seq_len,
|
1323 |
+
hidden_size=language_embed_size,
|
1324 |
+
kv_in_dim=language_embed_size,
|
1325 |
+
ff_exp_factor=language_exp_factor,
|
1326 |
+
grad_acc_size=hparams.grad_acc_size,
|
1327 |
+
swiglu=False,
|
1328 |
+
vocab_size=tokenizer.vocab_size,
|
1329 |
+
count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
|
1330 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1331 |
+
)
|
1332 |
+
cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1333 |
+
num_layers=num_cross_attn_layers,
|
1334 |
+
batch_size=hparams.batch_size_per_gpu,
|
1335 |
+
q_seq_len=language_seq_len,
|
1336 |
+
k_seq_len=k_v_cross_attn_seq_len,
|
1337 |
+
hidden_size=language_embed_size,
|
1338 |
+
kv_in_dim=vision_hidden_size,
|
1339 |
+
ff_exp_factor=cross_att_exp_factor,
|
1340 |
+
grad_acc_size=hparams.grad_acc_size,
|
1341 |
+
swiglu=False,
|
1342 |
+
vocab_size=None,
|
1343 |
+
count_backward=True,
|
1344 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1345 |
+
)
|
1346 |
+
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1347 |
+
num_layers=num_vision_layers,
|
1348 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1349 |
+
q_seq_len=single_image_seq_len,
|
1350 |
+
k_seq_len=single_image_seq_len,
|
1351 |
+
hidden_size=vision_hidden_size,
|
1352 |
+
kv_in_dim=vision_hidden_size,
|
1353 |
+
ff_exp_factor=vision_exp_factor,
|
1354 |
+
grad_acc_size=hparams.grad_acc_size,
|
1355 |
+
swiglu=False,
|
1356 |
+
vocab_size=None,
|
1357 |
+
count_backward=not hparams.model_params["freeze_vision_layers"],
|
1358 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1359 |
+
)
|
1360 |
+
if self.config.use_resampler:
|
1361 |
+
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
|
1362 |
+
num_layers=self.config.resampler_depth,
|
1363 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1364 |
+
q_seq_len=self.config.resampler_n_latents,
|
1365 |
+
vision_embed_seq_len=single_image_seq_len,
|
1366 |
+
q_k_v_input_dim=vision_hidden_size,
|
1367 |
+
attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
|
1368 |
+
ff_exp_factor=cross_att_exp_factor,
|
1369 |
+
count_backward=True,
|
1370 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1371 |
+
)
|
1372 |
+
flop_count = (
|
1373 |
+
language_tflops_per_batch_per_gpu
|
1374 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1375 |
+
+ vision_tflops_per_batch_per_gpu
|
1376 |
+
+ perceiver_tflops_per_batch_per_gpu
|
1377 |
+
)
|
1378 |
+
else:
|
1379 |
+
flop_count = (
|
1380 |
+
language_tflops_per_batch_per_gpu
|
1381 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1382 |
+
+ vision_tflops_per_batch_per_gpu
|
1383 |
+
)
|
1384 |
+
return flop_count
|
m4/models/vgpt_neo/__init__.py
ADDED
File without changes
|
m4/models/vgpt_neo/configuration_vgpt_neo.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" GPT Neo model configuration"""
|
16 |
+
import os
|
17 |
+
from typing import Tuple, Union
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"EleutherAI/gpt-neo-125M": "https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
|
28 |
+
"EleutherAI/gpt-neo-1.3B": "https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
|
29 |
+
# See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class VGPTNeoConfig(PretrainedConfig):
|
34 |
+
r"""
|
35 |
+
This is the configuration class to store the configuration of a [`GPTNeoModel`]. It is used to instantiate a GPT
|
36 |
+
Neo model according to the specified arguments, defining the model architecture. Instantiating a configuration with
|
37 |
+
the defaults will yield a similar configuration to that of the GPTNeo
|
38 |
+
[EleutherAI/gpt-neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B) architecture.
|
39 |
+
|
40 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
41 |
+
documentation from [`PretrainedConfig`] for more information.
|
42 |
+
|
43 |
+
TODO: this doc is completely out of sync with the actual args
|
44 |
+
|
45 |
+
Args:
|
46 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
47 |
+
Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
|
48 |
+
`inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
|
49 |
+
tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
|
50 |
+
additional_vocab_size (`int`, *optional`, defaults to 0):
|
51 |
+
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
52 |
+
are always trainable whereas regular vocab tokens can be frozen or not.
|
53 |
+
attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
|
54 |
+
The type of attention for each layer in a `List` of the following format `[[["attention_type"],
|
55 |
+
num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
|
56 |
+
value of `attention_type` from `["global", "local"]`
|
57 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
58 |
+
Dimensionality of the encoder layers and the pooler layer.
|
59 |
+
num_layers (`int`, *optional*, defaults to 24):
|
60 |
+
Number of hidden layers in the Transformer encoder.
|
61 |
+
num_heads (`int`, *optional*, defaults to 16):
|
62 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
63 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
64 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
65 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
66 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
67 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
68 |
+
embed_dropout (`float`, *optional*, defaults to 0.0):
|
69 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
70 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
71 |
+
The dropout ratio for the attention probabilities.
|
72 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
73 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
74 |
+
just in case (e.g., 512 or 1024 or 2048).
|
75 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
76 |
+
The vocabulary size of the `token_type_ids` passed when calling [`GPTNeoModel`].
|
77 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
78 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
79 |
+
alpha_initializer (`str`, *optional*, defaults to `"ones"`):
|
80 |
+
Initialization type for the alphas.
|
81 |
+
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
82 |
+
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
|
83 |
+
alpha_type (`str`, *optional*, defaults to `"vector"`):
|
84 |
+
Whether the gating alphas should be vectors or single floats.
|
85 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
86 |
+
The epsilon used by the layer normalization layers.
|
87 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
88 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
89 |
+
relevant if `config.is_decoder=True`.
|
90 |
+
cross_layer_interval (`int`, *optional*, default to 1)
|
91 |
+
Interval for cross attention (from text to image) layers.
|
92 |
+
Example:
|
93 |
+
```python
|
94 |
+
>>> from transformers import GPTNeoConfig, GPTNeoModel
|
95 |
+
>>> # Initializing a GPTNeo EleutherAI/gpt-neo-1.3B style configuration
|
96 |
+
>>> configuration = GPTNeoConfig()
|
97 |
+
>>> # Initializing a model (with random weights) from the EleutherAI/gpt-neo-1.3B style configuration
|
98 |
+
>>> model = GPTNeoModel(configuration)
|
99 |
+
>>> # Accessing the model configuration
|
100 |
+
>>> configuration = model.config
|
101 |
+
```"""
|
102 |
+
model_type = "vgpt_neo"
|
103 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
104 |
+
attribute_map = {"num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
vocab_size=50257,
|
109 |
+
additional_vocab_size=0,
|
110 |
+
max_position_embeddings=2048,
|
111 |
+
hidden_size=2048,
|
112 |
+
num_layers=24,
|
113 |
+
attention_types=[[["global", "local"], 12]],
|
114 |
+
num_heads=16,
|
115 |
+
intermediate_size=None,
|
116 |
+
window_size=256,
|
117 |
+
activation_function="gelu_new",
|
118 |
+
resid_dropout=0.0,
|
119 |
+
embed_dropout=0.0,
|
120 |
+
attention_dropout=0.0,
|
121 |
+
layer_norm_epsilon=1e-5,
|
122 |
+
initializer_range=0.02,
|
123 |
+
alpha_initializer="ones",
|
124 |
+
alphas_initializer_range=0.0,
|
125 |
+
alpha_type="vector",
|
126 |
+
summary_type="cls_index",
|
127 |
+
summary_use_proj=True,
|
128 |
+
summary_activation=None,
|
129 |
+
summary_proj_to_labels=True,
|
130 |
+
summary_first_dropout=0.1,
|
131 |
+
use_cache=True,
|
132 |
+
bos_token_id=50256,
|
133 |
+
eos_token_id=50256,
|
134 |
+
cross_layer_interval=1,
|
135 |
+
tie_word_embeddings=False,
|
136 |
+
freeze_text_layers=True,
|
137 |
+
freeze_lm_head=False,
|
138 |
+
freeze_vision_layers=True,
|
139 |
+
vision_model_name="google/vit-base-patch16-224",
|
140 |
+
vision_model_params="{}",
|
141 |
+
vision_embed_dim=768,
|
142 |
+
vision_image_size=224,
|
143 |
+
image_token_index=50257,
|
144 |
+
use_resampler=False,
|
145 |
+
resampler_n_latents=64,
|
146 |
+
resampler_depth=6,
|
147 |
+
resampler_n_heads=16,
|
148 |
+
resampler_head_dim=96,
|
149 |
+
**kwargs,
|
150 |
+
):
|
151 |
+
self.vocab_size = vocab_size
|
152 |
+
self.additional_vocab_size = additional_vocab_size
|
153 |
+
self.max_position_embeddings = max_position_embeddings
|
154 |
+
self.hidden_size = hidden_size
|
155 |
+
self.num_layers = num_layers
|
156 |
+
self.num_heads = num_heads
|
157 |
+
self.intermediate_size = intermediate_size
|
158 |
+
self.window_size = window_size
|
159 |
+
self.activation_function = activation_function
|
160 |
+
self.resid_dropout = resid_dropout
|
161 |
+
self.embed_dropout = embed_dropout
|
162 |
+
self.attention_dropout = attention_dropout
|
163 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
164 |
+
self.initializer_range = initializer_range
|
165 |
+
self.alpha_initializer = alpha_initializer
|
166 |
+
self.alphas_initializer_range = alphas_initializer_range
|
167 |
+
self.alpha_type = alpha_type
|
168 |
+
self.summary_type = summary_type
|
169 |
+
self.summary_use_proj = summary_use_proj
|
170 |
+
self.summary_activation = summary_activation
|
171 |
+
self.summary_first_dropout = summary_first_dropout
|
172 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
173 |
+
self.use_cache = use_cache
|
174 |
+
|
175 |
+
self.bos_token_id = bos_token_id
|
176 |
+
self.eos_token_id = eos_token_id
|
177 |
+
self.cross_layer_interval = cross_layer_interval
|
178 |
+
self.freeze_vision_layers = freeze_vision_layers
|
179 |
+
self.vision_model_name = vision_model_name
|
180 |
+
self.vision_model_params = vision_model_params
|
181 |
+
|
182 |
+
self.tie_word_embeddings = tie_word_embeddings
|
183 |
+
self.freeze_text_layers = freeze_text_layers
|
184 |
+
self.freeze_lm_head = freeze_lm_head
|
185 |
+
self.image_token_index = image_token_index
|
186 |
+
self.attention_types = attention_types
|
187 |
+
self.attention_layers = self.expand_attention_types_params(attention_types)
|
188 |
+
|
189 |
+
self.vision_embed_dim = vision_embed_dim
|
190 |
+
self.vision_image_size = vision_image_size
|
191 |
+
|
192 |
+
# Resampler params
|
193 |
+
self.use_resampler = use_resampler
|
194 |
+
self.resampler_n_latents = resampler_n_latents
|
195 |
+
self.resampler_depth = resampler_depth
|
196 |
+
self.resampler_n_heads = resampler_n_heads
|
197 |
+
self.resampler_head_dim = resampler_head_dim
|
198 |
+
|
199 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
200 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
201 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
202 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
203 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
204 |
+
|
205 |
+
super().__init__(
|
206 |
+
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
207 |
+
)
|
208 |
+
|
209 |
+
def check_compatibilities(self):
|
210 |
+
if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
|
211 |
+
raise ValueError(
|
212 |
+
"if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
|
213 |
+
)
|
214 |
+
|
215 |
+
vision_model_params = eval(self.vision_model_params)
|
216 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
217 |
+
if hasattr(config, "vision_config"):
|
218 |
+
vision_config = config.vision_config
|
219 |
+
else:
|
220 |
+
vision_config = config
|
221 |
+
vision_embed_dim = vision_config.hidden_size
|
222 |
+
if self.vision_embed_dim != vision_embed_dim:
|
223 |
+
raise ValueError(
|
224 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
225 |
+
f" ({vision_embed_dim})"
|
226 |
+
)
|
227 |
+
vision_image_size = vision_config.image_size
|
228 |
+
if self.vision_image_size != vision_image_size:
|
229 |
+
raise ValueError(
|
230 |
+
f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
|
231 |
+
f" ({vision_image_size})"
|
232 |
+
)
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
236 |
+
outputs = super(VGPTNeoConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
237 |
+
if isinstance(outputs, Tuple):
|
238 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
239 |
+
outputs[0].check_compatibilities()
|
240 |
+
else:
|
241 |
+
outputs.check_compatibilities()
|
242 |
+
return outputs
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def expand_attention_types_params(attention_types):
|
246 |
+
attentions = []
|
247 |
+
for item in attention_types:
|
248 |
+
for _ in range(item[1]):
|
249 |
+
attentions.extend(item[0])
|
250 |
+
return attentions
|
m4/models/vgpt_neo/modeling_vgpt_neo.py
ADDED
@@ -0,0 +1,1182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch GPT Neo model."""
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
from typing import Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
from torch import nn
|
24 |
+
from torch.nn import CrossEntropyLoss
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
27 |
+
from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
28 |
+
|
29 |
+
from m4.models import DecoupledEmbedding, DecoupledLinear
|
30 |
+
from m4.models.common import (
|
31 |
+
expand_inputs_for_generation,
|
32 |
+
prepare_inputs_for_generation,
|
33 |
+
update_model_kwargs_for_generation,
|
34 |
+
)
|
35 |
+
from m4.models.custom_modules import VLOOMPreTrainedModelBase
|
36 |
+
from m4.models.perceiver.perceiver import PerceiverResampler
|
37 |
+
from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig
|
38 |
+
from m4.training.utils import (
|
39 |
+
compute_perceiver_tflops_per_batch_per_gpu,
|
40 |
+
compute_tflops_per_batch_per_gpu,
|
41 |
+
freeze_model,
|
42 |
+
)
|
43 |
+
from m4.utils import logging
|
44 |
+
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
|
49 |
+
_CONFIG_FOR_DOC = "VGPTNeoConfig"
|
50 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
51 |
+
|
52 |
+
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
53 |
+
"EleutherAI/gpt-neo-125M",
|
54 |
+
"EleutherAI/gpt-neo-1.3B",
|
55 |
+
# See all GPTNeo models at https://huggingface.co/models?filter=gpt_neo
|
56 |
+
]
|
57 |
+
|
58 |
+
|
59 |
+
def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
|
60 |
+
"""Load tf checkpoints in a pytorch model"""
|
61 |
+
try:
|
62 |
+
import re
|
63 |
+
|
64 |
+
import tensorflow as tf
|
65 |
+
except ImportError:
|
66 |
+
logger.error(
|
67 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
68 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
69 |
+
)
|
70 |
+
raise
|
71 |
+
tf_path = os.path.abspath(gpt_neo_checkpoint_path)
|
72 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
73 |
+
# Load weights from TF model
|
74 |
+
init_vars = tf.train.list_variables(tf_path)
|
75 |
+
names = []
|
76 |
+
arrays = []
|
77 |
+
for name, shape in init_vars:
|
78 |
+
if "global_step" not in name and "adam" not in name:
|
79 |
+
array = tf.train.load_variable(tf_path, name)
|
80 |
+
array = tf.dtypes.cast(array.squeeze(), tf.float32).numpy()
|
81 |
+
name = name.replace("attn/q", "attn/attention/q_proj/w")
|
82 |
+
name = name.replace("attn/k", "attn/attention/k_proj/w")
|
83 |
+
name = name.replace("attn/v", "attn/attention/v_proj/w")
|
84 |
+
name = name.replace("attn/o", "attn/attention/out_proj/w")
|
85 |
+
name = name.replace("norm_1", "ln_1")
|
86 |
+
name = name.replace("norm_2", "ln_2")
|
87 |
+
name = name.replace("attn/compute_output_bias/o_b", "attn/attention/out_proj/b")
|
88 |
+
name = name.replace("conv1d_main/c_fc/kernel", "c_fc/w")
|
89 |
+
name = name.replace("conv1d_main/c_fc/bias", "c_fc/b")
|
90 |
+
name = name.replace("conv1d_main/c_proj/kernel", "c_proj/w")
|
91 |
+
name = name.replace("conv1d_main/c_proj/bias", "c_proj/b")
|
92 |
+
|
93 |
+
names.append(name)
|
94 |
+
arrays.append(array)
|
95 |
+
|
96 |
+
for name, array in zip(names, arrays):
|
97 |
+
name = name[5:] # skip "gpt2/"
|
98 |
+
name = name.split("/")
|
99 |
+
pointer = model.transformer
|
100 |
+
for m_name in name:
|
101 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
102 |
+
scope_names = re.split(r"(\d+)", m_name)
|
103 |
+
else:
|
104 |
+
scope_names = [m_name]
|
105 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
106 |
+
pointer = getattr(pointer, "weight")
|
107 |
+
elif scope_names[0] == "b":
|
108 |
+
pointer = getattr(pointer, "bias")
|
109 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
110 |
+
pointer = getattr(pointer, scope_names[0])
|
111 |
+
pointer = getattr(pointer, "weight")
|
112 |
+
else:
|
113 |
+
pointer = getattr(pointer, scope_names[0])
|
114 |
+
if len(scope_names) >= 2:
|
115 |
+
num = int(scope_names[1])
|
116 |
+
pointer = pointer[num]
|
117 |
+
|
118 |
+
if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
|
119 |
+
array = array.transpose()
|
120 |
+
|
121 |
+
if name == ["wte"]:
|
122 |
+
# if vocab is padded, then trim off the padding embeddings
|
123 |
+
array = array[: config.vocab_size]
|
124 |
+
|
125 |
+
if pointer.shape != array.shape:
|
126 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched {name}")
|
127 |
+
|
128 |
+
print(f"Initialize PyTorch weight {name}")
|
129 |
+
pointer.data = torch.from_numpy(array)
|
130 |
+
|
131 |
+
# init the final linear layer using word embeddings
|
132 |
+
embs = model.transformer.wte.weight
|
133 |
+
lin = nn.Linear(embs.size()[1], embs.size()[0], bias=False)
|
134 |
+
lin.weight = embs
|
135 |
+
model.set_output_embeddings(lin)
|
136 |
+
return model
|
137 |
+
|
138 |
+
|
139 |
+
class GPTNeoSelfAttention(nn.Module):
|
140 |
+
def __init__(self, config, attention_type, is_cross_attention=False):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
max_positions = config.max_position_embeddings
|
144 |
+
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
145 |
+
1, 1, max_positions, max_positions
|
146 |
+
)
|
147 |
+
|
148 |
+
# local causal self attention is a sliding window where each token can only attend to the previous
|
149 |
+
# window_size tokens. This is implemented by updating the causal mask such that for each token
|
150 |
+
# all other tokens are masked except the previous window_size tokens.
|
151 |
+
if attention_type == "local":
|
152 |
+
bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
|
153 |
+
self.is_cross_attention = is_cross_attention
|
154 |
+
self.register_buffer("bias", bias)
|
155 |
+
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
156 |
+
|
157 |
+
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
158 |
+
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
159 |
+
|
160 |
+
self.embed_dim = config.hidden_size
|
161 |
+
self.num_heads = config.num_heads
|
162 |
+
self.head_dim = self.embed_dim // self.num_heads
|
163 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
164 |
+
raise ValueError(
|
165 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
166 |
+
f" {self.num_heads})."
|
167 |
+
)
|
168 |
+
if self.is_cross_attention:
|
169 |
+
in_dim = self.embed_dim if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
|
170 |
+
self.k_proj = nn.Linear(in_dim, self.embed_dim, bias=False)
|
171 |
+
self.v_proj = nn.Linear(in_dim, self.embed_dim, bias=False)
|
172 |
+
else:
|
173 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
174 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
175 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
176 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
177 |
+
|
178 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
179 |
+
"""
|
180 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
181 |
+
"""
|
182 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
183 |
+
tensor = tensor.view(new_shape)
|
184 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
185 |
+
|
186 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
187 |
+
"""
|
188 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
189 |
+
"""
|
190 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
191 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
192 |
+
return tensor.view(new_shape)
|
193 |
+
|
194 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
195 |
+
# Keep the attention weights computation in fp32 to avoid overflow issues
|
196 |
+
query = query.to(torch.float32)
|
197 |
+
key = key.to(torch.float32)
|
198 |
+
|
199 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
200 |
+
|
201 |
+
if not self.is_cross_attention:
|
202 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
203 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
204 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
205 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
206 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
207 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
208 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
209 |
+
|
210 |
+
if attention_mask is not None:
|
211 |
+
# Apply the attention mask
|
212 |
+
attn_weights = attn_weights + attention_mask
|
213 |
+
|
214 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
215 |
+
attn_weights = attn_weights.to(value.dtype)
|
216 |
+
attn_weights = self.attn_dropout(attn_weights)
|
217 |
+
|
218 |
+
# Mask heads if we want to
|
219 |
+
if head_mask is not None:
|
220 |
+
attn_weights = attn_weights * head_mask
|
221 |
+
|
222 |
+
attn_output = torch.matmul(attn_weights, value)
|
223 |
+
|
224 |
+
return attn_output, attn_weights
|
225 |
+
|
226 |
+
def forward(
|
227 |
+
self,
|
228 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
229 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
230 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
231 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
232 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
233 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
234 |
+
use_cache: Optional[bool] = False,
|
235 |
+
output_attentions: Optional[bool] = False,
|
236 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
237 |
+
if encoder_hidden_states is not None:
|
238 |
+
key = self.k_proj(encoder_hidden_states)
|
239 |
+
value = self.v_proj(encoder_hidden_states)
|
240 |
+
attention_mask = encoder_attention_mask
|
241 |
+
else:
|
242 |
+
key = self.k_proj(hidden_states)
|
243 |
+
value = self.v_proj(hidden_states)
|
244 |
+
query = self.q_proj(hidden_states)
|
245 |
+
|
246 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
247 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
248 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
249 |
+
|
250 |
+
if layer_past is not None:
|
251 |
+
past_key = layer_past[0]
|
252 |
+
past_value = layer_past[1]
|
253 |
+
key = torch.cat((past_key, key), dim=-2)
|
254 |
+
value = torch.cat((past_value, value), dim=-2)
|
255 |
+
|
256 |
+
if use_cache is True:
|
257 |
+
present = (key, value)
|
258 |
+
else:
|
259 |
+
present = None
|
260 |
+
|
261 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
262 |
+
|
263 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
264 |
+
attn_output = self.out_proj(attn_output)
|
265 |
+
attn_output = self.resid_dropout(attn_output)
|
266 |
+
|
267 |
+
outputs = (attn_output, present)
|
268 |
+
if output_attentions:
|
269 |
+
outputs += (attn_weights,)
|
270 |
+
|
271 |
+
return outputs # a, present, (attentions)
|
272 |
+
|
273 |
+
|
274 |
+
class GPTNeoAttention(nn.Module):
|
275 |
+
def __init__(self, config, layer_id=0, is_cross_attention=False):
|
276 |
+
super().__init__()
|
277 |
+
self.layer_id = layer_id
|
278 |
+
self.attention_layers = config.attention_layers
|
279 |
+
self.attention_type = self.attention_layers[layer_id]
|
280 |
+
if self.attention_type in ["global", "local"]:
|
281 |
+
self.attention = GPTNeoSelfAttention(config, self.attention_type, is_cross_attention=is_cross_attention)
|
282 |
+
else:
|
283 |
+
raise NotImplementedError(
|
284 |
+
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
|
285 |
+
f"{config.attention_layers}. Select attn layer types from ['global', 'local'] only."
|
286 |
+
)
|
287 |
+
|
288 |
+
def forward(
|
289 |
+
self,
|
290 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
291 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
292 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
293 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
294 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
295 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
296 |
+
use_cache: Optional[bool] = False,
|
297 |
+
output_attentions: Optional[bool] = False,
|
298 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
299 |
+
return self.attention(
|
300 |
+
hidden_states,
|
301 |
+
attention_mask=attention_mask,
|
302 |
+
layer_past=layer_past,
|
303 |
+
head_mask=head_mask,
|
304 |
+
encoder_hidden_states=encoder_hidden_states,
|
305 |
+
encoder_attention_mask=encoder_attention_mask,
|
306 |
+
use_cache=use_cache,
|
307 |
+
output_attentions=output_attentions,
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
class GPTNeoMLP(nn.Module):
|
312 |
+
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
|
313 |
+
super().__init__()
|
314 |
+
embed_dim = config.hidden_size
|
315 |
+
self.c_fc = nn.Linear(embed_dim, intermediate_size)
|
316 |
+
self.c_proj = nn.Linear(intermediate_size, embed_dim)
|
317 |
+
self.act = ACT2FN[config.activation_function]
|
318 |
+
self.dropout = nn.Dropout(float(config.resid_dropout))
|
319 |
+
|
320 |
+
def forward(self, hidden_states):
|
321 |
+
hidden_states = self.c_fc(hidden_states)
|
322 |
+
hidden_states = self.act(hidden_states)
|
323 |
+
hidden_states = self.c_proj(hidden_states)
|
324 |
+
hidden_states = self.dropout(hidden_states)
|
325 |
+
return hidden_states
|
326 |
+
|
327 |
+
|
328 |
+
class GPTNeoBlock(nn.Module):
|
329 |
+
def __init__(self, config, layer_id):
|
330 |
+
super().__init__()
|
331 |
+
hidden_size = config.hidden_size
|
332 |
+
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
333 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
334 |
+
self.attn = GPTNeoAttention(config, layer_id, is_cross_attention=False)
|
335 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
336 |
+
|
337 |
+
self.mlp = GPTNeoMLP(inner_dim, config)
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
342 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
343 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
344 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
345 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
346 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
347 |
+
use_cache: Optional[bool] = False,
|
348 |
+
output_attentions: Optional[bool] = False,
|
349 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
350 |
+
residual = hidden_states
|
351 |
+
hidden_states = self.ln_1(hidden_states)
|
352 |
+
attn_outputs = self.attn(
|
353 |
+
hidden_states,
|
354 |
+
layer_past=layer_past,
|
355 |
+
attention_mask=attention_mask,
|
356 |
+
head_mask=head_mask,
|
357 |
+
use_cache=use_cache,
|
358 |
+
output_attentions=output_attentions,
|
359 |
+
)
|
360 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
361 |
+
outputs = attn_outputs[1:]
|
362 |
+
# residual connection
|
363 |
+
hidden_states = attn_output + residual
|
364 |
+
residual = hidden_states
|
365 |
+
hidden_states = self.ln_2(hidden_states)
|
366 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
367 |
+
# residual connection
|
368 |
+
hidden_states = residual + feed_forward_hidden_states
|
369 |
+
|
370 |
+
if use_cache:
|
371 |
+
outputs = (hidden_states,) + outputs
|
372 |
+
else:
|
373 |
+
outputs = (hidden_states,) + outputs[1:]
|
374 |
+
|
375 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
376 |
+
|
377 |
+
|
378 |
+
class VGPTNeoGatedCrossAttentionBlock(nn.Module):
|
379 |
+
def __init__(self, config, layer_id):
|
380 |
+
super().__init__()
|
381 |
+
hidden_size = config.hidden_size
|
382 |
+
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
383 |
+
|
384 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
385 |
+
self.cross_attn = GPTNeoAttention(config, layer_id, is_cross_attention=True)
|
386 |
+
self.mlp = GPTNeoMLP(inner_dim, config)
|
387 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
388 |
+
self.act = nn.Tanh()
|
389 |
+
|
390 |
+
if config.alpha_initializer == "zeros":
|
391 |
+
if config.alpha_type == "vector":
|
392 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
393 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
394 |
+
elif config.alpha_type == "float":
|
395 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
|
396 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1))
|
397 |
+
else:
|
398 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
399 |
+
|
400 |
+
elif config.alpha_initializer == "ones":
|
401 |
+
if config.alpha_type == "vector":
|
402 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, hidden_size))
|
403 |
+
self.alpha_dense = nn.Parameter(torch.ones(1, 1, hidden_size))
|
404 |
+
elif config.alpha_type == "float":
|
405 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1))
|
406 |
+
self.alpha_dense = nn.Parameter(torch.ones(1))
|
407 |
+
else:
|
408 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
409 |
+
|
410 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
411 |
+
if config.alpha_type == "vector":
|
412 |
+
self.alpha_cross_attn = nn.Parameter(
|
413 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
|
414 |
+
)
|
415 |
+
self.alpha_dense = nn.Parameter(
|
416 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, hidden_size))
|
417 |
+
)
|
418 |
+
elif config.alpha_type == "float":
|
419 |
+
self.alpha_cross_attn = nn.Parameter(
|
420 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
421 |
+
)
|
422 |
+
self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
|
423 |
+
else:
|
424 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
425 |
+
|
426 |
+
else:
|
427 |
+
raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
|
428 |
+
|
429 |
+
def forward(
|
430 |
+
self,
|
431 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
432 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
433 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
434 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
435 |
+
image_hidden_states: Optional[torch.Tensor] = None,
|
436 |
+
image_attention_mask: Optional[torch.FloatTensor] = None,
|
437 |
+
use_cache: Optional[bool] = False,
|
438 |
+
output_attentions: Optional[bool] = False,
|
439 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
440 |
+
if image_hidden_states is None:
|
441 |
+
raise ValueError(
|
442 |
+
"`image_hidden_states` is required for VGPT2 cross attention module which are visual features to be"
|
443 |
+
" conditioned on."
|
444 |
+
)
|
445 |
+
# add one self-attention block for cross-attention
|
446 |
+
|
447 |
+
# TODO(aps): Handle cross attention in the outputs
|
448 |
+
# if not hasattr(self, "crossattention"):
|
449 |
+
# raise ValueError(
|
450 |
+
# f"If `image_hidden_states` are passed, {self} has to be instantiated with "
|
451 |
+
# "cross-attention layers by setting `config.add_cross_attention=True`"
|
452 |
+
# )
|
453 |
+
residual = hidden_states
|
454 |
+
|
455 |
+
hidden_states = self.ln_1(hidden_states)
|
456 |
+
cross_attn_outputs = self.cross_attn(
|
457 |
+
hidden_states,
|
458 |
+
attention_mask=attention_mask,
|
459 |
+
head_mask=head_mask,
|
460 |
+
encoder_hidden_states=image_hidden_states,
|
461 |
+
encoder_attention_mask=image_attention_mask,
|
462 |
+
output_attentions=output_attentions,
|
463 |
+
)
|
464 |
+
attn_output = cross_attn_outputs[0]
|
465 |
+
outputs = cross_attn_outputs[1:]
|
466 |
+
# residual connection
|
467 |
+
hidden_states = residual + self.act(self.alpha_cross_attn) * attn_output
|
468 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
469 |
+
|
470 |
+
residual = hidden_states
|
471 |
+
hidden_states = self.ln_2(hidden_states)
|
472 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
473 |
+
# residual connection
|
474 |
+
hidden_states = residual + self.act(self.alpha_dense) * feed_forward_hidden_states
|
475 |
+
|
476 |
+
if use_cache:
|
477 |
+
outputs = (hidden_states,) + outputs
|
478 |
+
else:
|
479 |
+
outputs = (hidden_states,) + outputs[1:]
|
480 |
+
|
481 |
+
return outputs
|
482 |
+
|
483 |
+
|
484 |
+
class VGPTNeoPreTrainedModel(VLOOMPreTrainedModelBase):
|
485 |
+
"""
|
486 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
487 |
+
models.
|
488 |
+
"""
|
489 |
+
|
490 |
+
config_class = VGPTNeoConfig
|
491 |
+
load_tf_weights = load_tf_weights_in_gpt_neo
|
492 |
+
base_model_prefix = "transformer"
|
493 |
+
supports_gradient_checkpointing = True
|
494 |
+
_no_split_modules = ["GPTNeoBlock"]
|
495 |
+
|
496 |
+
def __init__(self, *inputs, **kwargs):
|
497 |
+
super().__init__(*inputs, **kwargs)
|
498 |
+
|
499 |
+
def _init_weights(self, module):
|
500 |
+
"""Initialize the weights."""
|
501 |
+
if isinstance(module, (nn.Linear,)):
|
502 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
503 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
504 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
505 |
+
if module.bias is not None:
|
506 |
+
module.bias.data.zero_()
|
507 |
+
elif isinstance(module, nn.Embedding):
|
508 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
509 |
+
if module.padding_idx is not None:
|
510 |
+
module.weight.data[module.padding_idx].zero_()
|
511 |
+
elif isinstance(module, nn.LayerNorm):
|
512 |
+
module.bias.data.zero_()
|
513 |
+
module.weight.data.fill_(1.0)
|
514 |
+
|
515 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
516 |
+
if isinstance(module, VGPTNeoModel):
|
517 |
+
module.gradient_checkpointing = value
|
518 |
+
|
519 |
+
@classmethod
|
520 |
+
def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
|
521 |
+
# this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
|
522 |
+
beheaded_model = model.transformer if hasattr(model, "transformer") else model
|
523 |
+
cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
|
524 |
+
beheaded_model.freeze_relevant_params(config)
|
525 |
+
|
526 |
+
|
527 |
+
GPT_NEO_START_DOCSTRING = r"""
|
528 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
529 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
530 |
+
etc.)
|
531 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
532 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
533 |
+
and behavior.
|
534 |
+
Parameters:
|
535 |
+
config ([`GPTNeoConfig`]): Model configuration class with all the parameters of the model.
|
536 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
537 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
538 |
+
"""
|
539 |
+
|
540 |
+
GPT_NEO_INPUTS_DOCSTRING = r"""
|
541 |
+
Args:
|
542 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
543 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
544 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
545 |
+
sequence tokens in the vocabulary.
|
546 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
547 |
+
`input_ids`.
|
548 |
+
Indices can be obtained using [`GPTNeoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
549 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
550 |
+
[What are input IDs?](../glossary#input-ids)
|
551 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
|
552 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
553 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
554 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
555 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
556 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
557 |
+
- 1 for tokens that are **not masked**,
|
558 |
+
- 0 for tokens that are **masked**.
|
559 |
+
[What are attention masks?](../glossary#attention-mask)
|
560 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
561 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
562 |
+
1]`:
|
563 |
+
- 0 corresponds to a *sentence A* token,
|
564 |
+
- 1 corresponds to a *sentence B* token.
|
565 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
566 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
567 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
568 |
+
config.max_position_embeddings - 1]`.
|
569 |
+
[What are position IDs?](../glossary#position-ids)
|
570 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
571 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
572 |
+
- 1 indicates the head is **not masked**,
|
573 |
+
- 0 indicates the head is **masked**.
|
574 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
575 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
576 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
577 |
+
model's internal embedding lookup matrix.
|
578 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
579 |
+
`past_key_values`).
|
580 |
+
use_cache (`bool`, *optional*):
|
581 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
582 |
+
`past_key_values`).
|
583 |
+
output_attentions (`bool`, *optional*):
|
584 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
585 |
+
tensors for more detail.
|
586 |
+
output_hidden_states (`bool`, *optional*):
|
587 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
588 |
+
more detail.
|
589 |
+
return_dict (`bool`, *optional*):
|
590 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
591 |
+
"""
|
592 |
+
|
593 |
+
|
594 |
+
@add_start_docstrings(
|
595 |
+
"The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
|
596 |
+
GPT_NEO_START_DOCSTRING,
|
597 |
+
)
|
598 |
+
class VGPTNeoModel(VGPTNeoPreTrainedModel):
|
599 |
+
def __init__(self, config, vision_model=None):
|
600 |
+
super().__init__(config)
|
601 |
+
|
602 |
+
self.embed_dim = config.hidden_size
|
603 |
+
self.wte = DecoupledEmbedding(
|
604 |
+
num_embeddings=config.vocab_size,
|
605 |
+
num_additional_embeddings=config.additional_vocab_size,
|
606 |
+
embedding_dim=self.embed_dim,
|
607 |
+
partially_freeze=config.freeze_text_layers,
|
608 |
+
)
|
609 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
610 |
+
self.drop = nn.Dropout(float(config.embed_dropout))
|
611 |
+
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
612 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
613 |
+
|
614 |
+
self.cross_layer_interval = config.cross_layer_interval
|
615 |
+
num_cross_layers = config.num_layers // self.cross_layer_interval
|
616 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
617 |
+
[VGPTNeoGatedCrossAttentionBlock(config, layer_id=i) for i in range(num_cross_layers)]
|
618 |
+
)
|
619 |
+
|
620 |
+
# Perceiver Resampler
|
621 |
+
if config.use_resampler:
|
622 |
+
self.perceiver_resampler = PerceiverResampler(
|
623 |
+
self.config,
|
624 |
+
self.config.vision_embed_dim,
|
625 |
+
config.resampler_depth,
|
626 |
+
config.resampler_n_heads,
|
627 |
+
config.resampler_head_dim,
|
628 |
+
config.resampler_n_latents,
|
629 |
+
)
|
630 |
+
self.gradient_checkpointing = False
|
631 |
+
self.image_token_idx = config.image_token_index
|
632 |
+
|
633 |
+
# Load an uninitialized model and later in from_pretrained will load the pre-trained model -
|
634 |
+
# this solves the losing of weights in `from_pretrained` on the main model
|
635 |
+
self.vision_model = vision_model
|
636 |
+
|
637 |
+
# Initialize weights and apply final processing
|
638 |
+
self.post_init()
|
639 |
+
|
640 |
+
self.freeze_relevant_params(config)
|
641 |
+
|
642 |
+
def freeze_relevant_params(self, config=None):
|
643 |
+
if config is None:
|
644 |
+
config = self.config
|
645 |
+
|
646 |
+
if config.freeze_text_layers:
|
647 |
+
self.freeze_text_layers()
|
648 |
+
|
649 |
+
if config.freeze_vision_layers:
|
650 |
+
freeze_model(self.vision_model)
|
651 |
+
|
652 |
+
def freeze_text_layers(self):
|
653 |
+
for module in [self.wpe, self.h, self.ln_f]:
|
654 |
+
freeze_model(module)
|
655 |
+
|
656 |
+
def get_input_embeddings(self):
|
657 |
+
return self.wte
|
658 |
+
|
659 |
+
def set_input_embeddings(self, new_embeddings):
|
660 |
+
self.wte = new_embeddings
|
661 |
+
|
662 |
+
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
663 |
+
@add_code_sample_docstrings(
|
664 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
665 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
666 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
667 |
+
config_class=_CONFIG_FOR_DOC,
|
668 |
+
)
|
669 |
+
def forward(
|
670 |
+
self,
|
671 |
+
input_ids: Optional[torch.LongTensor] = None,
|
672 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
673 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
674 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
675 |
+
position_ids: Optional[torch.LongTensor] = None,
|
676 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
677 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
678 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
679 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
680 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
681 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
682 |
+
use_cache: Optional[bool] = None,
|
683 |
+
output_attentions: Optional[bool] = None,
|
684 |
+
output_hidden_states: Optional[bool] = None,
|
685 |
+
return_dict: Optional[bool] = None,
|
686 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
687 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
688 |
+
output_hidden_states = (
|
689 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
690 |
+
)
|
691 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
692 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
693 |
+
|
694 |
+
if input_ids is not None and inputs_embeds is not None:
|
695 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
696 |
+
elif input_ids is not None:
|
697 |
+
input_shape = input_ids.size()
|
698 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
699 |
+
batch_size = input_ids.shape[0]
|
700 |
+
elif inputs_embeds is not None:
|
701 |
+
input_shape = inputs_embeds.size()[:-1]
|
702 |
+
batch_size = inputs_embeds.shape[0]
|
703 |
+
else:
|
704 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
705 |
+
|
706 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
707 |
+
|
708 |
+
if token_type_ids is not None:
|
709 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
710 |
+
if position_ids is not None:
|
711 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
712 |
+
|
713 |
+
if past_key_values is None:
|
714 |
+
past_length = 0
|
715 |
+
past_key_values = tuple([None] * len(self.h))
|
716 |
+
else:
|
717 |
+
past_length = past_key_values[0][0].size(-2)
|
718 |
+
if position_ids is None:
|
719 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
720 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
721 |
+
|
722 |
+
# GPT2Attention mask.
|
723 |
+
if attention_mask is not None:
|
724 |
+
if batch_size <= 0:
|
725 |
+
raise ValueError("batch_size has to be defined and > 0")
|
726 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
727 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
728 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
729 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
730 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
731 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
732 |
+
attention_mask = attention_mask[:, None, None, :]
|
733 |
+
|
734 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
735 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
736 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
737 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
738 |
+
# effectively the same as removing these entirely.
|
739 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
740 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
741 |
+
|
742 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
743 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
744 |
+
if pixel_values is not None and image_embeddings is not None:
|
745 |
+
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
746 |
+
elif pixel_values is not None:
|
747 |
+
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
748 |
+
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
749 |
+
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
750 |
+
# Get sequence from the vision encoder
|
751 |
+
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
752 |
+
elif image_embeddings is not None:
|
753 |
+
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
754 |
+
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
755 |
+
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
756 |
+
|
757 |
+
if self.config.use_resampler:
|
758 |
+
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
759 |
+
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
760 |
+
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
761 |
+
# Make image_attention_mask compatible with hidden states
|
762 |
+
text_seq_len = image_attention_mask.size(1)
|
763 |
+
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
764 |
+
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
765 |
+
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
766 |
+
|
767 |
+
if image_hidden_states is not None:
|
768 |
+
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
769 |
+
image_hidden_shape = (image_batch_size, image_sequence_length)
|
770 |
+
if image_attention_mask is None:
|
771 |
+
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
772 |
+
# image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
773 |
+
image_attention_mask = image_attention_mask.to(torch.bool)
|
774 |
+
image_attention_mask = image_attention_mask[:, None, :, :]
|
775 |
+
else:
|
776 |
+
image_attention_mask = None
|
777 |
+
|
778 |
+
# Prepare head mask if needed
|
779 |
+
# 1.0 in head_mask indicate we keep the head
|
780 |
+
# attention_probs has shape bsz x n_heads x N x N
|
781 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
782 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
783 |
+
|
784 |
+
if inputs_embeds is None:
|
785 |
+
inputs_embeds = self.wte(input_ids)
|
786 |
+
|
787 |
+
position_embeds = self.wpe(position_ids)
|
788 |
+
hidden_states = inputs_embeds + position_embeds
|
789 |
+
|
790 |
+
if token_type_ids is not None:
|
791 |
+
token_type_embeds = self.wte(token_type_ids)
|
792 |
+
hidden_states = hidden_states + token_type_embeds
|
793 |
+
|
794 |
+
hidden_states = self.drop(hidden_states)
|
795 |
+
|
796 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
797 |
+
|
798 |
+
presents = () if use_cache else None
|
799 |
+
all_self_attentions = () if output_attentions else None
|
800 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
801 |
+
all_hidden_states = () if output_hidden_states else None
|
802 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
803 |
+
if output_hidden_states:
|
804 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
805 |
+
|
806 |
+
def vblock(
|
807 |
+
main_block,
|
808 |
+
hidden_states,
|
809 |
+
layer_past,
|
810 |
+
attention_mask,
|
811 |
+
layer_head_mask,
|
812 |
+
use_cache,
|
813 |
+
output_attentions,
|
814 |
+
image_hidden_states,
|
815 |
+
image_attention_mask,
|
816 |
+
layer_idx,
|
817 |
+
cross_layer_interval,
|
818 |
+
gated_cross_attn_layers,
|
819 |
+
):
|
820 |
+
# TODO(aps): Add cross attention values to respective lists
|
821 |
+
# TODO(aps): Add xblock head mask support
|
822 |
+
if layer_idx % cross_layer_interval == 0:
|
823 |
+
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
824 |
+
outputs = xblock(
|
825 |
+
hidden_states,
|
826 |
+
attention_mask=attention_mask,
|
827 |
+
image_hidden_states=image_hidden_states,
|
828 |
+
image_attention_mask=image_attention_mask,
|
829 |
+
use_cache=use_cache,
|
830 |
+
output_attentions=output_attentions,
|
831 |
+
)
|
832 |
+
hidden_states = outputs[0]
|
833 |
+
|
834 |
+
outputs = main_block(
|
835 |
+
hidden_states,
|
836 |
+
layer_past=layer_past,
|
837 |
+
attention_mask=attention_mask,
|
838 |
+
head_mask=layer_head_mask,
|
839 |
+
use_cache=use_cache,
|
840 |
+
output_attentions=output_attentions,
|
841 |
+
)
|
842 |
+
|
843 |
+
return outputs
|
844 |
+
|
845 |
+
if self.gradient_checkpointing and self.training:
|
846 |
+
if use_cache:
|
847 |
+
logger.warning_once(
|
848 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
849 |
+
)
|
850 |
+
use_cache = False
|
851 |
+
|
852 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
853 |
+
vblock,
|
854 |
+
block,
|
855 |
+
hidden_states,
|
856 |
+
layer_past,
|
857 |
+
attention_mask,
|
858 |
+
head_mask[i],
|
859 |
+
use_cache,
|
860 |
+
output_attentions,
|
861 |
+
image_hidden_states,
|
862 |
+
image_attention_mask,
|
863 |
+
i,
|
864 |
+
self.cross_layer_interval,
|
865 |
+
self.gated_cross_attn_layers,
|
866 |
+
)
|
867 |
+
else:
|
868 |
+
outputs = vblock(
|
869 |
+
block,
|
870 |
+
hidden_states,
|
871 |
+
layer_past=layer_past,
|
872 |
+
attention_mask=attention_mask,
|
873 |
+
layer_head_mask=head_mask[i],
|
874 |
+
use_cache=use_cache,
|
875 |
+
output_attentions=output_attentions,
|
876 |
+
layer_idx=i,
|
877 |
+
image_hidden_states=image_hidden_states,
|
878 |
+
image_attention_mask=image_attention_mask,
|
879 |
+
cross_layer_interval=self.cross_layer_interval,
|
880 |
+
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
881 |
+
)
|
882 |
+
|
883 |
+
hidden_states = outputs[0]
|
884 |
+
if use_cache is True:
|
885 |
+
presents = presents + (outputs[1],)
|
886 |
+
|
887 |
+
if output_attentions:
|
888 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
889 |
+
if self.config.add_cross_attention:
|
890 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
891 |
+
|
892 |
+
hidden_states = self.ln_f(hidden_states)
|
893 |
+
|
894 |
+
hidden_states = hidden_states.view(output_shape)
|
895 |
+
# Add last hidden state
|
896 |
+
if output_hidden_states:
|
897 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
898 |
+
|
899 |
+
if not return_dict:
|
900 |
+
return tuple(
|
901 |
+
v
|
902 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
903 |
+
if v is not None
|
904 |
+
)
|
905 |
+
|
906 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
907 |
+
last_hidden_state=hidden_states,
|
908 |
+
past_key_values=presents,
|
909 |
+
hidden_states=all_hidden_states,
|
910 |
+
attentions=all_self_attentions,
|
911 |
+
cross_attentions=all_cross_attentions,
|
912 |
+
)
|
913 |
+
|
914 |
+
|
915 |
+
@add_start_docstrings(
|
916 |
+
"""
|
917 |
+
The GPT Neo Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
918 |
+
embeddings).
|
919 |
+
""",
|
920 |
+
GPT_NEO_START_DOCSTRING,
|
921 |
+
)
|
922 |
+
class VGPTNeoForCausalLM(VGPTNeoPreTrainedModel):
|
923 |
+
_keys_to_ignore_on_load_missing = [
|
924 |
+
r"h\.\d+\.attn\.masked_bias",
|
925 |
+
r"lm_head.weight",
|
926 |
+
r"h\.\d+\.attn\.attention\.bias",
|
927 |
+
]
|
928 |
+
_keys_to_ignore_on_save = [r"lm_head.weight"]
|
929 |
+
|
930 |
+
def __init__(self, config, vision_model=None):
|
931 |
+
super().__init__(config)
|
932 |
+
self.transformer = VGPTNeoModel(config, vision_model=vision_model)
|
933 |
+
self.lm_head = DecoupledLinear(
|
934 |
+
in_features=config.hidden_size,
|
935 |
+
out_features=config.vocab_size,
|
936 |
+
out_additional_features=config.additional_vocab_size,
|
937 |
+
bias=False,
|
938 |
+
partially_freeze=config.freeze_lm_head,
|
939 |
+
)
|
940 |
+
|
941 |
+
# Initialize weights and apply final processing
|
942 |
+
self.post_init()
|
943 |
+
|
944 |
+
def get_output_embeddings(self):
|
945 |
+
return self.lm_head
|
946 |
+
|
947 |
+
def set_output_embeddings(self, new_embeddings):
|
948 |
+
self.lm_head = new_embeddings
|
949 |
+
|
950 |
+
def tie_weights(self):
|
951 |
+
"""
|
952 |
+
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
953 |
+
"""
|
954 |
+
output_embeddings = self.get_output_embeddings()
|
955 |
+
input_embeddings = self.get_input_embeddings()
|
956 |
+
|
957 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
958 |
+
output_embeddings.weight = input_embeddings.weight
|
959 |
+
if input_embeddings.num_additional_embeddings > 0:
|
960 |
+
assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
|
961 |
+
output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
|
962 |
+
|
963 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
964 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
965 |
+
if hasattr(output_embeddings, "out_additional_features") and hasattr(
|
966 |
+
input_embeddings, "num_additional_embeddings"
|
967 |
+
):
|
968 |
+
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
969 |
+
|
970 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
971 |
+
return prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
972 |
+
|
973 |
+
@staticmethod
|
974 |
+
def _expand_inputs_for_generation(
|
975 |
+
*args,
|
976 |
+
**model_kwargs,
|
977 |
+
):
|
978 |
+
return expand_inputs_for_generation(*args, **model_kwargs)
|
979 |
+
|
980 |
+
@staticmethod
|
981 |
+
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
982 |
+
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
|
983 |
+
|
984 |
+
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
985 |
+
@add_code_sample_docstrings(
|
986 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
987 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
988 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
989 |
+
config_class=_CONFIG_FOR_DOC,
|
990 |
+
)
|
991 |
+
def forward(
|
992 |
+
self,
|
993 |
+
input_ids: Optional[torch.Tensor] = None,
|
994 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
995 |
+
attention_mask: Optional[torch.Tensor] = None,
|
996 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
997 |
+
position_ids: Optional[torch.Tensor] = None,
|
998 |
+
head_mask: Optional[torch.Tensor] = None,
|
999 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1000 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1001 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1002 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1003 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
1004 |
+
labels: Optional[torch.Tensor] = None,
|
1005 |
+
use_cache: Optional[bool] = None,
|
1006 |
+
output_attentions: Optional[bool] = None,
|
1007 |
+
output_hidden_states: Optional[bool] = None,
|
1008 |
+
return_dict: Optional[bool] = None,
|
1009 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
1010 |
+
r"""
|
1011 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1012 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1013 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1014 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1015 |
+
"""
|
1016 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1017 |
+
|
1018 |
+
transformer_outputs = self.transformer(
|
1019 |
+
input_ids,
|
1020 |
+
past_key_values=past_key_values,
|
1021 |
+
attention_mask=attention_mask,
|
1022 |
+
token_type_ids=token_type_ids,
|
1023 |
+
position_ids=position_ids,
|
1024 |
+
head_mask=head_mask,
|
1025 |
+
inputs_embeds=inputs_embeds,
|
1026 |
+
pixel_values=pixel_values,
|
1027 |
+
image_embeddings=image_embeddings,
|
1028 |
+
image_attention_mask=image_attention_mask,
|
1029 |
+
crossblock_head_mask=crossblock_head_mask,
|
1030 |
+
use_cache=use_cache,
|
1031 |
+
output_attentions=output_attentions,
|
1032 |
+
output_hidden_states=output_hidden_states,
|
1033 |
+
return_dict=return_dict,
|
1034 |
+
)
|
1035 |
+
hidden_states = transformer_outputs[0]
|
1036 |
+
|
1037 |
+
lm_logits = self.lm_head(hidden_states)
|
1038 |
+
|
1039 |
+
loss = None
|
1040 |
+
if labels is not None:
|
1041 |
+
# Compute loss in fp32 to match with mesh-tf version
|
1042 |
+
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
1043 |
+
lm_logits = lm_logits.to(torch.float32)
|
1044 |
+
|
1045 |
+
# Shift so that tokens < n predict n
|
1046 |
+
if attention_mask is not None:
|
1047 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1048 |
+
shift_logits = lm_logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
1049 |
+
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
1050 |
+
else:
|
1051 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1052 |
+
shift_labels = labels[..., 1:].contiguous()
|
1053 |
+
# Flatten the tokens
|
1054 |
+
loss_fct = CrossEntropyLoss()
|
1055 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1056 |
+
|
1057 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
1058 |
+
loss = loss.to(hidden_states.dtype)
|
1059 |
+
|
1060 |
+
if not return_dict:
|
1061 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1062 |
+
return ((loss,) + output) if loss is not None else output
|
1063 |
+
|
1064 |
+
return CausalLMOutputWithCrossAttentions(
|
1065 |
+
loss=loss,
|
1066 |
+
logits=lm_logits,
|
1067 |
+
past_key_values=transformer_outputs.past_key_values,
|
1068 |
+
hidden_states=transformer_outputs.hidden_states,
|
1069 |
+
attentions=transformer_outputs.attentions,
|
1070 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1071 |
+
)
|
1072 |
+
|
1073 |
+
@staticmethod
|
1074 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
1075 |
+
"""
|
1076 |
+
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
1077 |
+
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1078 |
+
beam_idx at every generation step.
|
1079 |
+
"""
|
1080 |
+
return tuple(
|
1081 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1082 |
+
for layer_past in past
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
|
1086 |
+
config_vl_model = self.config
|
1087 |
+
|
1088 |
+
language_embed_size = config_vl_model.hidden_size
|
1089 |
+
vision_config = self.transformer.vision_model.config
|
1090 |
+
num_language_layers = config_vl_model.num_layers
|
1091 |
+
ffn_inner_size = (
|
1092 |
+
config_vl_model.intermediate_size
|
1093 |
+
if config_vl_model.intermediate_size is not None
|
1094 |
+
else 4 * config_vl_model.hidden_size
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
# Get vision model blocks infos
|
1098 |
+
vision_patch_size = vision_config.patch_size
|
1099 |
+
vision_hidden_size = vision_config.hidden_size
|
1100 |
+
num_vision_layers = vision_config.num_hidden_layers
|
1101 |
+
# The +1 is for the CLS token
|
1102 |
+
single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
|
1103 |
+
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
|
1104 |
+
|
1105 |
+
# Get language and cross-att blocks infos
|
1106 |
+
num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
|
1107 |
+
language_seq_len = data_param.max_seq_len
|
1108 |
+
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1109 |
+
cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1110 |
+
k_v_cross_attn_seq_len = (
|
1111 |
+
(self.config.resampler_n_latents * max_num_images)
|
1112 |
+
if self.config.use_resampler
|
1113 |
+
else (single_image_seq_len * max_num_images)
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1117 |
+
num_layers=num_language_layers,
|
1118 |
+
batch_size=hparams.batch_size_per_gpu,
|
1119 |
+
q_seq_len=language_seq_len,
|
1120 |
+
k_seq_len=language_seq_len,
|
1121 |
+
hidden_size=language_embed_size,
|
1122 |
+
kv_in_dim=language_embed_size,
|
1123 |
+
ff_exp_factor=language_exp_factor,
|
1124 |
+
grad_acc_size=hparams.grad_acc_size,
|
1125 |
+
swiglu=False,
|
1126 |
+
vocab_size=tokenizer.vocab_size,
|
1127 |
+
count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
|
1128 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1129 |
+
)
|
1130 |
+
cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1131 |
+
num_layers=num_cross_attn_layers,
|
1132 |
+
batch_size=hparams.batch_size_per_gpu,
|
1133 |
+
q_seq_len=language_seq_len,
|
1134 |
+
k_seq_len=k_v_cross_attn_seq_len,
|
1135 |
+
hidden_size=language_embed_size,
|
1136 |
+
kv_in_dim=vision_hidden_size,
|
1137 |
+
ff_exp_factor=cross_att_exp_factor,
|
1138 |
+
grad_acc_size=hparams.grad_acc_size,
|
1139 |
+
swiglu=False,
|
1140 |
+
vocab_size=None,
|
1141 |
+
count_backward=True,
|
1142 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1143 |
+
)
|
1144 |
+
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1145 |
+
num_layers=num_vision_layers,
|
1146 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1147 |
+
q_seq_len=single_image_seq_len,
|
1148 |
+
k_seq_len=single_image_seq_len,
|
1149 |
+
hidden_size=vision_hidden_size,
|
1150 |
+
kv_in_dim=vision_hidden_size,
|
1151 |
+
ff_exp_factor=vision_exp_factor,
|
1152 |
+
grad_acc_size=hparams.grad_acc_size,
|
1153 |
+
swiglu=False,
|
1154 |
+
vocab_size=None,
|
1155 |
+
count_backward=not hparams.model_params["freeze_vision_layers"],
|
1156 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1157 |
+
)
|
1158 |
+
if self.config.use_resampler:
|
1159 |
+
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
|
1160 |
+
num_layers=self.config.resampler_depth,
|
1161 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1162 |
+
q_seq_len=self.config.resampler_n_latents,
|
1163 |
+
vision_embed_seq_len=single_image_seq_len,
|
1164 |
+
q_k_v_input_dim=vision_hidden_size,
|
1165 |
+
attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
|
1166 |
+
ff_exp_factor=cross_att_exp_factor,
|
1167 |
+
count_backward=True,
|
1168 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1169 |
+
)
|
1170 |
+
flop_count = (
|
1171 |
+
language_tflops_per_batch_per_gpu
|
1172 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1173 |
+
+ vision_tflops_per_batch_per_gpu
|
1174 |
+
+ perceiver_tflops_per_batch_per_gpu
|
1175 |
+
)
|
1176 |
+
else:
|
1177 |
+
flop_count = (
|
1178 |
+
language_tflops_per_batch_per_gpu
|
1179 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1180 |
+
+ vision_tflops_per_batch_per_gpu
|
1181 |
+
)
|
1182 |
+
return flop_count
|
m4/models/vllama/configuration_vllama.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" LLaMA model configuration"""
|
21 |
+
import os
|
22 |
+
from typing import Tuple, Union
|
23 |
+
|
24 |
+
from transformers import AutoConfig
|
25 |
+
from transformers.configuration_utils import PretrainedConfig
|
26 |
+
from transformers.utils import logging
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
32 |
+
|
33 |
+
|
34 |
+
class VLlamaConfig(PretrainedConfig):
|
35 |
+
r"""
|
36 |
+
TODO: update docstring with respect to new arguments
|
37 |
+
|
38 |
+
This is the configuration class to store the configuration of a [`~LlamaModel`]. It is used to instantiate an LLaMA
|
39 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
40 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
41 |
+
|
42 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
43 |
+
documentation from [`PretrainedConfig`] for more information.
|
44 |
+
|
45 |
+
|
46 |
+
Args:
|
47 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
48 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
49 |
+
`inputs_ids` passed when calling [`~LlamaModel`]
|
50 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
51 |
+
Dimension of the hidden representations.
|
52 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
53 |
+
Dimension of the MLP representations.
|
54 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
55 |
+
Number of hidden layers in the Transformer encoder.
|
56 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
57 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
58 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
59 |
+
The non-linear activation function (function or string) in the decoder.
|
60 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
61 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
62 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
63 |
+
The epsilon used by the rms normalization layers.
|
64 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
65 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
66 |
+
relevant if `config.is_decoder=True`.
|
67 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
68 |
+
Whether to tie weight embeddings
|
69 |
+
Example:
|
70 |
+
|
71 |
+
```python
|
72 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
73 |
+
|
74 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
75 |
+
>>> configuration = LlamaConfig()
|
76 |
+
|
77 |
+
>>> # Initializing a model from the llama-7b style configuration
|
78 |
+
>>> model = LlamaModel(configuration)
|
79 |
+
|
80 |
+
>>> # Accessing the model configuration
|
81 |
+
>>> configuration = model.config
|
82 |
+
```"""
|
83 |
+
model_type = "vllama"
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
vocab_size=32000,
|
88 |
+
additional_vocab_size=0,
|
89 |
+
hidden_size=4096,
|
90 |
+
intermediate_size=11008,
|
91 |
+
num_hidden_layers=32,
|
92 |
+
num_attention_heads=32,
|
93 |
+
dropout=0.0,
|
94 |
+
hidden_act="silu",
|
95 |
+
initializer_range=0.02,
|
96 |
+
alpha_initializer="ones",
|
97 |
+
alphas_initializer_range=0.0,
|
98 |
+
alpha_type="vector",
|
99 |
+
rms_norm_eps=1e-6,
|
100 |
+
use_cache=True,
|
101 |
+
pad_token_id=0,
|
102 |
+
bos_token_id=1,
|
103 |
+
eos_token_id=2,
|
104 |
+
tie_word_embeddings=False,
|
105 |
+
cross_layer_interval=1,
|
106 |
+
cross_layer_activation_function="swiglu",
|
107 |
+
qk_layer_norms=False,
|
108 |
+
qk_layer_norms_perceiver=False,
|
109 |
+
freeze_text_layers=True,
|
110 |
+
freeze_text_module_exceptions=[],
|
111 |
+
freeze_lm_head=False,
|
112 |
+
freeze_vision_layers=True,
|
113 |
+
freeze_vision_module_exceptions=[],
|
114 |
+
vision_model_name="google/vit-base-patch16-224",
|
115 |
+
vision_model_params="{}",
|
116 |
+
vision_embed_dim=768,
|
117 |
+
vision_image_size=224,
|
118 |
+
use_resampler=False,
|
119 |
+
resampler_n_latents=64,
|
120 |
+
resampler_depth=6,
|
121 |
+
resampler_n_heads=16,
|
122 |
+
resampler_head_dim=96,
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
self.vocab_size = vocab_size
|
126 |
+
self.additional_vocab_size = additional_vocab_size
|
127 |
+
self.hidden_size = hidden_size
|
128 |
+
self.intermediate_size = intermediate_size
|
129 |
+
self.num_hidden_layers = num_hidden_layers
|
130 |
+
self.num_attention_heads = num_attention_heads
|
131 |
+
self.dropout = dropout
|
132 |
+
self.hidden_act = hidden_act
|
133 |
+
self.initializer_range = initializer_range
|
134 |
+
self.alpha_initializer = alpha_initializer
|
135 |
+
self.alphas_initializer_range = alphas_initializer_range
|
136 |
+
self.alpha_type = alpha_type
|
137 |
+
self.rms_norm_eps = rms_norm_eps
|
138 |
+
self.use_cache = use_cache
|
139 |
+
super().__init__(
|
140 |
+
pad_token_id=pad_token_id,
|
141 |
+
bos_token_id=bos_token_id,
|
142 |
+
eos_token_id=eos_token_id,
|
143 |
+
tie_word_embeddings=tie_word_embeddings,
|
144 |
+
**kwargs,
|
145 |
+
)
|
146 |
+
|
147 |
+
self.cross_layer_interval = cross_layer_interval
|
148 |
+
self.cross_layer_activation_function = cross_layer_activation_function
|
149 |
+
self.qk_layer_norms = qk_layer_norms
|
150 |
+
self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
|
151 |
+
self.freeze_vision_layers = freeze_vision_layers
|
152 |
+
self.vision_model_name = vision_model_name
|
153 |
+
self.vision_model_params = vision_model_params
|
154 |
+
|
155 |
+
self.freeze_text_layers = freeze_text_layers
|
156 |
+
self.freeze_text_module_exceptions = freeze_text_module_exceptions
|
157 |
+
self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
|
158 |
+
self.freeze_lm_head = freeze_lm_head
|
159 |
+
|
160 |
+
self.vision_embed_dim = vision_embed_dim
|
161 |
+
self.vision_image_size = vision_image_size
|
162 |
+
|
163 |
+
# Resampler params
|
164 |
+
self.use_resampler = use_resampler
|
165 |
+
self.resampler_n_latents = resampler_n_latents
|
166 |
+
self.resampler_depth = resampler_depth
|
167 |
+
self.resampler_n_heads = resampler_n_heads
|
168 |
+
self.resampler_head_dim = resampler_head_dim
|
169 |
+
|
170 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
171 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
172 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
173 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
174 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
175 |
+
|
176 |
+
def check_compatibilities(self):
|
177 |
+
vision_model_params = eval(self.vision_model_params)
|
178 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
179 |
+
if hasattr(config, "vision_config"):
|
180 |
+
vision_config = config.vision_config
|
181 |
+
else:
|
182 |
+
vision_config = config
|
183 |
+
vision_embed_dim = vision_config.hidden_size
|
184 |
+
if self.vision_embed_dim != vision_embed_dim:
|
185 |
+
raise ValueError(
|
186 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
187 |
+
f" ({vision_embed_dim})"
|
188 |
+
)
|
189 |
+
vision_image_size = vision_config.image_size
|
190 |
+
if self.vision_image_size != vision_image_size:
|
191 |
+
raise ValueError(
|
192 |
+
f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
|
193 |
+
f" ({vision_image_size})"
|
194 |
+
)
|
195 |
+
|
196 |
+
@classmethod
|
197 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
198 |
+
outputs = super(VLlamaConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
199 |
+
if isinstance(outputs, Tuple):
|
200 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
201 |
+
outputs[0].check_compatibilities()
|
202 |
+
else:
|
203 |
+
outputs.check_compatibilities()
|
204 |
+
return outputs
|
m4/models/vllama/make_tiny_llama.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# This script creates a super tiny model that is useful inside tests, when we just want to test that
|
4 |
+
# the machinery works, without needing to check the quality of the outcomes.
|
5 |
+
#
|
6 |
+
# usage: adjust the configs if wanted, but otherwise just run the script
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
mname_tiny = "tiny-random-LlamaForCausalLM"
|
14 |
+
|
15 |
+
path = Path(mname_tiny)
|
16 |
+
path.mkdir(parents=True, exist_ok=True)
|
17 |
+
|
18 |
+
config = LlamaConfig()
|
19 |
+
config.update(
|
20 |
+
dict(
|
21 |
+
vocab_size=32000,
|
22 |
+
hidden_size=16,
|
23 |
+
intermediate_size=16 * 4,
|
24 |
+
num_hidden_layers=2,
|
25 |
+
num_attention_heads=4,
|
26 |
+
)
|
27 |
+
)
|
28 |
+
model = LlamaForCausalLM(config)
|
29 |
+
tokenizer = LlamaTokenizer.from_pretrained("path_to_llama_7b")
|
30 |
+
|
31 |
+
# Test w/ one text
|
32 |
+
query = "This is a test"
|
33 |
+
query_tokens = tokenizer(query, return_tensors="pt")
|
34 |
+
|
35 |
+
input = {
|
36 |
+
"input_ids": query_tokens["input_ids"],
|
37 |
+
"attention_mask": query_tokens["attention_mask"],
|
38 |
+
}
|
39 |
+
|
40 |
+
out_gen = model.generate(**input)
|
41 |
+
text = tokenizer.batch_decode(out_gen)
|
42 |
+
|
43 |
+
# Save model + config + tokenizer
|
44 |
+
model.half() # makes it smaller
|
45 |
+
model.save_pretrained(path)
|
46 |
+
tokenizer.save_pretrained(path)
|
47 |
+
|
48 |
+
# test we can load it back
|
49 |
+
model = LlamaForCausalLM.from_pretrained(path)
|
50 |
+
|
51 |
+
print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
|
m4/models/vllama/make_tiny_model.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# This script creates a super tiny model that is useful inside tests, when we just want to test that
|
4 |
+
# the machinery works, without needing to check the quality of the outcomes.
|
5 |
+
#
|
6 |
+
# usage: adjust the configs if wanted, but otherwise just run the script
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
from types import SimpleNamespace
|
10 |
+
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from m4.models.vllama.modeling_vllama import VLlamaConfig, VLlamaForCausalLM
|
15 |
+
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
|
16 |
+
from m4.training.utils import get_tokenizer
|
17 |
+
|
18 |
+
|
19 |
+
mname_tiny = "tiny-random-vllama-clip"
|
20 |
+
|
21 |
+
path = Path(mname_tiny)
|
22 |
+
path.mkdir(parents=True, exist_ok=True)
|
23 |
+
|
24 |
+
# from the hardcoded https://github.com/huggingface/m4/blob/adf102f0000cb2632cd8a3ebb87398c65e448a97/m4/training/main.py#L80
|
25 |
+
additional_vocab_size = 2
|
26 |
+
|
27 |
+
config = VLlamaConfig()
|
28 |
+
config.update(
|
29 |
+
dict(
|
30 |
+
ffn_dim=64,
|
31 |
+
hidden_size=16,
|
32 |
+
max_position_embeddings=128,
|
33 |
+
num_attention_heads=4,
|
34 |
+
num_hidden_layers=2,
|
35 |
+
word_embed_proj_dim=16,
|
36 |
+
max_new_tokens=100,
|
37 |
+
use_resampler=True,
|
38 |
+
resampler_depth=2,
|
39 |
+
resampler_head_dim=8,
|
40 |
+
resampler_n_heads=2,
|
41 |
+
resampler_n_latents=16,
|
42 |
+
vision_embed_dim=32,
|
43 |
+
vision_image_size=30,
|
44 |
+
vision_model_name="hf-internal-testing/tiny-random-clip",
|
45 |
+
vision_model_params="{}",
|
46 |
+
vocab_size=32000,
|
47 |
+
additional_vocab_size=additional_vocab_size,
|
48 |
+
)
|
49 |
+
)
|
50 |
+
|
51 |
+
# print(config)
|
52 |
+
# can now modify config to say tiny values
|
53 |
+
|
54 |
+
model = VLlamaForCausalLM.from_config(config)
|
55 |
+
# print(model.config)
|
56 |
+
# print(model)
|
57 |
+
|
58 |
+
tokenizer_config = dict(
|
59 |
+
tokenizer_add_special_tokens="{}",
|
60 |
+
tokenizer_add_tokens=(
|
61 |
+
'[AddedToken("<fake_token_around_image>", rstrip=False, lstrip=False), AddedToken("<image>", rstrip=False,'
|
62 |
+
" lstrip=False)]"
|
63 |
+
),
|
64 |
+
tokenizer_name="HuggingFaceM4/huggy-llama-tokenizer-7b",
|
65 |
+
tokenizer_params='{"use_fast": True}',
|
66 |
+
)
|
67 |
+
tokenizer_config = SimpleNamespace(**tokenizer_config)
|
68 |
+
# print(tokenizer_config)
|
69 |
+
|
70 |
+
tokenizer = get_tokenizer(
|
71 |
+
tokenizer_name=tokenizer_config.tokenizer_name,
|
72 |
+
tokenizer_add_tokens=tokenizer_config.tokenizer_add_tokens,
|
73 |
+
tokenizer_add_special_tokens=tokenizer_config.tokenizer_add_special_tokens,
|
74 |
+
tokenizer_params=tokenizer_config.tokenizer_params,
|
75 |
+
additional_vocab_size=model.config.additional_vocab_size,
|
76 |
+
model_vocab_size=model.config.vocab_size,
|
77 |
+
)
|
78 |
+
assert "<image>" in tokenizer.get_vocab()
|
79 |
+
|
80 |
+
# Test w/ one image and one text
|
81 |
+
query = "<fake_token_around_image><image><fake_token_around_image>This is a picture of a cat."
|
82 |
+
query_tokens = tokenizer(query, return_tensors="pt")
|
83 |
+
|
84 |
+
num_images_per_ex = 1
|
85 |
+
pixel_values = transforms.ToTensor()(Image.new("RGB", (30, 30))).repeat(1, 1, 1, 1).unsqueeze(0)
|
86 |
+
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(query_tokens["input_ids"], tokenizer)
|
87 |
+
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=num_images_per_ex)
|
88 |
+
|
89 |
+
input = {
|
90 |
+
"input_ids": query_tokens["input_ids"],
|
91 |
+
"attention_mask": query_tokens["attention_mask"],
|
92 |
+
"pixel_values": pixel_values,
|
93 |
+
"pixel_values": pixel_values,
|
94 |
+
"image_attention_mask": image_attention_mask,
|
95 |
+
}
|
96 |
+
# debug shapes
|
97 |
+
# print(query_tokens["input_ids"].shape)
|
98 |
+
# print(query_tokens["attention_mask"].shape)
|
99 |
+
# print(pixel_values.shape)
|
100 |
+
# print(image_attention_mask.shape)
|
101 |
+
|
102 |
+
out_gen = model.generate(**input)
|
103 |
+
text = tokenizer.batch_decode(out_gen)
|
104 |
+
# print(text)
|
105 |
+
|
106 |
+
# Save model + config + tokenizer
|
107 |
+
model.half() # makes it smaller
|
108 |
+
model.save_pretrained(path)
|
109 |
+
tokenizer.save_pretrained(path)
|
110 |
+
|
111 |
+
# test we can load it back
|
112 |
+
model = VLlamaForCausalLM.from_pretrained(path)
|
113 |
+
|
114 |
+
print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
|
m4/models/vllama/modeling_vllama.py
ADDED
@@ -0,0 +1,1260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
29 |
+
from transformers.modeling_utils import PretrainedConfig
|
30 |
+
from transformers.utils import (
|
31 |
+
ContextManagers,
|
32 |
+
add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
replace_return_docstrings,
|
35 |
+
)
|
36 |
+
|
37 |
+
from m4.models import DecoupledEmbedding, DecoupledLinear
|
38 |
+
from m4.models.common import (
|
39 |
+
expand_inputs_for_generation,
|
40 |
+
prepare_inputs_for_generation,
|
41 |
+
update_model_kwargs_for_generation,
|
42 |
+
)
|
43 |
+
from m4.models.custom_modules import VLOOMPreTrainedModelBase
|
44 |
+
from m4.models.perceiver.perceiver import PerceiverResampler
|
45 |
+
from m4.models.vllama.configuration_vllama import VLlamaConfig
|
46 |
+
from m4.training.utils import (
|
47 |
+
compute_perceiver_tflops_per_batch_per_gpu,
|
48 |
+
compute_tflops_per_batch_per_gpu,
|
49 |
+
deepspeed_gathered_parameters_context_manager,
|
50 |
+
freeze_model,
|
51 |
+
)
|
52 |
+
from m4.utils import logging
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__)
|
56 |
+
|
57 |
+
_CONFIG_FOR_DOC = "VLlamaConfig"
|
58 |
+
|
59 |
+
|
60 |
+
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
61 |
+
"""
|
62 |
+
Make causal mask used for bi-directional self-attention.
|
63 |
+
"""
|
64 |
+
bsz, tgt_len = input_ids_shape
|
65 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
|
66 |
+
mask_cond = torch.arange(mask.size(-1))
|
67 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
68 |
+
mask = mask.to(dtype)
|
69 |
+
|
70 |
+
if past_key_values_length > 0:
|
71 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
72 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
73 |
+
|
74 |
+
|
75 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
76 |
+
"""
|
77 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
78 |
+
"""
|
79 |
+
bsz, src_len = mask.size()
|
80 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
81 |
+
|
82 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
83 |
+
|
84 |
+
inverted_mask = 1.0 - expanded_mask
|
85 |
+
|
86 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
87 |
+
|
88 |
+
|
89 |
+
class LlamaRMSNorm(nn.Module):
|
90 |
+
def __init__(self, hidden_size, eps=1e-6):
|
91 |
+
"""
|
92 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
96 |
+
self.variance_epsilon = eps
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
100 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
101 |
+
|
102 |
+
# convert into half-precision if necessary
|
103 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
104 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
105 |
+
|
106 |
+
return self.weight * hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
110 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
111 |
+
super().__init__()
|
112 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
113 |
+
self.register_buffer("inv_freq", inv_freq)
|
114 |
+
|
115 |
+
# Build here to make `torch.jit.trace` work.
|
116 |
+
self.max_seq_len_cached = max_position_embeddings
|
117 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
118 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
119 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
120 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
121 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
122 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
123 |
+
|
124 |
+
def forward(self, x, seq_len=None):
|
125 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
126 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
127 |
+
if seq_len > self.max_seq_len_cached:
|
128 |
+
self.max_seq_len_cached = seq_len
|
129 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
130 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
131 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
132 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
133 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
134 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
135 |
+
return (
|
136 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
137 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def rotate_half(x):
|
142 |
+
"""Rotates half the hidden dims of the input."""
|
143 |
+
x1 = x[..., : x.shape[-1] // 2]
|
144 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
145 |
+
return torch.cat((-x2, x1), dim=-1)
|
146 |
+
|
147 |
+
|
148 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
149 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
150 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
151 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
152 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
153 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
154 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
155 |
+
return q_embed, k_embed
|
156 |
+
|
157 |
+
|
158 |
+
class LlamaMLP(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
hidden_size: int,
|
162 |
+
intermediate_size: int,
|
163 |
+
hidden_act: str,
|
164 |
+
):
|
165 |
+
super().__init__()
|
166 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
167 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
168 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
169 |
+
self.act_fn = ACT2FN[hidden_act]
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
173 |
+
|
174 |
+
|
175 |
+
class LlamaAttention(nn.Module):
|
176 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
177 |
+
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
hidden_size: int,
|
181 |
+
num_heads: int,
|
182 |
+
dropout: float = 0.0,
|
183 |
+
is_cross_attention: bool = False,
|
184 |
+
config: PretrainedConfig = None,
|
185 |
+
qk_layer_norms: bool = False,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.hidden_size = hidden_size
|
189 |
+
self.num_heads = num_heads
|
190 |
+
self.head_dim = hidden_size // num_heads
|
191 |
+
self.dropout = dropout
|
192 |
+
|
193 |
+
if (self.head_dim * num_heads) != self.hidden_size:
|
194 |
+
raise ValueError(
|
195 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
196 |
+
f" and `num_heads`: {num_heads})."
|
197 |
+
)
|
198 |
+
|
199 |
+
self.is_cross_attention = is_cross_attention
|
200 |
+
|
201 |
+
if self.is_cross_attention:
|
202 |
+
kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
|
203 |
+
self.q_proj = nn.Linear(
|
204 |
+
self.hidden_size,
|
205 |
+
num_heads * self.head_dim,
|
206 |
+
bias=False,
|
207 |
+
)
|
208 |
+
self.k_proj = nn.Linear(kv_input_dim, num_heads * self.head_dim, bias=False)
|
209 |
+
self.v_proj = nn.Linear(
|
210 |
+
kv_input_dim,
|
211 |
+
num_heads * self.head_dim,
|
212 |
+
bias=False,
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
self.q_proj = nn.Linear(
|
216 |
+
self.hidden_size,
|
217 |
+
num_heads * self.head_dim,
|
218 |
+
bias=False,
|
219 |
+
)
|
220 |
+
self.k_proj = nn.Linear(
|
221 |
+
self.hidden_size,
|
222 |
+
num_heads * self.head_dim,
|
223 |
+
bias=False,
|
224 |
+
)
|
225 |
+
self.v_proj = nn.Linear(
|
226 |
+
self.hidden_size,
|
227 |
+
num_heads * self.head_dim,
|
228 |
+
bias=False,
|
229 |
+
)
|
230 |
+
self.o_proj = nn.Linear(
|
231 |
+
num_heads * self.head_dim,
|
232 |
+
hidden_size,
|
233 |
+
bias=False,
|
234 |
+
)
|
235 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
|
236 |
+
|
237 |
+
self.qk_layer_norms = qk_layer_norms
|
238 |
+
if self.qk_layer_norms:
|
239 |
+
self.q_layer_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
240 |
+
self.k_layer_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
241 |
+
|
242 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
243 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
244 |
+
|
245 |
+
def forward(
|
246 |
+
self,
|
247 |
+
hidden_states: torch.Tensor,
|
248 |
+
key_value_states: Optional[torch.Tensor] = None,
|
249 |
+
attention_mask: Optional[torch.Tensor] = None,
|
250 |
+
position_ids: Optional[torch.LongTensor] = None,
|
251 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
252 |
+
output_attentions: bool = False,
|
253 |
+
use_cache: bool = False,
|
254 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
255 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
256 |
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
257 |
+
|
258 |
+
bsz, q_len, _ = hidden_states.size()
|
259 |
+
|
260 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
261 |
+
if not is_cross_attention:
|
262 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
263 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
264 |
+
else:
|
265 |
+
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
266 |
+
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
267 |
+
value_states = (
|
268 |
+
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
269 |
+
)
|
270 |
+
|
271 |
+
kv_seq_len = key_states.shape[-2]
|
272 |
+
if past_key_value is not None:
|
273 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
274 |
+
if not is_cross_attention:
|
275 |
+
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
|
276 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
277 |
+
# [bsz, nh, t, hd]
|
278 |
+
|
279 |
+
if past_key_value is not None:
|
280 |
+
# reuse k, v, self_attention
|
281 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
282 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
283 |
+
|
284 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
285 |
+
|
286 |
+
if self.qk_layer_norms:
|
287 |
+
query_states = self.q_layer_norm(query_states)
|
288 |
+
key_states = self.k_layer_norm(key_states)
|
289 |
+
|
290 |
+
if attention_mask is not None:
|
291 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
292 |
+
raise ValueError(
|
293 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
294 |
+
)
|
295 |
+
|
296 |
+
attn_output = nn.functional.scaled_dot_product_attention(
|
297 |
+
query_states,
|
298 |
+
key_states,
|
299 |
+
value_states,
|
300 |
+
attn_mask=attention_mask,
|
301 |
+
dropout_p=self.dropout,
|
302 |
+
)
|
303 |
+
|
304 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
305 |
+
raise ValueError(
|
306 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
307 |
+
f" {attn_output.size()}"
|
308 |
+
)
|
309 |
+
|
310 |
+
attn_output = attn_output.transpose(1, 2)
|
311 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
312 |
+
|
313 |
+
attn_output = self.o_proj(attn_output)
|
314 |
+
|
315 |
+
attn_weights = None
|
316 |
+
logger.warning_once(
|
317 |
+
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
|
318 |
+
)
|
319 |
+
|
320 |
+
return attn_output, attn_weights, past_key_value
|
321 |
+
|
322 |
+
|
323 |
+
class LlamaDecoderLayer(nn.Module):
|
324 |
+
def __init__(self, config: VLlamaConfig):
|
325 |
+
super().__init__()
|
326 |
+
self.hidden_size = config.hidden_size
|
327 |
+
self.self_attn = LlamaAttention(
|
328 |
+
hidden_size=self.hidden_size,
|
329 |
+
num_heads=config.num_attention_heads,
|
330 |
+
dropout=config.dropout,
|
331 |
+
config=config,
|
332 |
+
)
|
333 |
+
self.mlp = LlamaMLP(
|
334 |
+
hidden_size=self.hidden_size,
|
335 |
+
intermediate_size=config.intermediate_size,
|
336 |
+
hidden_act=config.hidden_act,
|
337 |
+
)
|
338 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
339 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
340 |
+
self.dropout = config.dropout
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
hidden_states: torch.Tensor,
|
345 |
+
attention_mask: Optional[torch.Tensor] = None,
|
346 |
+
position_ids: Optional[torch.LongTensor] = None,
|
347 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
348 |
+
output_attentions: Optional[bool] = False,
|
349 |
+
use_cache: Optional[bool] = False,
|
350 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
351 |
+
"""
|
352 |
+
Args:
|
353 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
354 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
355 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
356 |
+
output_attentions (`bool`, *optional*):
|
357 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
358 |
+
returned tensors for more detail.
|
359 |
+
use_cache (`bool`, *optional*):
|
360 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
361 |
+
(see `past_key_values`).
|
362 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
363 |
+
"""
|
364 |
+
|
365 |
+
residual = hidden_states
|
366 |
+
|
367 |
+
hidden_states = self.input_layernorm(hidden_states)
|
368 |
+
|
369 |
+
# Self Attention
|
370 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
371 |
+
hidden_states=hidden_states,
|
372 |
+
attention_mask=attention_mask,
|
373 |
+
position_ids=position_ids,
|
374 |
+
past_key_value=past_key_value,
|
375 |
+
output_attentions=output_attentions,
|
376 |
+
use_cache=use_cache,
|
377 |
+
)
|
378 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
379 |
+
hidden_states = residual + hidden_states
|
380 |
+
|
381 |
+
# Fully Connected
|
382 |
+
residual = hidden_states
|
383 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
384 |
+
hidden_states = self.mlp(hidden_states)
|
385 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
386 |
+
hidden_states = residual + hidden_states
|
387 |
+
|
388 |
+
outputs = (hidden_states,)
|
389 |
+
|
390 |
+
if output_attentions:
|
391 |
+
outputs += (self_attn_weights,)
|
392 |
+
|
393 |
+
if use_cache:
|
394 |
+
outputs += (present_key_value,)
|
395 |
+
|
396 |
+
return outputs
|
397 |
+
|
398 |
+
|
399 |
+
class VLlamaGatedCrossAttentionLayer(nn.Module):
|
400 |
+
def __init__(self, config: VLlamaConfig):
|
401 |
+
super().__init__()
|
402 |
+
self.hidden_size = config.hidden_size
|
403 |
+
self.cross_attn = LlamaAttention(
|
404 |
+
hidden_size=self.hidden_size,
|
405 |
+
num_heads=config.num_attention_heads,
|
406 |
+
is_cross_attention=True,
|
407 |
+
dropout=config.dropout,
|
408 |
+
config=config,
|
409 |
+
qk_layer_norms=config.qk_layer_norms,
|
410 |
+
)
|
411 |
+
self.mlp = LlamaMLP(
|
412 |
+
hidden_size=self.hidden_size,
|
413 |
+
intermediate_size=config.intermediate_size,
|
414 |
+
hidden_act=config.hidden_act,
|
415 |
+
)
|
416 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
417 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
418 |
+
self.config = config.dropout
|
419 |
+
|
420 |
+
self.act_cross_attn = nn.Tanh()
|
421 |
+
self.act_dense = nn.Tanh()
|
422 |
+
|
423 |
+
if config.alpha_initializer == "zeros":
|
424 |
+
if config.alpha_type == "vector":
|
425 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
426 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
427 |
+
elif config.alpha_type == "float":
|
428 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
|
429 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1))
|
430 |
+
else:
|
431 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
432 |
+
|
433 |
+
elif config.alpha_initializer == "ones":
|
434 |
+
if config.alpha_type == "vector":
|
435 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size))
|
436 |
+
self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.hidden_size))
|
437 |
+
elif config.alpha_type == "float":
|
438 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1))
|
439 |
+
self.alpha_dense = nn.Parameter(torch.ones(1))
|
440 |
+
else:
|
441 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
442 |
+
|
443 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
444 |
+
if config.alpha_type == "vector":
|
445 |
+
self.alpha_cross_attn = nn.Parameter(
|
446 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
|
447 |
+
)
|
448 |
+
self.alpha_dense = nn.Parameter(
|
449 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
|
450 |
+
)
|
451 |
+
elif config.alpha_type == "float":
|
452 |
+
self.alpha_cross_attn = nn.Parameter(
|
453 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
454 |
+
)
|
455 |
+
self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
|
456 |
+
else:
|
457 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
458 |
+
|
459 |
+
else:
|
460 |
+
raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
|
461 |
+
|
462 |
+
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
|
463 |
+
raise ValueError("Alpha parameters not initialized correctly!")
|
464 |
+
|
465 |
+
def forward(
|
466 |
+
self,
|
467 |
+
hidden_states: torch.Tensor,
|
468 |
+
attention_mask: Optional[torch.Tensor] = None,
|
469 |
+
image_hidden_states: Optional[torch.Tensor] = None,
|
470 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
471 |
+
output_attentions: Optional[bool] = False,
|
472 |
+
use_cache: Optional[bool] = False,
|
473 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
474 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
475 |
+
"""
|
476 |
+
Args:
|
477 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
478 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
479 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
480 |
+
output_attentions (`bool`, *optional*):
|
481 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
482 |
+
returned tensors for more detail.
|
483 |
+
use_cache (`bool`, *optional*):
|
484 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
485 |
+
(see `past_key_values`).
|
486 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
487 |
+
"""
|
488 |
+
if image_hidden_states is None:
|
489 |
+
raise ValueError(
|
490 |
+
"`image_hidden_states` is required for VLlama cross attention module which are visual features to be"
|
491 |
+
" conditioned on."
|
492 |
+
)
|
493 |
+
|
494 |
+
if past_key_value is not None:
|
495 |
+
raise NotImplementedError("Past key value states are not implemented for VLlama cross attention module.")
|
496 |
+
|
497 |
+
residual = hidden_states
|
498 |
+
|
499 |
+
hidden_states = self.input_layernorm(hidden_states)
|
500 |
+
|
501 |
+
# Self Attention
|
502 |
+
hidden_states, self_attn_weights, present_key_value = self.cross_attn(
|
503 |
+
hidden_states=hidden_states,
|
504 |
+
key_value_states=image_hidden_states,
|
505 |
+
attention_mask=image_attention_mask,
|
506 |
+
output_attentions=output_attentions,
|
507 |
+
)
|
508 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
509 |
+
hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
510 |
+
|
511 |
+
# Fully Connected
|
512 |
+
residual = hidden_states
|
513 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
514 |
+
hidden_states = self.mlp(hidden_states)
|
515 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
516 |
+
hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
|
517 |
+
|
518 |
+
outputs = (hidden_states,)
|
519 |
+
|
520 |
+
if output_attentions:
|
521 |
+
outputs += (self_attn_weights,)
|
522 |
+
|
523 |
+
if use_cache:
|
524 |
+
outputs += (present_key_value,)
|
525 |
+
|
526 |
+
return outputs
|
527 |
+
|
528 |
+
|
529 |
+
LLAMA_START_DOCSTRING = r"""
|
530 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
531 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
532 |
+
etc.)
|
533 |
+
|
534 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
535 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
536 |
+
and behavior.
|
537 |
+
|
538 |
+
Parameters:
|
539 |
+
config ([`VLlamaConfig`]):
|
540 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
541 |
+
load the weights associated with the model, only the configuration. Check out the
|
542 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
543 |
+
"""
|
544 |
+
|
545 |
+
|
546 |
+
@add_start_docstrings(
|
547 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
548 |
+
LLAMA_START_DOCSTRING,
|
549 |
+
)
|
550 |
+
class VLlamaPreTrainedModel(VLOOMPreTrainedModelBase):
|
551 |
+
config_class = VLlamaConfig
|
552 |
+
base_model_prefix = "model"
|
553 |
+
supports_gradient_checkpointing = True
|
554 |
+
_no_split_modules = ["LlamaDecoderLayer", "VLlamaGatedCrossAttentionLayer"]
|
555 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
556 |
+
|
557 |
+
def _init_weights(self, module):
|
558 |
+
def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
|
559 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
|
560 |
+
module.weight.data.normal_(mean=mean, std=std)
|
561 |
+
if module.bias is not None:
|
562 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
|
563 |
+
module.bias.data.zero_()
|
564 |
+
|
565 |
+
if isinstance(module, VLlamaGatedCrossAttentionLayer):
|
566 |
+
for sub_module_name, sub_module in module.named_modules():
|
567 |
+
if isinstance(sub_module, nn.Linear):
|
568 |
+
if "down_proj" in sub_module_name:
|
569 |
+
factor = 2 * self.config.num_hidden_layers
|
570 |
+
else:
|
571 |
+
factor = 1.0
|
572 |
+
init_a_linear(sub_module, std=(0.4 / (sub_module.in_features * factor)) ** 0.5)
|
573 |
+
elif isinstance(module, PerceiverResampler):
|
574 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
|
575 |
+
module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.vision_embed_dim) ** 0.5)
|
576 |
+
for sub_module_name, sub_module in module.named_modules():
|
577 |
+
if isinstance(sub_module, nn.Linear):
|
578 |
+
if "c_proj" in sub_module_name:
|
579 |
+
factor = 2 * self.config.num_hidden_layers
|
580 |
+
else:
|
581 |
+
factor = 1.0
|
582 |
+
init_a_linear(sub_module, std=(0.4 / (self.config.vision_embed_dim * factor)) ** 0.5)
|
583 |
+
elif isinstance(module, nn.Embedding):
|
584 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
|
585 |
+
module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
|
586 |
+
if module.padding_idx is not None:
|
587 |
+
module.weight.data[module.padding_idx].zero_()
|
588 |
+
elif isinstance(module, DecoupledLinear):
|
589 |
+
if hasattr(module, "additional_fc"):
|
590 |
+
init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)
|
591 |
+
|
592 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
593 |
+
if isinstance(module, VLlamaModel):
|
594 |
+
module.gradient_checkpointing = value
|
595 |
+
|
596 |
+
@classmethod
|
597 |
+
def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
|
598 |
+
# this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
|
599 |
+
beheaded_model = model.model if hasattr(model, "model") else model
|
600 |
+
cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
|
601 |
+
beheaded_model.freeze_relevant_params(config)
|
602 |
+
|
603 |
+
|
604 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
605 |
+
Args:
|
606 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
607 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
608 |
+
it.
|
609 |
+
|
610 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
611 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
612 |
+
|
613 |
+
[What are input IDs?](../glossary#input-ids)
|
614 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
615 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
616 |
+
|
617 |
+
- 1 for tokens that are **not masked**,
|
618 |
+
- 0 for tokens that are **masked**.
|
619 |
+
|
620 |
+
[What are attention masks?](../glossary#attention-mask)
|
621 |
+
|
622 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
623 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
624 |
+
|
625 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
626 |
+
`past_key_values`).
|
627 |
+
|
628 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
629 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
630 |
+
information on the default strategy.
|
631 |
+
|
632 |
+
- 1 indicates the head is **not masked**,
|
633 |
+
- 0 indicates the head is **masked**.
|
634 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
635 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
636 |
+
config.n_positions - 1]`.
|
637 |
+
[What are position IDs?](../glossary#position-ids)
|
638 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
639 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
640 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
641 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
642 |
+
|
643 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
644 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
645 |
+
|
646 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
647 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
648 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
649 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
650 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
651 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
652 |
+
model's internal embedding lookup matrix.
|
653 |
+
use_cache (`bool`, *optional*):
|
654 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
655 |
+
`past_key_values`).
|
656 |
+
output_attentions (`bool`, *optional*):
|
657 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
658 |
+
tensors for more detail.
|
659 |
+
output_hidden_states (`bool`, *optional*):
|
660 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
661 |
+
more detail.
|
662 |
+
return_dict (`bool`, *optional*):
|
663 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
664 |
+
"""
|
665 |
+
|
666 |
+
|
667 |
+
@add_start_docstrings(
|
668 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
669 |
+
LLAMA_START_DOCSTRING,
|
670 |
+
)
|
671 |
+
class VLlamaModel(VLlamaPreTrainedModel):
|
672 |
+
"""
|
673 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
674 |
+
|
675 |
+
Args:
|
676 |
+
config: VLlamaConfig
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(self, config: VLlamaConfig, vision_model=None):
|
680 |
+
super().__init__(config)
|
681 |
+
self.config = config
|
682 |
+
self.padding_idx = config.pad_token_id
|
683 |
+
self.vocab_size = config.vocab_size
|
684 |
+
|
685 |
+
self.embed_tokens = DecoupledEmbedding(
|
686 |
+
num_embeddings=config.vocab_size,
|
687 |
+
num_additional_embeddings=config.additional_vocab_size,
|
688 |
+
embedding_dim=config.hidden_size,
|
689 |
+
partially_freeze=config.freeze_text_layers,
|
690 |
+
padding_idx=self.padding_idx,
|
691 |
+
)
|
692 |
+
|
693 |
+
# Load an uninitialized model and later in from_pretrained will load the pre-trained model -
|
694 |
+
# this solves the losing of weights in `from_pretrained` on the main model
|
695 |
+
self.vision_model = vision_model
|
696 |
+
|
697 |
+
# Perceiver Resampler
|
698 |
+
if config.use_resampler:
|
699 |
+
self.perceiver_resampler = PerceiverResampler(
|
700 |
+
self.config,
|
701 |
+
self.config.vision_embed_dim,
|
702 |
+
config.resampler_depth,
|
703 |
+
config.resampler_n_heads,
|
704 |
+
config.resampler_head_dim,
|
705 |
+
config.resampler_n_latents,
|
706 |
+
)
|
707 |
+
|
708 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
709 |
+
|
710 |
+
self.cross_layer_interval = config.cross_layer_interval
|
711 |
+
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
712 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
713 |
+
[VLlamaGatedCrossAttentionLayer(config) for _ in range(num_cross_layers)]
|
714 |
+
)
|
715 |
+
self.gradient_checkpointing = False
|
716 |
+
|
717 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
718 |
+
|
719 |
+
self.gradient_checkpointing = False
|
720 |
+
# Initialize weights and apply final processing
|
721 |
+
self.post_init()
|
722 |
+
|
723 |
+
self.freeze_relevant_params(config)
|
724 |
+
|
725 |
+
def freeze_relevant_params(self, config=None):
|
726 |
+
if config is None:
|
727 |
+
config = self.config
|
728 |
+
|
729 |
+
if config.freeze_text_layers:
|
730 |
+
self.freeze_text_layers(config.freeze_text_module_exceptions)
|
731 |
+
|
732 |
+
if config.freeze_vision_layers:
|
733 |
+
freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
|
734 |
+
|
735 |
+
def freeze_text_layers(self, module_exceptions):
|
736 |
+
for module in [self.layers, self.norm]:
|
737 |
+
freeze_model(module, module_exceptions=module_exceptions)
|
738 |
+
|
739 |
+
def get_input_embeddings(self):
|
740 |
+
return self.embed_tokens
|
741 |
+
|
742 |
+
def set_input_embeddings(self, value):
|
743 |
+
self.embed_tokens = value
|
744 |
+
|
745 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
746 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
747 |
+
# create causal mask
|
748 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
749 |
+
combined_attention_mask = None
|
750 |
+
if input_shape[-1] > 1:
|
751 |
+
combined_attention_mask = _make_causal_mask(
|
752 |
+
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
753 |
+
).to(inputs_embeds.device)
|
754 |
+
|
755 |
+
if attention_mask is not None:
|
756 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
757 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
758 |
+
inputs_embeds.device
|
759 |
+
)
|
760 |
+
combined_attention_mask = (
|
761 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
762 |
+
)
|
763 |
+
|
764 |
+
return combined_attention_mask
|
765 |
+
|
766 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
767 |
+
def forward(
|
768 |
+
self,
|
769 |
+
input_ids: torch.LongTensor = None,
|
770 |
+
attention_mask: Optional[torch.Tensor] = None,
|
771 |
+
position_ids: Optional[torch.LongTensor] = None,
|
772 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
773 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
774 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
775 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
776 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
777 |
+
use_cache: Optional[bool] = None,
|
778 |
+
output_attentions: Optional[bool] = None,
|
779 |
+
output_hidden_states: Optional[bool] = None,
|
780 |
+
return_dict: Optional[bool] = None,
|
781 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
782 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
783 |
+
|
784 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
785 |
+
output_hidden_states = (
|
786 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
787 |
+
)
|
788 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
789 |
+
|
790 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
791 |
+
|
792 |
+
# retrieve input_ids and inputs_embeds
|
793 |
+
if input_ids is not None and inputs_embeds is not None:
|
794 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
795 |
+
elif input_ids is not None:
|
796 |
+
batch_size, seq_length = input_ids.shape
|
797 |
+
elif inputs_embeds is not None:
|
798 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
799 |
+
else:
|
800 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
801 |
+
|
802 |
+
seq_length_with_past = seq_length
|
803 |
+
past_key_values_length = 0
|
804 |
+
|
805 |
+
if past_key_values is not None:
|
806 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
807 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
808 |
+
|
809 |
+
if attention_mask is not None and position_ids is None:
|
810 |
+
# create position_ids on the fly for batch generation
|
811 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
812 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
813 |
+
elif position_ids is None:
|
814 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
815 |
+
position_ids = torch.arange(
|
816 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
817 |
+
)
|
818 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
819 |
+
else:
|
820 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
821 |
+
|
822 |
+
if pixel_values is not None and image_embeddings is not None:
|
823 |
+
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
824 |
+
elif pixel_values is not None:
|
825 |
+
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
826 |
+
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
827 |
+
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
828 |
+
# Get sequence from the vision encoder
|
829 |
+
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
830 |
+
elif image_embeddings is not None:
|
831 |
+
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
832 |
+
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
833 |
+
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
834 |
+
|
835 |
+
if self.config.use_resampler:
|
836 |
+
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
837 |
+
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
838 |
+
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
839 |
+
# Make image_attention_mask compatible with hidden states
|
840 |
+
text_seq_len = image_attention_mask.size(1)
|
841 |
+
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
842 |
+
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
843 |
+
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
844 |
+
|
845 |
+
if image_hidden_states is not None:
|
846 |
+
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
847 |
+
image_hidden_shape = (image_batch_size, image_sequence_length)
|
848 |
+
if image_attention_mask is None:
|
849 |
+
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
850 |
+
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
851 |
+
else:
|
852 |
+
image_attention_mask = None
|
853 |
+
|
854 |
+
if inputs_embeds is None:
|
855 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
856 |
+
# embed positions
|
857 |
+
if attention_mask is None:
|
858 |
+
attention_mask = torch.ones(
|
859 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
860 |
+
)
|
861 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
862 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
863 |
+
)
|
864 |
+
|
865 |
+
hidden_states = inputs_embeds
|
866 |
+
|
867 |
+
if self.gradient_checkpointing and self.training:
|
868 |
+
if use_cache:
|
869 |
+
logger.warning_once(
|
870 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
871 |
+
)
|
872 |
+
use_cache = False
|
873 |
+
|
874 |
+
# decoder layers
|
875 |
+
all_hidden_states = () if output_hidden_states else None
|
876 |
+
all_self_attns = () if output_attentions else None
|
877 |
+
next_decoder_cache = () if use_cache else None
|
878 |
+
|
879 |
+
for idx, decoder_layer in enumerate(self.layers):
|
880 |
+
if output_hidden_states:
|
881 |
+
all_hidden_states += (hidden_states,)
|
882 |
+
|
883 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
884 |
+
|
885 |
+
def vblock(
|
886 |
+
main_block,
|
887 |
+
hidden_states,
|
888 |
+
attention_mask,
|
889 |
+
position_ids,
|
890 |
+
past_key_value,
|
891 |
+
image_hidden_states,
|
892 |
+
image_attention_mask,
|
893 |
+
output_attentions,
|
894 |
+
use_cache,
|
895 |
+
layer_idx,
|
896 |
+
cross_layer_interval,
|
897 |
+
gated_cross_attn_layers,
|
898 |
+
):
|
899 |
+
# TODO(ls): Add cross attention values to respective lists
|
900 |
+
if layer_idx % cross_layer_interval == 0:
|
901 |
+
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
902 |
+
outputs = xblock(
|
903 |
+
hidden_states,
|
904 |
+
attention_mask=attention_mask,
|
905 |
+
image_hidden_states=image_hidden_states,
|
906 |
+
image_attention_mask=image_attention_mask,
|
907 |
+
output_attentions=output_attentions,
|
908 |
+
use_cache=use_cache,
|
909 |
+
past_key_value=None, # not implemented
|
910 |
+
)
|
911 |
+
hidden_states = outputs[0]
|
912 |
+
|
913 |
+
layer_outputs = main_block(
|
914 |
+
hidden_states,
|
915 |
+
attention_mask=attention_mask,
|
916 |
+
position_ids=position_ids,
|
917 |
+
past_key_value=past_key_value,
|
918 |
+
output_attentions=output_attentions,
|
919 |
+
use_cache=use_cache,
|
920 |
+
)
|
921 |
+
|
922 |
+
return layer_outputs
|
923 |
+
|
924 |
+
if self.gradient_checkpointing and self.training:
|
925 |
+
past_key_value = None
|
926 |
+
if use_cache:
|
927 |
+
logger.warning_once(
|
928 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
929 |
+
)
|
930 |
+
use_cache = False
|
931 |
+
|
932 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
933 |
+
vblock,
|
934 |
+
decoder_layer,
|
935 |
+
hidden_states,
|
936 |
+
attention_mask,
|
937 |
+
position_ids,
|
938 |
+
past_key_value,
|
939 |
+
image_hidden_states,
|
940 |
+
image_attention_mask,
|
941 |
+
output_attentions,
|
942 |
+
use_cache,
|
943 |
+
idx,
|
944 |
+
self.cross_layer_interval,
|
945 |
+
self.gated_cross_attn_layers,
|
946 |
+
)
|
947 |
+
else:
|
948 |
+
layer_outputs = vblock(
|
949 |
+
decoder_layer,
|
950 |
+
hidden_states,
|
951 |
+
attention_mask=attention_mask,
|
952 |
+
position_ids=position_ids,
|
953 |
+
past_key_value=past_key_value,
|
954 |
+
image_hidden_states=image_hidden_states,
|
955 |
+
image_attention_mask=image_attention_mask,
|
956 |
+
output_attentions=output_attentions,
|
957 |
+
use_cache=use_cache,
|
958 |
+
layer_idx=idx,
|
959 |
+
cross_layer_interval=self.cross_layer_interval,
|
960 |
+
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
961 |
+
)
|
962 |
+
|
963 |
+
hidden_states = layer_outputs[0]
|
964 |
+
|
965 |
+
if use_cache:
|
966 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
967 |
+
|
968 |
+
if output_attentions:
|
969 |
+
all_self_attns += (layer_outputs[1],)
|
970 |
+
|
971 |
+
hidden_states = self.norm(hidden_states)
|
972 |
+
|
973 |
+
# add hidden states from the last decoder layer
|
974 |
+
if output_hidden_states:
|
975 |
+
all_hidden_states += (hidden_states,)
|
976 |
+
|
977 |
+
next_cache = next_decoder_cache if use_cache else None
|
978 |
+
if not return_dict:
|
979 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
980 |
+
return BaseModelOutputWithPast(
|
981 |
+
last_hidden_state=hidden_states,
|
982 |
+
past_key_values=next_cache,
|
983 |
+
hidden_states=all_hidden_states,
|
984 |
+
attentions=all_self_attns,
|
985 |
+
)
|
986 |
+
|
987 |
+
|
988 |
+
class VLlamaForCausalLM(VLlamaPreTrainedModel):
|
989 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
990 |
+
|
991 |
+
def __init__(self, config, vision_model=None):
|
992 |
+
super().__init__(config)
|
993 |
+
self.model = VLlamaModel(config, vision_model=vision_model)
|
994 |
+
|
995 |
+
self.lm_head = DecoupledLinear(
|
996 |
+
in_features=config.hidden_size,
|
997 |
+
out_features=config.vocab_size,
|
998 |
+
out_additional_features=config.additional_vocab_size,
|
999 |
+
bias=False,
|
1000 |
+
partially_freeze=config.freeze_lm_head,
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
# Initialize weights and apply final processing
|
1004 |
+
self.post_init()
|
1005 |
+
|
1006 |
+
def get_input_embeddings(self):
|
1007 |
+
return self.model.embed_tokens
|
1008 |
+
|
1009 |
+
def set_input_embeddings(self, value):
|
1010 |
+
self.model.embed_tokens = value
|
1011 |
+
|
1012 |
+
def get_output_embeddings(self):
|
1013 |
+
return self.lm_head
|
1014 |
+
|
1015 |
+
def set_output_embeddings(self, new_embeddings):
|
1016 |
+
self.lm_head = new_embeddings
|
1017 |
+
|
1018 |
+
def set_decoder(self, decoder):
|
1019 |
+
self.model = decoder
|
1020 |
+
|
1021 |
+
def get_decoder(self):
|
1022 |
+
return self.model
|
1023 |
+
|
1024 |
+
def tie_weights(self):
|
1025 |
+
"""
|
1026 |
+
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
1027 |
+
"""
|
1028 |
+
output_embeddings = self.get_output_embeddings()
|
1029 |
+
input_embeddings = self.get_input_embeddings()
|
1030 |
+
|
1031 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
1032 |
+
output_embeddings.weight = input_embeddings.weight
|
1033 |
+
if input_embeddings.num_additional_embeddings > 0:
|
1034 |
+
assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
|
1035 |
+
output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
|
1036 |
+
|
1037 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
1038 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
1039 |
+
if hasattr(output_embeddings, "out_additional_features") and hasattr(
|
1040 |
+
input_embeddings, "num_additional_embeddings"
|
1041 |
+
):
|
1042 |
+
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
1043 |
+
|
1044 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1045 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1046 |
+
def forward(
|
1047 |
+
self,
|
1048 |
+
input_ids: torch.LongTensor = None,
|
1049 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1050 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1051 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1052 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1053 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1054 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1055 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1056 |
+
labels: Optional[torch.LongTensor] = None,
|
1057 |
+
use_cache: Optional[bool] = None,
|
1058 |
+
output_attentions: Optional[bool] = None,
|
1059 |
+
output_hidden_states: Optional[bool] = None,
|
1060 |
+
return_dict: Optional[bool] = None,
|
1061 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1062 |
+
r"""
|
1063 |
+
Args:
|
1064 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1065 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1066 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1067 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1068 |
+
|
1069 |
+
Returns:
|
1070 |
+
|
1071 |
+
Example:
|
1072 |
+
|
1073 |
+
```python
|
1074 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1075 |
+
|
1076 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1077 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1078 |
+
|
1079 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
1080 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1081 |
+
|
1082 |
+
>>> # Generate
|
1083 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1084 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1085 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
1086 |
+
```"""
|
1087 |
+
|
1088 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1089 |
+
output_hidden_states = (
|
1090 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1091 |
+
)
|
1092 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1093 |
+
|
1094 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1095 |
+
outputs = self.model(
|
1096 |
+
input_ids=input_ids,
|
1097 |
+
attention_mask=attention_mask,
|
1098 |
+
position_ids=position_ids,
|
1099 |
+
past_key_values=past_key_values,
|
1100 |
+
inputs_embeds=inputs_embeds,
|
1101 |
+
pixel_values=pixel_values,
|
1102 |
+
image_embeddings=image_embeddings,
|
1103 |
+
image_attention_mask=image_attention_mask,
|
1104 |
+
use_cache=use_cache,
|
1105 |
+
output_attentions=output_attentions,
|
1106 |
+
output_hidden_states=output_hidden_states,
|
1107 |
+
return_dict=return_dict,
|
1108 |
+
)
|
1109 |
+
|
1110 |
+
hidden_states = outputs[0]
|
1111 |
+
logits = self.lm_head(hidden_states)
|
1112 |
+
|
1113 |
+
loss = None
|
1114 |
+
if labels is not None:
|
1115 |
+
# Shift so that tokens < n predict n
|
1116 |
+
if attention_mask is not None:
|
1117 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1118 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
1119 |
+
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
1120 |
+
else:
|
1121 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1122 |
+
shift_labels = labels[..., 1:].contiguous()
|
1123 |
+
# Flatten the tokens
|
1124 |
+
loss_fct = CrossEntropyLoss()
|
1125 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1126 |
+
|
1127 |
+
if not return_dict:
|
1128 |
+
output = (logits,) + outputs[1:]
|
1129 |
+
return (loss,) + output if loss is not None else output
|
1130 |
+
|
1131 |
+
return CausalLMOutputWithPast(
|
1132 |
+
loss=loss,
|
1133 |
+
logits=logits,
|
1134 |
+
past_key_values=outputs.past_key_values,
|
1135 |
+
hidden_states=outputs.hidden_states,
|
1136 |
+
attentions=outputs.attentions,
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1140 |
+
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
1141 |
+
unwanted_kwargs = ["token_type_ids"]
|
1142 |
+
for kwarg in unwanted_kwargs:
|
1143 |
+
inputs.pop(kwarg, None)
|
1144 |
+
return inputs
|
1145 |
+
|
1146 |
+
@staticmethod
|
1147 |
+
def _expand_inputs_for_generation(
|
1148 |
+
*args,
|
1149 |
+
**model_kwargs,
|
1150 |
+
):
|
1151 |
+
return expand_inputs_for_generation(*args, **model_kwargs)
|
1152 |
+
|
1153 |
+
@staticmethod
|
1154 |
+
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
1155 |
+
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
|
1156 |
+
|
1157 |
+
@staticmethod
|
1158 |
+
def _reorder_cache(past, beam_idx):
|
1159 |
+
reordered_past = ()
|
1160 |
+
for layer_past in past:
|
1161 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1162 |
+
return reordered_past
|
1163 |
+
|
1164 |
+
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
|
1165 |
+
config_vl_model = self.config
|
1166 |
+
|
1167 |
+
language_embed_size = config_vl_model.hidden_size
|
1168 |
+
num_language_layers = config_vl_model.num_hidden_layers
|
1169 |
+
ffn_inner_size = config_vl_model.intermediate_size
|
1170 |
+
|
1171 |
+
vision_config = self.model.vision_model.config
|
1172 |
+
if hasattr(vision_config, "vision_config"):
|
1173 |
+
vision_config = vision_config.vision_config
|
1174 |
+
|
1175 |
+
# Get vision model blocks infos
|
1176 |
+
vision_patch_size = vision_config.patch_size
|
1177 |
+
vision_hidden_size = vision_config.hidden_size
|
1178 |
+
num_vision_layers = vision_config.num_hidden_layers
|
1179 |
+
# The +1 is for the CLS token
|
1180 |
+
single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
|
1181 |
+
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
|
1182 |
+
|
1183 |
+
# Get language and cross-att blocks infos
|
1184 |
+
num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
|
1185 |
+
language_seq_len = data_param.max_seq_len
|
1186 |
+
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1187 |
+
cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1188 |
+
k_v_cross_attn_seq_len = (
|
1189 |
+
(self.config.resampler_n_latents * max_num_images)
|
1190 |
+
if self.config.use_resampler
|
1191 |
+
else (single_image_seq_len * max_num_images)
|
1192 |
+
)
|
1193 |
+
|
1194 |
+
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1195 |
+
num_layers=num_language_layers,
|
1196 |
+
batch_size=hparams.batch_size_per_gpu,
|
1197 |
+
q_seq_len=language_seq_len,
|
1198 |
+
k_seq_len=language_seq_len,
|
1199 |
+
hidden_size=language_embed_size,
|
1200 |
+
kv_in_dim=language_embed_size,
|
1201 |
+
ff_exp_factor=language_exp_factor,
|
1202 |
+
grad_acc_size=hparams.grad_acc_size,
|
1203 |
+
swiglu=True,
|
1204 |
+
vocab_size=tokenizer.vocab_size,
|
1205 |
+
count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
|
1206 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1207 |
+
)
|
1208 |
+
cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1209 |
+
num_layers=num_cross_attn_layers,
|
1210 |
+
batch_size=hparams.batch_size_per_gpu,
|
1211 |
+
q_seq_len=language_seq_len,
|
1212 |
+
k_seq_len=k_v_cross_attn_seq_len,
|
1213 |
+
hidden_size=language_embed_size,
|
1214 |
+
kv_in_dim=vision_hidden_size,
|
1215 |
+
ff_exp_factor=cross_att_exp_factor,
|
1216 |
+
grad_acc_size=hparams.grad_acc_size,
|
1217 |
+
swiglu=True,
|
1218 |
+
vocab_size=None,
|
1219 |
+
count_backward=True,
|
1220 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1221 |
+
)
|
1222 |
+
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1223 |
+
num_layers=num_vision_layers,
|
1224 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1225 |
+
q_seq_len=single_image_seq_len,
|
1226 |
+
k_seq_len=single_image_seq_len,
|
1227 |
+
hidden_size=vision_hidden_size,
|
1228 |
+
kv_in_dim=vision_hidden_size,
|
1229 |
+
ff_exp_factor=vision_exp_factor,
|
1230 |
+
grad_acc_size=hparams.grad_acc_size,
|
1231 |
+
swiglu=False,
|
1232 |
+
vocab_size=None,
|
1233 |
+
count_backward=not hparams.model_params["freeze_vision_layers"],
|
1234 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1235 |
+
)
|
1236 |
+
if self.config.use_resampler:
|
1237 |
+
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
|
1238 |
+
num_layers=self.config.resampler_depth,
|
1239 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1240 |
+
q_seq_len=self.config.resampler_n_latents,
|
1241 |
+
vision_embed_seq_len=single_image_seq_len,
|
1242 |
+
q_k_v_input_dim=vision_hidden_size,
|
1243 |
+
attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
|
1244 |
+
ff_exp_factor=cross_att_exp_factor,
|
1245 |
+
count_backward=True,
|
1246 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1247 |
+
)
|
1248 |
+
flop_count = (
|
1249 |
+
language_tflops_per_batch_per_gpu
|
1250 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1251 |
+
+ vision_tflops_per_batch_per_gpu
|
1252 |
+
+ perceiver_tflops_per_batch_per_gpu
|
1253 |
+
)
|
1254 |
+
else:
|
1255 |
+
flop_count = (
|
1256 |
+
language_tflops_per_batch_per_gpu
|
1257 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1258 |
+
+ vision_tflops_per_batch_per_gpu
|
1259 |
+
)
|
1260 |
+
return flop_count
|
m4/models/vopt/__init__.py
ADDED
File without changes
|
m4/models/vopt/configuration_vopt.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" OPT model configuration"""
|
16 |
+
import os
|
17 |
+
from typing import Tuple, Union
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
OPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"facebook/opt-125m": "https://huggingface.co/facebook/opt-125m/blob/main/config.json",
|
28 |
+
"facebook/opt-350m": "https://huggingface.co/facebook/opt-350m/blob/main/config.json",
|
29 |
+
"facebook/opt-1.3b": "https://huggingface.co/facebook/opt-1.3b/blob/main/config.json",
|
30 |
+
"facebook/opt-2.7b": "https://huggingface.co/facebook/opt-2.7b/blob/main/config.json",
|
31 |
+
"facebook/opt-6.7b": "https://huggingface.co/facebook/opt-6.7b/blob/main/config.json",
|
32 |
+
"facebook/opt-13b": "https://huggingface.co/facebook/opt-13b/blob/main/config.json",
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class VOPTConfig(PretrainedConfig):
|
37 |
+
r"""
|
38 |
+
This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model
|
39 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
40 |
+
defaults will yield a similar configuration to that of the OPT
|
41 |
+
[facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture.
|
42 |
+
|
43 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
44 |
+
documentation from [`PretrainedConfig`] for more information.
|
45 |
+
|
46 |
+
TODO: this doc is completely out of sync with the actual args
|
47 |
+
|
48 |
+
Args:
|
49 |
+
vocab_size (`int`, *optional*, defaults to 50272):
|
50 |
+
Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the
|
51 |
+
`inputs_ids` passed when calling [`OPTModel`]
|
52 |
+
additional_vocab_size (`int`, *optional`, defaults to 0):
|
53 |
+
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
54 |
+
are always trainable whereas regular vocab tokens can be frozen or not.
|
55 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
56 |
+
Dimensionality of the layers and the pooler layer.
|
57 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
58 |
+
Number of decoder layers.
|
59 |
+
ffn_dim (`int`, *optional*, defaults to 3072):
|
60 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
61 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
62 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
63 |
+
activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
|
64 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
65 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
66 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
67 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
68 |
+
just in case (e.g., 512 or 1024 or 2048).
|
69 |
+
do_layer_norm_before (`bool`, *optional*, defaults to `True`):
|
70 |
+
Whether to perform layer normalization before the attention block.
|
71 |
+
word_embed_proj_dim (`int`, *optional*):
|
72 |
+
`word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to
|
73 |
+
`hidden_size`.
|
74 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
75 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
76 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
77 |
+
The dropout ratio for the attention probabilities.
|
78 |
+
layerdrop: (`float`, *optional*, defaults to 0.0):
|
79 |
+
The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
|
80 |
+
details.
|
81 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
82 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
83 |
+
alpha_initializer (`str`, *optional*, defaults to `"ones"`):
|
84 |
+
Initialization type for the alphas.
|
85 |
+
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
86 |
+
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
|
87 |
+
alpha_type (`str`, *optional*, defaults to `"vector"`):
|
88 |
+
Whether the gating alphas should be vectors or single floats.
|
89 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
90 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
91 |
+
cross_layer_interval (`int`, *optional*, default to 1)
|
92 |
+
Interval for cross attention (from text to image) layers.
|
93 |
+
Example:
|
94 |
+
|
95 |
+
```python
|
96 |
+
>>> from transformers import OPTModel, OPTConfig
|
97 |
+
|
98 |
+
>>> # Initializing a OPT facebook/opt-large style configuration
|
99 |
+
>>> configuration = OPTConfig()
|
100 |
+
|
101 |
+
>>> # Initializing a model from the facebook/opt-large style configuration
|
102 |
+
>>> model = OPTModel(configuration)
|
103 |
+
|
104 |
+
>>> # Accessing the model configuration
|
105 |
+
>>> configuration = model.config
|
106 |
+
```"""
|
107 |
+
model_type = "vopt"
|
108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
vocab_size=50272,
|
113 |
+
additional_vocab_size=0,
|
114 |
+
hidden_size=768,
|
115 |
+
num_hidden_layers=12,
|
116 |
+
ffn_dim=3072,
|
117 |
+
max_position_embeddings=2048,
|
118 |
+
do_layer_norm_before=True,
|
119 |
+
_remove_final_layer_norm=False,
|
120 |
+
word_embed_proj_dim=None,
|
121 |
+
dropout=0.1,
|
122 |
+
attention_dropout=0.0,
|
123 |
+
num_attention_heads=12,
|
124 |
+
activation_function="relu",
|
125 |
+
layerdrop=0.0,
|
126 |
+
init_std=0.02,
|
127 |
+
alpha_initializer="ones",
|
128 |
+
alphas_initializer_range=0.0,
|
129 |
+
alpha_type="vector",
|
130 |
+
use_cache=True,
|
131 |
+
pad_token_id=1,
|
132 |
+
bos_token_id=2,
|
133 |
+
eos_token_id=2,
|
134 |
+
cross_layer_interval=1,
|
135 |
+
cross_layer_activation_function="swiglu",
|
136 |
+
normformer_layer_norms=False,
|
137 |
+
qk_layer_norms=False,
|
138 |
+
rms_norm=False,
|
139 |
+
qk_layer_norms_perceiver=False,
|
140 |
+
tie_word_embeddings=False,
|
141 |
+
freeze_text_layers=True,
|
142 |
+
freeze_text_module_exceptions=[],
|
143 |
+
freeze_lm_head=False,
|
144 |
+
freeze_vision_layers=True,
|
145 |
+
freeze_vision_module_exceptions=[],
|
146 |
+
vision_model_name="google/vit-base-patch16-224",
|
147 |
+
vision_model_params="{}",
|
148 |
+
vision_embed_dim=768,
|
149 |
+
vision_image_size=224,
|
150 |
+
image_token_index=50257, # TODO: change this to right value
|
151 |
+
use_resampler=False,
|
152 |
+
resampler_n_latents=64,
|
153 |
+
resampler_depth=6,
|
154 |
+
resampler_n_heads=16,
|
155 |
+
resampler_head_dim=96,
|
156 |
+
**kwargs,
|
157 |
+
):
|
158 |
+
super().__init__(
|
159 |
+
pad_token_id=pad_token_id,
|
160 |
+
bos_token_id=bos_token_id,
|
161 |
+
eos_token_id=eos_token_id,
|
162 |
+
tie_word_embeddings=tie_word_embeddings,
|
163 |
+
**kwargs,
|
164 |
+
)
|
165 |
+
self.vocab_size = vocab_size
|
166 |
+
self.additional_vocab_size = additional_vocab_size
|
167 |
+
self.max_position_embeddings = max_position_embeddings
|
168 |
+
self.num_attention_heads = num_attention_heads
|
169 |
+
self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
|
170 |
+
self.ffn_dim = ffn_dim
|
171 |
+
self.hidden_size = hidden_size
|
172 |
+
self.num_hidden_layers = num_hidden_layers
|
173 |
+
self.dropout = dropout
|
174 |
+
self.attention_dropout = attention_dropout
|
175 |
+
self.activation_function = activation_function
|
176 |
+
self.init_std = init_std
|
177 |
+
self.alpha_initializer = alpha_initializer
|
178 |
+
self.alphas_initializer_range = alphas_initializer_range
|
179 |
+
self.alpha_type = alpha_type
|
180 |
+
self.layerdrop = layerdrop
|
181 |
+
self.use_cache = use_cache
|
182 |
+
self.do_layer_norm_before = do_layer_norm_before
|
183 |
+
|
184 |
+
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
|
185 |
+
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
186 |
+
# see https://github.com/facebookresearch/metaseq/pull/164
|
187 |
+
self._remove_final_layer_norm = _remove_final_layer_norm
|
188 |
+
|
189 |
+
self.cross_layer_interval = cross_layer_interval
|
190 |
+
self.cross_layer_activation_function = cross_layer_activation_function
|
191 |
+
self.normformer_layer_norms = normformer_layer_norms
|
192 |
+
self.qk_layer_norms = qk_layer_norms
|
193 |
+
self.rms_norm = rms_norm
|
194 |
+
self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
|
195 |
+
self.freeze_vision_layers = freeze_vision_layers
|
196 |
+
self.vision_model_name = vision_model_name
|
197 |
+
self.vision_model_params = vision_model_params
|
198 |
+
|
199 |
+
self.tie_word_embeddings = tie_word_embeddings
|
200 |
+
self.freeze_text_layers = freeze_text_layers
|
201 |
+
self.freeze_text_module_exceptions = freeze_text_module_exceptions
|
202 |
+
self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
|
203 |
+
self.freeze_lm_head = freeze_lm_head
|
204 |
+
self.image_token_index = image_token_index
|
205 |
+
|
206 |
+
self.vision_embed_dim = vision_embed_dim
|
207 |
+
self.vision_image_size = vision_image_size
|
208 |
+
|
209 |
+
# Resampler params
|
210 |
+
self.use_resampler = use_resampler
|
211 |
+
self.resampler_n_latents = resampler_n_latents
|
212 |
+
self.resampler_depth = resampler_depth
|
213 |
+
self.resampler_n_heads = resampler_n_heads
|
214 |
+
self.resampler_head_dim = resampler_head_dim
|
215 |
+
|
216 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
217 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
218 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
219 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
220 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
221 |
+
|
222 |
+
def check_compatibilities(self):
|
223 |
+
vision_model_params = eval(self.vision_model_params)
|
224 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
225 |
+
if hasattr(config, "vision_config"):
|
226 |
+
vision_config = config.vision_config
|
227 |
+
else:
|
228 |
+
vision_config = config
|
229 |
+
vision_embed_dim = vision_config.hidden_size
|
230 |
+
if self.vision_embed_dim != vision_embed_dim:
|
231 |
+
raise ValueError(
|
232 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
233 |
+
f" ({vision_embed_dim})"
|
234 |
+
)
|
235 |
+
vision_image_size = vision_config.image_size
|
236 |
+
if self.vision_image_size != vision_image_size:
|
237 |
+
raise ValueError(
|
238 |
+
f"vision_image_size ({self.vision_image_size}) must match the hidden size of the vision model"
|
239 |
+
f" ({vision_image_size})"
|
240 |
+
)
|
241 |
+
|
242 |
+
@classmethod
|
243 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
244 |
+
outputs = super(VOPTConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
245 |
+
if isinstance(outputs, Tuple):
|
246 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
247 |
+
outputs[0].check_compatibilities()
|
248 |
+
else:
|
249 |
+
outputs.check_compatibilities()
|
250 |
+
return outputs
|
m4/models/vopt/make_tiny_model.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# This script creates a super tiny model that is useful inside tests, when we just want to test that
|
4 |
+
# the machinery works, without needing to check the quality of the outcomes.
|
5 |
+
#
|
6 |
+
# usage: adjust the configs if wanted, but otherwise just run the script
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
from types import SimpleNamespace
|
10 |
+
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from m4.models.vopt.modeling_vopt import VOPTConfig, VOPTForCausalLM
|
15 |
+
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
|
16 |
+
from m4.training.utils import get_tokenizer
|
17 |
+
|
18 |
+
|
19 |
+
mname_tiny = "tiny-random-vopt-clip"
|
20 |
+
|
21 |
+
path = Path(mname_tiny)
|
22 |
+
path.mkdir(parents=True, exist_ok=True)
|
23 |
+
|
24 |
+
# from the hardcoded https://github.com/huggingface/m4/blob/adf102f0000cb2632cd8a3ebb87398c65e448a97/m4/training/main.py#L80
|
25 |
+
additional_vocab_size = 2
|
26 |
+
|
27 |
+
config = VOPTConfig()
|
28 |
+
config.update(
|
29 |
+
dict(
|
30 |
+
ffn_dim=64,
|
31 |
+
hidden_size=16,
|
32 |
+
max_position_embeddings=128,
|
33 |
+
num_attention_heads=4,
|
34 |
+
num_hidden_layers=2,
|
35 |
+
word_embed_proj_dim=16,
|
36 |
+
max_new_tokens=100,
|
37 |
+
use_resampler=True,
|
38 |
+
resampler_depth=2,
|
39 |
+
resampler_head_dim=8,
|
40 |
+
resampler_n_heads=2,
|
41 |
+
resampler_n_latents=16,
|
42 |
+
vision_embed_dim=32,
|
43 |
+
vision_image_size=30,
|
44 |
+
vision_model_name="hf-internal-testing/tiny-random-clip",
|
45 |
+
vision_model_params="{}",
|
46 |
+
vocab_size=50265,
|
47 |
+
additional_vocab_size=additional_vocab_size,
|
48 |
+
)
|
49 |
+
)
|
50 |
+
|
51 |
+
# print(config)
|
52 |
+
# can now modify config to say tiny values
|
53 |
+
|
54 |
+
model = VOPTForCausalLM.from_config(config)
|
55 |
+
# print(model.config)
|
56 |
+
# print(model)
|
57 |
+
|
58 |
+
tokenizer_config = dict(
|
59 |
+
tokenizer_add_special_tokens="{}",
|
60 |
+
tokenizer_add_tokens=(
|
61 |
+
'[AddedToken("<fake_token_around_image>", rstrip=False, lstrip=False), AddedToken("<image>", rstrip=False,'
|
62 |
+
" lstrip=False)]"
|
63 |
+
),
|
64 |
+
tokenizer_name="facebook/opt-13b",
|
65 |
+
tokenizer_params='{"use_fast":True}',
|
66 |
+
)
|
67 |
+
tokenizer_config = SimpleNamespace(**tokenizer_config)
|
68 |
+
# print(tokenizer_config)
|
69 |
+
|
70 |
+
tokenizer = get_tokenizer(
|
71 |
+
tokenizer_name=tokenizer_config.tokenizer_name,
|
72 |
+
tokenizer_add_tokens=tokenizer_config.tokenizer_add_tokens,
|
73 |
+
tokenizer_add_special_tokens=tokenizer_config.tokenizer_add_special_tokens,
|
74 |
+
tokenizer_params=tokenizer_config.tokenizer_params,
|
75 |
+
additional_vocab_size=model.config.additional_vocab_size,
|
76 |
+
model_vocab_size=model.config.vocab_size,
|
77 |
+
)
|
78 |
+
assert "<image>" in tokenizer.get_vocab()
|
79 |
+
|
80 |
+
# Test w/ one image and one text
|
81 |
+
query = "<fake_token_around_image><image><fake_token_around_image>This is a picture of a cat."
|
82 |
+
query_tokens = tokenizer(query, return_tensors="pt")
|
83 |
+
|
84 |
+
num_images_per_ex = 1
|
85 |
+
pixel_values = transforms.ToTensor()(Image.new("RGB", (30, 30))).repeat(1, 1, 1, 1).unsqueeze(0)
|
86 |
+
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(query_tokens["input_ids"], tokenizer)
|
87 |
+
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=num_images_per_ex)
|
88 |
+
|
89 |
+
input = {
|
90 |
+
"input_ids": query_tokens["input_ids"],
|
91 |
+
"attention_mask": query_tokens["attention_mask"],
|
92 |
+
"pixel_values": pixel_values,
|
93 |
+
"pixel_values": pixel_values,
|
94 |
+
"image_attention_mask": image_attention_mask,
|
95 |
+
}
|
96 |
+
# debug shapes
|
97 |
+
# print(query_tokens["input_ids"].shape)
|
98 |
+
# print(query_tokens["attention_mask"].shape)
|
99 |
+
# print(pixel_values.shape)
|
100 |
+
# print(image_attention_mask.shape)
|
101 |
+
|
102 |
+
out_gen = model.generate(**input)
|
103 |
+
text = tokenizer.batch_decode(out_gen)
|
104 |
+
# print(text)
|
105 |
+
|
106 |
+
# Save model + config + tokenizer
|
107 |
+
model.half() # makes it smaller
|
108 |
+
model.save_pretrained(path)
|
109 |
+
tokenizer.save_pretrained(path)
|
110 |
+
|
111 |
+
# test we can load it back
|
112 |
+
model = VOPTForCausalLM.from_pretrained(path)
|
113 |
+
|
114 |
+
print(f"Generated {mname_tiny} - Upload the generated folder to the hub")
|
m4/models/vopt/modeling_vopt.py
ADDED
@@ -0,0 +1,1513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch OPT model."""
|
16 |
+
import random
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import CrossEntropyLoss
|
23 |
+
from transformers.activations import ACT2FN
|
24 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
25 |
+
from transformers.utils import (
|
26 |
+
ContextManagers,
|
27 |
+
add_code_sample_docstrings,
|
28 |
+
add_start_docstrings,
|
29 |
+
add_start_docstrings_to_model_forward,
|
30 |
+
replace_return_docstrings,
|
31 |
+
)
|
32 |
+
|
33 |
+
from m4.models import DecoupledEmbedding, DecoupledLinear
|
34 |
+
from m4.models.common import (
|
35 |
+
expand_inputs_for_generation,
|
36 |
+
prepare_inputs_for_generation,
|
37 |
+
update_model_kwargs_for_generation,
|
38 |
+
)
|
39 |
+
from m4.models.custom_modules import VLOOMPreTrainedModelBase
|
40 |
+
from m4.models.perceiver.perceiver import PerceiverResampler
|
41 |
+
from m4.models.vopt.configuration_vopt import VOPTConfig
|
42 |
+
from m4.training.utils import (
|
43 |
+
compute_perceiver_tflops_per_batch_per_gpu,
|
44 |
+
compute_tflops_per_batch_per_gpu,
|
45 |
+
deepspeed_gathered_parameters_context_manager,
|
46 |
+
freeze_model,
|
47 |
+
)
|
48 |
+
from m4.utils import logging
|
49 |
+
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__)
|
52 |
+
|
53 |
+
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
|
54 |
+
_CONFIG_FOR_DOC = "VOPTConfig"
|
55 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
56 |
+
|
57 |
+
# Base model docstring
|
58 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
|
59 |
+
|
60 |
+
# SequenceClassification docstring
|
61 |
+
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
|
62 |
+
_SEQ_CLASS_EXPECTED_LOSS = 1.71
|
63 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
|
64 |
+
|
65 |
+
|
66 |
+
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
67 |
+
"facebook/opt-125m",
|
68 |
+
"facebook/opt-350m",
|
69 |
+
"facebook/opt-1.3b",
|
70 |
+
"facebook/opt-2.7b",
|
71 |
+
"facebook/opt-6.7b",
|
72 |
+
"facebook/opt-13b",
|
73 |
+
"facebook/opt-30b",
|
74 |
+
# See all OPT models at https://huggingface.co/models?filter=opt
|
75 |
+
]
|
76 |
+
|
77 |
+
|
78 |
+
class SwiGLUActivation(nn.Module):
|
79 |
+
def __init__(self, in_features: int, out_features: int):
|
80 |
+
super().__init__()
|
81 |
+
self.gate = nn.Linear(in_features, out_features, bias=False)
|
82 |
+
|
83 |
+
def forward(self, hidden_states_to_gate, hidden_states):
|
84 |
+
gate = self.gate(hidden_states)
|
85 |
+
return nn.functional.silu(gate) * hidden_states_to_gate
|
86 |
+
|
87 |
+
|
88 |
+
# Taken from LLaMA codebase
|
89 |
+
class RMSNorm(torch.nn.Module):
|
90 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
91 |
+
super().__init__()
|
92 |
+
self.eps = eps
|
93 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
94 |
+
|
95 |
+
def _norm(self, x):
|
96 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
output = self._norm(x.float()).type_as(x)
|
100 |
+
return output * self.weight
|
101 |
+
|
102 |
+
|
103 |
+
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
104 |
+
"""
|
105 |
+
Make causal mask used for bi-directional self-attention.
|
106 |
+
"""
|
107 |
+
bsz, tgt_len = input_ids_shape
|
108 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
|
109 |
+
mask_cond = torch.arange(mask.size(-1))
|
110 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
111 |
+
mask = mask.to(dtype)
|
112 |
+
|
113 |
+
if past_key_values_length > 0:
|
114 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
115 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
116 |
+
|
117 |
+
|
118 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
119 |
+
"""
|
120 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
121 |
+
"""
|
122 |
+
bsz, src_len = mask.size()
|
123 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
124 |
+
|
125 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
126 |
+
|
127 |
+
inverted_mask = 1.0 - expanded_mask
|
128 |
+
|
129 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
130 |
+
|
131 |
+
|
132 |
+
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
133 |
+
"""
|
134 |
+
This module learns positional embeddings up to a fixed maximum size.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
138 |
+
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
139 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
140 |
+
self.offset = 2
|
141 |
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
142 |
+
|
143 |
+
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
|
144 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
145 |
+
attention_mask = attention_mask.long()
|
146 |
+
|
147 |
+
# create positions depending on attention_mask
|
148 |
+
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
|
149 |
+
|
150 |
+
# cut positions if `past_key_values_length` is > 0
|
151 |
+
positions = positions[:, past_key_values_length:]
|
152 |
+
|
153 |
+
return super().forward(positions + self.offset)
|
154 |
+
|
155 |
+
|
156 |
+
class OPTAttention(nn.Module):
|
157 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
embed_dim: int,
|
162 |
+
num_heads: int,
|
163 |
+
dropout: float = 0.0,
|
164 |
+
is_decoder: bool = False,
|
165 |
+
bias: bool = True,
|
166 |
+
is_cross_attention=False,
|
167 |
+
config=None,
|
168 |
+
qk_layer_norms=False,
|
169 |
+
):
|
170 |
+
super().__init__()
|
171 |
+
self.embed_dim = embed_dim
|
172 |
+
self.num_heads = num_heads
|
173 |
+
self.dropout = dropout
|
174 |
+
self.head_dim = embed_dim // num_heads
|
175 |
+
|
176 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
177 |
+
raise ValueError(
|
178 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
179 |
+
f" and `num_heads`: {num_heads})."
|
180 |
+
)
|
181 |
+
self.scaling = self.head_dim**-0.5
|
182 |
+
self.is_decoder = is_decoder
|
183 |
+
|
184 |
+
self.is_cross_attention = is_cross_attention
|
185 |
+
|
186 |
+
if self.is_cross_attention:
|
187 |
+
kv_input_dim = self.hidden_size if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim
|
188 |
+
self.k_proj = nn.Linear(kv_input_dim, embed_dim, bias=bias)
|
189 |
+
self.v_proj = nn.Linear(kv_input_dim, embed_dim, bias=bias)
|
190 |
+
|
191 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
192 |
+
else:
|
193 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
194 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
195 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
196 |
+
|
197 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
198 |
+
|
199 |
+
self.qk_layer_norms = qk_layer_norms
|
200 |
+
if self.qk_layer_norms and config.rms_norm:
|
201 |
+
self.q_layer_norm = RMSNorm(self.head_dim, eps=1e-6)
|
202 |
+
self.k_layer_norm = RMSNorm(self.head_dim, eps=1e-6)
|
203 |
+
elif self.qk_layer_norms:
|
204 |
+
self.q_layer_norm = nn.LayerNorm(self.head_dim)
|
205 |
+
self.k_layer_norm = nn.LayerNorm(self.head_dim)
|
206 |
+
|
207 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
208 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
209 |
+
|
210 |
+
def forward(
|
211 |
+
self,
|
212 |
+
hidden_states: torch.Tensor,
|
213 |
+
key_value_states: Optional[torch.Tensor] = None,
|
214 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
215 |
+
attention_mask: Optional[torch.Tensor] = None,
|
216 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
217 |
+
output_attentions: bool = False,
|
218 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
219 |
+
"""Input shape: Batch x Time x Channel"""
|
220 |
+
|
221 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
222 |
+
# for the decoder
|
223 |
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
224 |
+
|
225 |
+
bsz, tgt_len, _ = hidden_states.size()
|
226 |
+
|
227 |
+
# get query proj
|
228 |
+
query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
|
229 |
+
# get key, value proj
|
230 |
+
if is_cross_attention and past_key_value is not None:
|
231 |
+
# reuse k,v, cross_attentions
|
232 |
+
key_states = past_key_value[0]
|
233 |
+
value_states = past_key_value[1]
|
234 |
+
elif is_cross_attention:
|
235 |
+
# cross_attentions
|
236 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
237 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
238 |
+
elif past_key_value is not None:
|
239 |
+
# reuse k, v, self_attention
|
240 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
241 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
242 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
243 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
244 |
+
else:
|
245 |
+
# self_attention
|
246 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
247 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
248 |
+
|
249 |
+
if self.is_decoder:
|
250 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
251 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
252 |
+
# key/value_states (first "if" case)
|
253 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
254 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
255 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
256 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
257 |
+
past_key_value = (key_states, value_states)
|
258 |
+
|
259 |
+
if self.qk_layer_norms:
|
260 |
+
query_states = self.q_layer_norm(query_states)
|
261 |
+
key_states = self.k_layer_norm(key_states)
|
262 |
+
|
263 |
+
src_len = key_states.size(2)
|
264 |
+
|
265 |
+
if attention_mask is not None:
|
266 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
267 |
+
raise ValueError(
|
268 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
269 |
+
)
|
270 |
+
if layer_head_mask is not None:
|
271 |
+
if layer_head_mask.size() != (self.num_heads,):
|
272 |
+
raise ValueError(
|
273 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
274 |
+
f" {layer_head_mask.size()}"
|
275 |
+
)
|
276 |
+
attention_mask = attention_mask.expand(-1, self.num_heads, -1, -1)
|
277 |
+
attention_mask = attention_mask + layer_head_mask.view(1, -1, 1, 1)
|
278 |
+
|
279 |
+
attn_output = nn.functional.scaled_dot_product_attention(
|
280 |
+
query_states,
|
281 |
+
key_states,
|
282 |
+
value_states,
|
283 |
+
attn_mask=attention_mask,
|
284 |
+
dropout_p=self.dropout,
|
285 |
+
)
|
286 |
+
|
287 |
+
attn_weights_reshaped = None
|
288 |
+
logger.warning_once(
|
289 |
+
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
|
290 |
+
)
|
291 |
+
attn_output = attn_output.transpose(1, 2)
|
292 |
+
|
293 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
294 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
295 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
296 |
+
|
297 |
+
attn_output = self.out_proj(attn_output)
|
298 |
+
|
299 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
300 |
+
|
301 |
+
|
302 |
+
class OPTDecoderLayer(nn.Module):
|
303 |
+
def __init__(self, config: VOPTConfig):
|
304 |
+
super().__init__()
|
305 |
+
self.embed_dim = config.hidden_size
|
306 |
+
self.self_attn = OPTAttention(
|
307 |
+
embed_dim=self.embed_dim,
|
308 |
+
num_heads=config.num_attention_heads,
|
309 |
+
dropout=config.attention_dropout,
|
310 |
+
is_decoder=True,
|
311 |
+
config=config,
|
312 |
+
)
|
313 |
+
self.do_layer_norm_before = config.do_layer_norm_before
|
314 |
+
self.dropout = config.dropout
|
315 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
316 |
+
|
317 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
318 |
+
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
|
319 |
+
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
|
320 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
hidden_states: torch.Tensor,
|
325 |
+
attention_mask: Optional[torch.Tensor] = None,
|
326 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
327 |
+
output_attentions: Optional[bool] = False,
|
328 |
+
use_cache: Optional[bool] = False,
|
329 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
330 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
334 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
335 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
336 |
+
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
337 |
+
`(encoder_attention_heads,)`.
|
338 |
+
output_attentions (`bool`, *optional*):
|
339 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
340 |
+
returned tensors for more detail.
|
341 |
+
use_cache (`bool`, *optional*):
|
342 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
343 |
+
(see `past_key_values`).
|
344 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
345 |
+
"""
|
346 |
+
|
347 |
+
residual = hidden_states
|
348 |
+
|
349 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
350 |
+
if self.do_layer_norm_before:
|
351 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
352 |
+
|
353 |
+
# Self Attention
|
354 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
355 |
+
hidden_states=hidden_states,
|
356 |
+
past_key_value=past_key_value,
|
357 |
+
attention_mask=attention_mask,
|
358 |
+
layer_head_mask=layer_head_mask,
|
359 |
+
output_attentions=output_attentions,
|
360 |
+
)
|
361 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
362 |
+
hidden_states = residual + hidden_states
|
363 |
+
|
364 |
+
# 350m applies layer norm AFTER attention
|
365 |
+
if not self.do_layer_norm_before:
|
366 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
367 |
+
|
368 |
+
# Fully Connected
|
369 |
+
hidden_states_shape = hidden_states.shape
|
370 |
+
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
371 |
+
residual = hidden_states
|
372 |
+
|
373 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
374 |
+
if self.do_layer_norm_before:
|
375 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
376 |
+
|
377 |
+
hidden_states = self.fc1(hidden_states)
|
378 |
+
hidden_states = self.activation_fn(hidden_states)
|
379 |
+
|
380 |
+
hidden_states = self.fc2(hidden_states)
|
381 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
382 |
+
|
383 |
+
hidden_states = (residual + hidden_states).view(hidden_states_shape)
|
384 |
+
|
385 |
+
# 350m applies layer norm AFTER attention
|
386 |
+
if not self.do_layer_norm_before:
|
387 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
388 |
+
|
389 |
+
outputs = (hidden_states,)
|
390 |
+
|
391 |
+
if output_attentions:
|
392 |
+
outputs += (self_attn_weights,)
|
393 |
+
|
394 |
+
if use_cache:
|
395 |
+
outputs += (present_key_value,)
|
396 |
+
|
397 |
+
return outputs
|
398 |
+
|
399 |
+
|
400 |
+
class VOPTGatedAttentionLayer(nn.Module):
|
401 |
+
def __init__(self, config: VOPTConfig):
|
402 |
+
"""
|
403 |
+
Note: Based on `tr_101_cm401xPMD09_nobias`, setting the biases to False in all of the nn.Linear for the gated cross attention.
|
404 |
+
Provide a small stability gain at opt-13b scale.
|
405 |
+
"""
|
406 |
+
super().__init__()
|
407 |
+
self.embed_dim = config.hidden_size
|
408 |
+
self.cross_attn = OPTAttention(
|
409 |
+
embed_dim=self.embed_dim,
|
410 |
+
num_heads=config.num_attention_heads,
|
411 |
+
dropout=config.attention_dropout,
|
412 |
+
is_decoder=True,
|
413 |
+
config=config,
|
414 |
+
is_cross_attention=True,
|
415 |
+
bias=False,
|
416 |
+
qk_layer_norms=config.qk_layer_norms,
|
417 |
+
)
|
418 |
+
self.do_layer_norm_before = config.do_layer_norm_before
|
419 |
+
self.normformer_layer_norms = config.normformer_layer_norms
|
420 |
+
self.dropout = config.dropout
|
421 |
+
if config.cross_layer_activation_function == "swiglu":
|
422 |
+
# We cannot put `SwiGLUActivation` in `ACT2FN` because it takes two arguments (`in_features` and
|
423 |
+
# `out_features`) that we don't know until entering this module.
|
424 |
+
self.activation_fn = SwiGLUActivation(self.embed_dim, config.ffn_dim)
|
425 |
+
else:
|
426 |
+
self.activation_fn = ACT2FN[config.cross_layer_activation_function]
|
427 |
+
|
428 |
+
if config.rms_norm:
|
429 |
+
self.self_attn_layer_norm = RMSNorm(self.embed_dim, eps=1e-6)
|
430 |
+
else:
|
431 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
432 |
+
|
433 |
+
if self.normformer_layer_norms:
|
434 |
+
self.self_attn_post_layer_norm = nn.LayerNorm(self.embed_dim)
|
435 |
+
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
|
436 |
+
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
|
437 |
+
|
438 |
+
if config.rms_norm:
|
439 |
+
self.final_layer_norm = RMSNorm(self.embed_dim, eps=1e-6)
|
440 |
+
else:
|
441 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
442 |
+
|
443 |
+
if self.normformer_layer_norms:
|
444 |
+
self.mlp_post_layer_norm = nn.LayerNorm(config.ffn_dim)
|
445 |
+
|
446 |
+
self.act_cross_attn = nn.Tanh()
|
447 |
+
self.act_dense = nn.Tanh()
|
448 |
+
|
449 |
+
if config.alpha_initializer == "zeros":
|
450 |
+
if config.alpha_type == "vector":
|
451 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
452 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
453 |
+
elif config.alpha_type == "float":
|
454 |
+
self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
|
455 |
+
self.alpha_dense = nn.Parameter(torch.zeros(1))
|
456 |
+
else:
|
457 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
458 |
+
|
459 |
+
elif config.alpha_initializer == "ones":
|
460 |
+
if config.alpha_type == "vector":
|
461 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.embed_dim))
|
462 |
+
self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.embed_dim))
|
463 |
+
elif config.alpha_type == "float":
|
464 |
+
self.alpha_cross_attn = nn.Parameter(torch.ones(1))
|
465 |
+
self.alpha_dense = nn.Parameter(torch.ones(1))
|
466 |
+
else:
|
467 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
468 |
+
|
469 |
+
elif config.alpha_initializer in {"normal", "gaussian", "random"}:
|
470 |
+
if config.alpha_type == "vector":
|
471 |
+
self.alpha_cross_attn = nn.Parameter(
|
472 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.embed_dim))
|
473 |
+
)
|
474 |
+
self.alpha_dense = nn.Parameter(
|
475 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.embed_dim))
|
476 |
+
)
|
477 |
+
elif config.alpha_type == "float":
|
478 |
+
self.alpha_cross_attn = nn.Parameter(
|
479 |
+
torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
|
480 |
+
)
|
481 |
+
self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
|
482 |
+
else:
|
483 |
+
raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
|
484 |
+
|
485 |
+
else:
|
486 |
+
raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
|
487 |
+
|
488 |
+
assert hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")
|
489 |
+
|
490 |
+
def forward(
|
491 |
+
self,
|
492 |
+
hidden_states: torch.Tensor,
|
493 |
+
attention_mask: Optional[torch.Tensor] = None,
|
494 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
495 |
+
image_hidden_states: Optional[torch.Tensor] = None,
|
496 |
+
image_attention_mask: Optional[torch.FloatTensor] = None,
|
497 |
+
output_attentions: Optional[bool] = False,
|
498 |
+
use_cache: Optional[bool] = False,
|
499 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
500 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
501 |
+
"""
|
502 |
+
Args:
|
503 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
504 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
505 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
506 |
+
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
507 |
+
`(encoder_attention_heads,)`.
|
508 |
+
output_attentions (`bool`, *optional*):
|
509 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
510 |
+
returned tensors for more detail.
|
511 |
+
use_cache (`bool`, *optional*):
|
512 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
513 |
+
(see `past_key_values`).
|
514 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
515 |
+
"""
|
516 |
+
if image_hidden_states is None:
|
517 |
+
raise ValueError(
|
518 |
+
"`image_hidden_states` is required for VOPT cross attention module which are visual features to be"
|
519 |
+
" conditioned on."
|
520 |
+
)
|
521 |
+
|
522 |
+
if past_key_value is not None:
|
523 |
+
raise NotImplementedError("Past key value states are not implemented for VOPT cross attention module.")
|
524 |
+
|
525 |
+
residual = hidden_states
|
526 |
+
|
527 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
528 |
+
if self.do_layer_norm_before:
|
529 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
530 |
+
|
531 |
+
# Self Attention
|
532 |
+
hidden_states, self_attn_weights, present_key_value = self.cross_attn(
|
533 |
+
hidden_states=hidden_states,
|
534 |
+
key_value_states=image_hidden_states,
|
535 |
+
attention_mask=image_attention_mask,
|
536 |
+
layer_head_mask=layer_head_mask,
|
537 |
+
output_attentions=output_attentions,
|
538 |
+
)
|
539 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
540 |
+
if self.normformer_layer_norms:
|
541 |
+
hidden_states = self.self_attn_post_layer_norm(hidden_states)
|
542 |
+
hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
543 |
+
|
544 |
+
# 350m applies layer norm AFTER attention
|
545 |
+
if not self.do_layer_norm_before:
|
546 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
547 |
+
|
548 |
+
# Fully Connected
|
549 |
+
hidden_states_shape = hidden_states.shape
|
550 |
+
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
551 |
+
residual = hidden_states
|
552 |
+
|
553 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
554 |
+
if self.do_layer_norm_before:
|
555 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
556 |
+
|
557 |
+
hidden_states_to_gate = self.fc1(hidden_states)
|
558 |
+
if isinstance(self.activation_fn, SwiGLUActivation):
|
559 |
+
hidden_states = self.activation_fn(hidden_states_to_gate, hidden_states)
|
560 |
+
else:
|
561 |
+
hidden_states = self.activation_fn(hidden_states_to_gate)
|
562 |
+
|
563 |
+
if self.normformer_layer_norms:
|
564 |
+
hidden_states = self.mlp_post_layer_norm(hidden_states)
|
565 |
+
hidden_states = self.fc2(hidden_states)
|
566 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
567 |
+
|
568 |
+
hidden_states = (residual + self.act_dense(self.alpha_dense) * hidden_states).view(hidden_states_shape)
|
569 |
+
|
570 |
+
# 350m applies layer norm AFTER attention
|
571 |
+
if not self.do_layer_norm_before:
|
572 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
573 |
+
|
574 |
+
outputs = (hidden_states,)
|
575 |
+
|
576 |
+
if output_attentions:
|
577 |
+
outputs += (self_attn_weights,)
|
578 |
+
|
579 |
+
if use_cache:
|
580 |
+
outputs += (present_key_value,)
|
581 |
+
|
582 |
+
return outputs
|
583 |
+
|
584 |
+
|
585 |
+
OPT_START_DOCSTRING = r"""
|
586 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
587 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
588 |
+
etc.)
|
589 |
+
|
590 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
591 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
592 |
+
and behavior.
|
593 |
+
|
594 |
+
Parameters:
|
595 |
+
config ([`VOPTConfig`]):
|
596 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
597 |
+
load the weights associated with the model, only the configuration. Check out the
|
598 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
599 |
+
"""
|
600 |
+
|
601 |
+
|
602 |
+
@add_start_docstrings(
|
603 |
+
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
|
604 |
+
OPT_START_DOCSTRING,
|
605 |
+
)
|
606 |
+
class VOPTPreTrainedModel(VLOOMPreTrainedModelBase):
|
607 |
+
config_class = VOPTConfig
|
608 |
+
base_model_prefix = "model"
|
609 |
+
supports_gradient_checkpointing = True
|
610 |
+
_no_split_modules = ["OPTDecoderLayer", "VOPTGatedAttentionLayer", "CLIPEncoderLayer"]
|
611 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
612 |
+
|
613 |
+
def _init_weights(self, module):
|
614 |
+
def init_a_linear(module, mean=0.0, std=self.config.init_std):
|
615 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
|
616 |
+
module.weight.data.normal_(mean=mean, std=std)
|
617 |
+
if module.bias is not None:
|
618 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
|
619 |
+
module.bias.data.zero_()
|
620 |
+
|
621 |
+
if isinstance(module, VOPTGatedAttentionLayer):
|
622 |
+
for sub_module_name, sub_module in module.named_modules():
|
623 |
+
if isinstance(sub_module, nn.Linear):
|
624 |
+
if "fc2" in sub_module_name:
|
625 |
+
factor = 2 * self.config.num_hidden_layers
|
626 |
+
else:
|
627 |
+
factor = 1.0
|
628 |
+
init_a_linear(sub_module, std=(0.4 / (sub_module.in_features * factor)) ** 0.5)
|
629 |
+
elif isinstance(module, PerceiverResampler):
|
630 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
|
631 |
+
module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.vision_embed_dim) ** 0.5)
|
632 |
+
for sub_module_name, sub_module in module.named_modules():
|
633 |
+
if isinstance(sub_module, nn.Linear):
|
634 |
+
if "c_proj" in sub_module_name:
|
635 |
+
factor = 2 * self.config.num_hidden_layers
|
636 |
+
else:
|
637 |
+
factor = 1.0
|
638 |
+
init_a_linear(sub_module, std=(0.4 / (self.config.vision_embed_dim * factor)) ** 0.5)
|
639 |
+
elif isinstance(module, nn.Embedding):
|
640 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
|
641 |
+
module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
|
642 |
+
if module.padding_idx is not None:
|
643 |
+
module.weight.data[module.padding_idx].zero_()
|
644 |
+
elif isinstance(module, DecoupledLinear):
|
645 |
+
if hasattr(module, "additional_fc"):
|
646 |
+
init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)
|
647 |
+
|
648 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
649 |
+
if isinstance(module, (VOPTDecoder)):
|
650 |
+
module.gradient_checkpointing = value
|
651 |
+
|
652 |
+
@classmethod
|
653 |
+
def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
|
654 |
+
# this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
|
655 |
+
beheaded_model = model.model if hasattr(model, "model") else model
|
656 |
+
cls.override_vision_model(beheaded_model.decoder, vision_model_name, vision_model_params, torch_dtype)
|
657 |
+
beheaded_model.freeze_relevant_params(config)
|
658 |
+
|
659 |
+
|
660 |
+
OPT_INPUTS_DOCSTRING = r"""
|
661 |
+
Args:
|
662 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
663 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
664 |
+
it.
|
665 |
+
|
666 |
+
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
667 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
668 |
+
|
669 |
+
[What are input IDs?](../glossary#input-ids)
|
670 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
671 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
672 |
+
|
673 |
+
- 1 for tokens that are **not masked**,
|
674 |
+
- 0 for tokens that are **masked**.
|
675 |
+
|
676 |
+
[What are attention masks?](../glossary#attention-mask)
|
677 |
+
|
678 |
+
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
679 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
680 |
+
|
681 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
682 |
+
`past_key_values`).
|
683 |
+
|
684 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
685 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
686 |
+
information on the default strategy.
|
687 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
688 |
+
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
689 |
+
|
690 |
+
- 1 indicates the head is **not masked**,
|
691 |
+
- 0 indicates the head is **masked**.
|
692 |
+
|
693 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
694 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
695 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
696 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
697 |
+
|
698 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
699 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
700 |
+
|
701 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
702 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
703 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
704 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
705 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
706 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
707 |
+
model's internal embedding lookup matrix.
|
708 |
+
use_cache (`bool`, *optional*):
|
709 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
710 |
+
`past_key_values`).
|
711 |
+
output_attentions (`bool`, *optional*):
|
712 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
713 |
+
tensors for more detail.
|
714 |
+
output_hidden_states (`bool`, *optional*):
|
715 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
716 |
+
more detail.
|
717 |
+
return_dict (`bool`, *optional*):
|
718 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
719 |
+
"""
|
720 |
+
|
721 |
+
|
722 |
+
class VOPTDecoder(VOPTPreTrainedModel):
|
723 |
+
"""
|
724 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
|
725 |
+
|
726 |
+
Args:
|
727 |
+
config: VOPTConfig
|
728 |
+
"""
|
729 |
+
|
730 |
+
def __init__(self, config: VOPTConfig, vision_model=None):
|
731 |
+
super().__init__(config)
|
732 |
+
self.config = config
|
733 |
+
self.dropout = config.dropout
|
734 |
+
self.layerdrop = config.layerdrop
|
735 |
+
self.padding_idx = config.pad_token_id
|
736 |
+
self.max_target_positions = config.max_position_embeddings
|
737 |
+
self.vocab_size = config.vocab_size
|
738 |
+
|
739 |
+
self.embed_tokens = DecoupledEmbedding(
|
740 |
+
num_embeddings=config.vocab_size,
|
741 |
+
num_additional_embeddings=config.additional_vocab_size,
|
742 |
+
embedding_dim=config.word_embed_proj_dim,
|
743 |
+
partially_freeze=config.freeze_text_layers,
|
744 |
+
padding_idx=self.padding_idx,
|
745 |
+
)
|
746 |
+
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
747 |
+
|
748 |
+
# Load an uninitialized model and later in from_pretrained will load the pre-trained model -
|
749 |
+
# this solves the losing of weights in `from_pretrained` on the main model
|
750 |
+
self.vision_model = vision_model
|
751 |
+
|
752 |
+
# Perceiver Resampler
|
753 |
+
if config.use_resampler:
|
754 |
+
self.perceiver_resampler = PerceiverResampler(
|
755 |
+
self.config,
|
756 |
+
self.config.vision_embed_dim,
|
757 |
+
config.resampler_depth,
|
758 |
+
config.resampler_n_heads,
|
759 |
+
config.resampler_head_dim,
|
760 |
+
config.resampler_n_latents,
|
761 |
+
)
|
762 |
+
|
763 |
+
if config.word_embed_proj_dim != config.hidden_size:
|
764 |
+
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
765 |
+
else:
|
766 |
+
self.project_in = None
|
767 |
+
|
768 |
+
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
769 |
+
|
770 |
+
self.cross_layer_interval = config.cross_layer_interval
|
771 |
+
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
772 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
773 |
+
[VOPTGatedAttentionLayer(config) for i in range(num_cross_layers)]
|
774 |
+
)
|
775 |
+
self.gradient_checkpointing = False
|
776 |
+
|
777 |
+
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
778 |
+
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
779 |
+
# see https://github.com/facebookresearch/metaseq/pull/164
|
780 |
+
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
781 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
782 |
+
else:
|
783 |
+
self.final_layer_norm = None
|
784 |
+
|
785 |
+
if config.word_embed_proj_dim != config.hidden_size:
|
786 |
+
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
787 |
+
else:
|
788 |
+
self.project_out = None
|
789 |
+
|
790 |
+
# Initialize weights and apply final processing
|
791 |
+
self.post_init()
|
792 |
+
|
793 |
+
def get_input_embeddings(self):
|
794 |
+
return self.embed_tokens
|
795 |
+
|
796 |
+
def set_input_embeddings(self, value):
|
797 |
+
self.embed_tokens = value
|
798 |
+
|
799 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
800 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
801 |
+
# create causal mask
|
802 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
803 |
+
combined_attention_mask = None
|
804 |
+
if input_shape[-1] > 1:
|
805 |
+
combined_attention_mask = _make_causal_mask(
|
806 |
+
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
807 |
+
).to(inputs_embeds.device)
|
808 |
+
|
809 |
+
if attention_mask is not None:
|
810 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
811 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
812 |
+
inputs_embeds.device
|
813 |
+
)
|
814 |
+
combined_attention_mask = (
|
815 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
816 |
+
)
|
817 |
+
|
818 |
+
return combined_attention_mask
|
819 |
+
|
820 |
+
def forward(
|
821 |
+
self,
|
822 |
+
input_ids: torch.LongTensor = None,
|
823 |
+
attention_mask: Optional[torch.Tensor] = None,
|
824 |
+
head_mask: Optional[torch.Tensor] = None,
|
825 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
826 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
827 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
828 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
829 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
830 |
+
crossblock_head_mask: Optional[torch.Tensor] = None, # TOFO (ls): check if this is needed
|
831 |
+
use_cache: Optional[bool] = None,
|
832 |
+
output_attentions: Optional[bool] = None,
|
833 |
+
output_hidden_states: Optional[bool] = None,
|
834 |
+
return_dict: Optional[bool] = None,
|
835 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
836 |
+
r"""
|
837 |
+
Args:
|
838 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
839 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
840 |
+
provide it.
|
841 |
+
|
842 |
+
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
843 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
844 |
+
|
845 |
+
[What are input IDs?](../glossary#input-ids)
|
846 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
847 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
848 |
+
|
849 |
+
- 1 for tokens that are **not masked**,
|
850 |
+
- 0 for tokens that are **masked**.
|
851 |
+
|
852 |
+
[What are attention masks?](../glossary#attention-mask)
|
853 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
854 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
855 |
+
|
856 |
+
- 1 indicates the head is **not masked**,
|
857 |
+
- 0 indicates the head is **masked**.
|
858 |
+
|
859 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
860 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
861 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
862 |
+
|
863 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
864 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
865 |
+
|
866 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
867 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
868 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
869 |
+
|
870 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
871 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
872 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
873 |
+
than the model's internal embedding lookup matrix.
|
874 |
+
output_attentions (`bool`, *optional*):
|
875 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
876 |
+
returned tensors for more detail.
|
877 |
+
output_hidden_states (`bool`, *optional*):
|
878 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
879 |
+
for more detail.
|
880 |
+
return_dict (`bool`, *optional*):
|
881 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
882 |
+
"""
|
883 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
884 |
+
|
885 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
886 |
+
output_hidden_states = (
|
887 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
888 |
+
)
|
889 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
890 |
+
|
891 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
892 |
+
|
893 |
+
# retrieve input_ids and inputs_embeds
|
894 |
+
if input_ids is not None and inputs_embeds is not None:
|
895 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
896 |
+
elif input_ids is not None:
|
897 |
+
input_shape = input_ids.size()
|
898 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
899 |
+
elif inputs_embeds is not None:
|
900 |
+
input_shape = inputs_embeds.size()[:-1]
|
901 |
+
else:
|
902 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
903 |
+
|
904 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
905 |
+
|
906 |
+
if pixel_values is not None and image_embeddings is not None:
|
907 |
+
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
908 |
+
elif pixel_values is not None:
|
909 |
+
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
910 |
+
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
911 |
+
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
912 |
+
# Get sequence from the vision encoder
|
913 |
+
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
914 |
+
elif image_embeddings is not None:
|
915 |
+
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
916 |
+
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
917 |
+
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
918 |
+
|
919 |
+
if self.config.use_resampler:
|
920 |
+
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
921 |
+
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
922 |
+
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
923 |
+
# Make image_attention_mask compatible with hidden states
|
924 |
+
text_seq_len = image_attention_mask.size(1)
|
925 |
+
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
926 |
+
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
927 |
+
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
928 |
+
|
929 |
+
if image_hidden_states is not None:
|
930 |
+
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
931 |
+
image_hidden_shape = (image_batch_size, image_sequence_length)
|
932 |
+
if image_attention_mask is None:
|
933 |
+
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
934 |
+
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
935 |
+
else:
|
936 |
+
image_attention_mask = None
|
937 |
+
|
938 |
+
if inputs_embeds is None:
|
939 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
940 |
+
|
941 |
+
# embed positions
|
942 |
+
if attention_mask is None:
|
943 |
+
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
944 |
+
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
945 |
+
|
946 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
947 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
948 |
+
)
|
949 |
+
|
950 |
+
if self.project_in is not None:
|
951 |
+
inputs_embeds = self.project_in(inputs_embeds)
|
952 |
+
|
953 |
+
hidden_states = inputs_embeds + pos_embeds
|
954 |
+
|
955 |
+
# decoder layers
|
956 |
+
all_hidden_states = () if output_hidden_states else None
|
957 |
+
all_self_attns = () if output_attentions else None
|
958 |
+
next_decoder_cache = () if use_cache else None
|
959 |
+
|
960 |
+
# check if head_mask has a correct number of layers specified if desired
|
961 |
+
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
962 |
+
if attn_mask is not None:
|
963 |
+
if attn_mask.size()[0] != (len(self.layers)):
|
964 |
+
raise ValueError(
|
965 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
966 |
+
f" {head_mask.size()[0]}."
|
967 |
+
)
|
968 |
+
|
969 |
+
for idx, decoder_layer in enumerate(self.layers):
|
970 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
971 |
+
if output_hidden_states:
|
972 |
+
all_hidden_states += (hidden_states,)
|
973 |
+
|
974 |
+
dropout_probability = random.uniform(0, 1)
|
975 |
+
if self.training and (dropout_probability < self.layerdrop):
|
976 |
+
continue
|
977 |
+
|
978 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
979 |
+
layer_head_mask = head_mask[idx] if head_mask is not None else None
|
980 |
+
|
981 |
+
def vblock(
|
982 |
+
main_block,
|
983 |
+
hidden_states,
|
984 |
+
attention_mask,
|
985 |
+
layer_head_mask,
|
986 |
+
past_key_value,
|
987 |
+
image_hidden_states,
|
988 |
+
image_attention_mask,
|
989 |
+
output_attentions,
|
990 |
+
use_cache,
|
991 |
+
layer_idx,
|
992 |
+
cross_layer_interval,
|
993 |
+
gated_cross_attn_layers,
|
994 |
+
):
|
995 |
+
# TODO(ls): Add cross attention values to respective lists
|
996 |
+
if layer_idx % cross_layer_interval == 0:
|
997 |
+
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
998 |
+
outputs = xblock(
|
999 |
+
hidden_states,
|
1000 |
+
attention_mask=attention_mask,
|
1001 |
+
layer_head_mask=layer_head_mask,
|
1002 |
+
image_hidden_states=image_hidden_states,
|
1003 |
+
image_attention_mask=image_attention_mask,
|
1004 |
+
output_attentions=output_attentions,
|
1005 |
+
use_cache=use_cache,
|
1006 |
+
past_key_value=None, # not implemented
|
1007 |
+
)
|
1008 |
+
hidden_states = outputs[0]
|
1009 |
+
|
1010 |
+
layer_outputs = main_block(
|
1011 |
+
hidden_states,
|
1012 |
+
attention_mask=attention_mask,
|
1013 |
+
layer_head_mask=layer_head_mask,
|
1014 |
+
past_key_value=past_key_value,
|
1015 |
+
output_attentions=output_attentions,
|
1016 |
+
use_cache=use_cache,
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
return layer_outputs
|
1020 |
+
|
1021 |
+
if self.gradient_checkpointing and self.training:
|
1022 |
+
past_key_value = None
|
1023 |
+
if use_cache:
|
1024 |
+
logger.warning_once(
|
1025 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1026 |
+
)
|
1027 |
+
use_cache = False
|
1028 |
+
|
1029 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1030 |
+
vblock,
|
1031 |
+
decoder_layer,
|
1032 |
+
hidden_states,
|
1033 |
+
attention_mask,
|
1034 |
+
layer_head_mask,
|
1035 |
+
past_key_value,
|
1036 |
+
image_hidden_states,
|
1037 |
+
image_attention_mask,
|
1038 |
+
output_attentions,
|
1039 |
+
use_cache,
|
1040 |
+
idx,
|
1041 |
+
self.cross_layer_interval,
|
1042 |
+
self.gated_cross_attn_layers,
|
1043 |
+
)
|
1044 |
+
else:
|
1045 |
+
layer_outputs = vblock(
|
1046 |
+
decoder_layer,
|
1047 |
+
hidden_states,
|
1048 |
+
attention_mask=attention_mask,
|
1049 |
+
layer_head_mask=layer_head_mask,
|
1050 |
+
past_key_value=past_key_value,
|
1051 |
+
image_hidden_states=image_hidden_states,
|
1052 |
+
image_attention_mask=image_attention_mask,
|
1053 |
+
output_attentions=output_attentions,
|
1054 |
+
use_cache=use_cache,
|
1055 |
+
layer_idx=idx,
|
1056 |
+
cross_layer_interval=self.cross_layer_interval,
|
1057 |
+
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
hidden_states = layer_outputs[0]
|
1061 |
+
|
1062 |
+
if use_cache:
|
1063 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
1064 |
+
|
1065 |
+
if output_attentions:
|
1066 |
+
all_self_attns += (layer_outputs[1],)
|
1067 |
+
|
1068 |
+
if self.final_layer_norm is not None:
|
1069 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
1070 |
+
|
1071 |
+
if self.project_out is not None:
|
1072 |
+
hidden_states = self.project_out(hidden_states)
|
1073 |
+
|
1074 |
+
# add hidden states from the last decoder layer
|
1075 |
+
if output_hidden_states:
|
1076 |
+
all_hidden_states += (hidden_states,)
|
1077 |
+
|
1078 |
+
next_cache = next_decoder_cache if use_cache else None
|
1079 |
+
if not return_dict:
|
1080 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1081 |
+
return BaseModelOutputWithPast(
|
1082 |
+
last_hidden_state=hidden_states,
|
1083 |
+
past_key_values=next_cache,
|
1084 |
+
hidden_states=all_hidden_states,
|
1085 |
+
attentions=all_self_attns,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
|
1089 |
+
@add_start_docstrings(
|
1090 |
+
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
|
1091 |
+
OPT_START_DOCSTRING,
|
1092 |
+
)
|
1093 |
+
class VOPTModel(VOPTPreTrainedModel):
|
1094 |
+
def __init__(self, config: VOPTConfig, vision_model=None):
|
1095 |
+
super().__init__(config)
|
1096 |
+
self.decoder = VOPTDecoder(config, vision_model=vision_model)
|
1097 |
+
|
1098 |
+
# Initialize weights and apply final processing
|
1099 |
+
self.post_init()
|
1100 |
+
|
1101 |
+
self.freeze_relevant_params(config)
|
1102 |
+
|
1103 |
+
def freeze_relevant_params(self, config=None):
|
1104 |
+
if config is None:
|
1105 |
+
config = self.config
|
1106 |
+
|
1107 |
+
if config.freeze_text_layers:
|
1108 |
+
self.freeze_text_layers(config.freeze_text_module_exceptions)
|
1109 |
+
|
1110 |
+
if config.freeze_vision_layers:
|
1111 |
+
freeze_model(self.decoder.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
|
1112 |
+
|
1113 |
+
def freeze_text_layers(self, module_exceptions):
|
1114 |
+
for module in [self.decoder.embed_positions, self.decoder.layers]:
|
1115 |
+
freeze_model(module, module_exceptions=module_exceptions)
|
1116 |
+
|
1117 |
+
if self.decoder.project_out is not None:
|
1118 |
+
freeze_model(self.decoder.project_out, module_exceptions=module_exceptions)
|
1119 |
+
|
1120 |
+
if self.decoder.final_layer_norm is not None:
|
1121 |
+
freeze_model(self.decoder.final_layer_norm, module_exceptions=module_exceptions)
|
1122 |
+
|
1123 |
+
def get_input_embeddings(self):
|
1124 |
+
return self.decoder.embed_tokens
|
1125 |
+
|
1126 |
+
def set_input_embeddings(self, value):
|
1127 |
+
self.decoder.embed_tokens = value
|
1128 |
+
|
1129 |
+
def get_decoder(self):
|
1130 |
+
return self.decoder
|
1131 |
+
|
1132 |
+
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
1133 |
+
@add_code_sample_docstrings(
|
1134 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1135 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1136 |
+
output_type=BaseModelOutputWithPast,
|
1137 |
+
config_class=_CONFIG_FOR_DOC,
|
1138 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
1139 |
+
)
|
1140 |
+
def forward(
|
1141 |
+
self,
|
1142 |
+
input_ids: torch.LongTensor = None,
|
1143 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1144 |
+
head_mask: Optional[torch.Tensor] = None,
|
1145 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1146 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1147 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1148 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1149 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1150 |
+
crossblock_head_mask: Optional[torch.Tensor] = None, # TOFO (ls): check if this is needed
|
1151 |
+
use_cache: Optional[bool] = None,
|
1152 |
+
output_attentions: Optional[bool] = None,
|
1153 |
+
output_hidden_states: Optional[bool] = None,
|
1154 |
+
return_dict: Optional[bool] = None,
|
1155 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1156 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1157 |
+
output_hidden_states = (
|
1158 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1159 |
+
)
|
1160 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1161 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1162 |
+
|
1163 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
1164 |
+
decoder_outputs = self.decoder(
|
1165 |
+
input_ids=input_ids,
|
1166 |
+
attention_mask=attention_mask,
|
1167 |
+
head_mask=head_mask,
|
1168 |
+
past_key_values=past_key_values,
|
1169 |
+
inputs_embeds=inputs_embeds,
|
1170 |
+
pixel_values=pixel_values,
|
1171 |
+
image_embeddings=image_embeddings,
|
1172 |
+
image_attention_mask=image_attention_mask,
|
1173 |
+
crossblock_head_mask=crossblock_head_mask,
|
1174 |
+
use_cache=use_cache,
|
1175 |
+
output_attentions=output_attentions,
|
1176 |
+
output_hidden_states=output_hidden_states,
|
1177 |
+
return_dict=return_dict,
|
1178 |
+
)
|
1179 |
+
|
1180 |
+
if not return_dict:
|
1181 |
+
return decoder_outputs
|
1182 |
+
|
1183 |
+
return BaseModelOutputWithPast(
|
1184 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1185 |
+
past_key_values=decoder_outputs.past_key_values,
|
1186 |
+
hidden_states=decoder_outputs.hidden_states,
|
1187 |
+
attentions=decoder_outputs.attentions,
|
1188 |
+
)
|
1189 |
+
|
1190 |
+
|
1191 |
+
class VOPTForCausalLM(VOPTPreTrainedModel):
|
1192 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
1193 |
+
|
1194 |
+
def __init__(self, config, vision_model=None):
|
1195 |
+
super().__init__(config)
|
1196 |
+
|
1197 |
+
# Initialize LM head first so that it is not directly offloaded to the CPU/disk
|
1198 |
+
# the lm_head weight is automatically tied to the embed tokens weight
|
1199 |
+
self.lm_head = DecoupledLinear(
|
1200 |
+
in_features=config.word_embed_proj_dim,
|
1201 |
+
out_features=config.vocab_size,
|
1202 |
+
out_additional_features=config.additional_vocab_size,
|
1203 |
+
bias=False,
|
1204 |
+
partially_freeze=config.freeze_lm_head,
|
1205 |
+
)
|
1206 |
+
|
1207 |
+
self.model = VOPTModel(config, vision_model=vision_model)
|
1208 |
+
# Initialize weights and apply final processing
|
1209 |
+
self.post_init()
|
1210 |
+
|
1211 |
+
def tie_weights(self):
|
1212 |
+
"""
|
1213 |
+
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
|
1214 |
+
"""
|
1215 |
+
output_embeddings = self.get_output_embeddings()
|
1216 |
+
input_embeddings = self.get_input_embeddings()
|
1217 |
+
|
1218 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
1219 |
+
output_embeddings.weight = input_embeddings.weight
|
1220 |
+
if input_embeddings.num_additional_embeddings > 0:
|
1221 |
+
assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
|
1222 |
+
output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
|
1223 |
+
|
1224 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
1225 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
1226 |
+
if hasattr(output_embeddings, "out_additional_features") and hasattr(
|
1227 |
+
input_embeddings, "num_additional_embeddings"
|
1228 |
+
):
|
1229 |
+
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
1230 |
+
|
1231 |
+
def get_input_embeddings(self):
|
1232 |
+
return self.model.decoder.embed_tokens
|
1233 |
+
|
1234 |
+
def set_input_embeddings(self, value):
|
1235 |
+
self.model.decoder.embed_tokens = value
|
1236 |
+
|
1237 |
+
def get_output_embeddings(self):
|
1238 |
+
return self.lm_head
|
1239 |
+
|
1240 |
+
def set_output_embeddings(self, new_embeddings):
|
1241 |
+
self.lm_head = new_embeddings
|
1242 |
+
|
1243 |
+
def set_decoder(self, decoder):
|
1244 |
+
self.model.decoder = decoder
|
1245 |
+
|
1246 |
+
def get_decoder(self):
|
1247 |
+
return self.model.decoder
|
1248 |
+
|
1249 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1250 |
+
def forward(
|
1251 |
+
self,
|
1252 |
+
input_ids: torch.LongTensor = None,
|
1253 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1254 |
+
head_mask: Optional[torch.Tensor] = None,
|
1255 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1256 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1257 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1258 |
+
image_embeddings: Optional[torch.FloatTensor] = None,
|
1259 |
+
image_attention_mask: Optional[torch.Tensor] = None,
|
1260 |
+
crossblock_head_mask: Optional[torch.Tensor] = None,
|
1261 |
+
labels: Optional[torch.LongTensor] = None,
|
1262 |
+
use_cache: Optional[bool] = None,
|
1263 |
+
output_attentions: Optional[bool] = None,
|
1264 |
+
output_hidden_states: Optional[bool] = None,
|
1265 |
+
return_dict: Optional[bool] = None,
|
1266 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1267 |
+
r"""
|
1268 |
+
Args:
|
1269 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1270 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
1271 |
+
provide it.
|
1272 |
+
|
1273 |
+
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1274 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1275 |
+
|
1276 |
+
[What are input IDs?](../glossary#input-ids)
|
1277 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1278 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1279 |
+
|
1280 |
+
- 1 for tokens that are **not masked**,
|
1281 |
+
- 0 for tokens that are **masked**.
|
1282 |
+
|
1283 |
+
[What are attention masks?](../glossary#attention-mask)
|
1284 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
1285 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
1286 |
+
|
1287 |
+
- 1 indicates the head is **not masked**,
|
1288 |
+
- 0 indicates the head is **masked**.
|
1289 |
+
|
1290 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
1291 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
1292 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
1293 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
1294 |
+
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
1295 |
+
|
1296 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
1297 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
1298 |
+
|
1299 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
1300 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
1301 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1302 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1303 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
1304 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
1305 |
+
than the model's internal embedding lookup matrix.
|
1306 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1307 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1308 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1309 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1310 |
+
use_cache (`bool`, *optional*):
|
1311 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
1312 |
+
(see `past_key_values`).
|
1313 |
+
output_attentions (`bool`, *optional*):
|
1314 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1315 |
+
returned tensors for more detail.
|
1316 |
+
output_hidden_states (`bool`, *optional*):
|
1317 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
1318 |
+
for more detail.
|
1319 |
+
return_dict (`bool`, *optional*):
|
1320 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1321 |
+
|
1322 |
+
Returns:
|
1323 |
+
|
1324 |
+
Example:
|
1325 |
+
|
1326 |
+
```python
|
1327 |
+
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
|
1328 |
+
|
1329 |
+
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
1330 |
+
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
1331 |
+
|
1332 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
1333 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1334 |
+
|
1335 |
+
>>> # Generate
|
1336 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1337 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1338 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
1339 |
+
```"""
|
1340 |
+
|
1341 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1342 |
+
output_hidden_states = (
|
1343 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1344 |
+
)
|
1345 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1346 |
+
|
1347 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1348 |
+
outputs = self.model.decoder(
|
1349 |
+
input_ids=input_ids,
|
1350 |
+
attention_mask=attention_mask,
|
1351 |
+
head_mask=head_mask,
|
1352 |
+
past_key_values=past_key_values,
|
1353 |
+
inputs_embeds=inputs_embeds,
|
1354 |
+
pixel_values=pixel_values,
|
1355 |
+
image_embeddings=image_embeddings,
|
1356 |
+
image_attention_mask=image_attention_mask,
|
1357 |
+
crossblock_head_mask=crossblock_head_mask,
|
1358 |
+
use_cache=use_cache,
|
1359 |
+
output_attentions=output_attentions,
|
1360 |
+
output_hidden_states=output_hidden_states,
|
1361 |
+
return_dict=return_dict,
|
1362 |
+
)
|
1363 |
+
|
1364 |
+
logits = self.lm_head(outputs[0]).contiguous()
|
1365 |
+
|
1366 |
+
loss = None
|
1367 |
+
if labels is not None:
|
1368 |
+
# Shift so that tokens < n predict n
|
1369 |
+
if attention_mask is not None:
|
1370 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1371 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
1372 |
+
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
1373 |
+
else:
|
1374 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1375 |
+
shift_labels = labels[..., 1:].contiguous()
|
1376 |
+
# Flatten the tokens
|
1377 |
+
loss_fct = CrossEntropyLoss()
|
1378 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1379 |
+
|
1380 |
+
if not return_dict:
|
1381 |
+
output = (logits,) + outputs[1:]
|
1382 |
+
return (loss,) + output if loss is not None else output
|
1383 |
+
|
1384 |
+
return CausalLMOutputWithPast(
|
1385 |
+
loss=loss,
|
1386 |
+
logits=logits,
|
1387 |
+
past_key_values=outputs.past_key_values,
|
1388 |
+
hidden_states=outputs.hidden_states,
|
1389 |
+
attentions=outputs.attentions,
|
1390 |
+
)
|
1391 |
+
|
1392 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1393 |
+
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
1394 |
+
unwanted_kwargs = ["position_ids", "token_type_ids"]
|
1395 |
+
for kwarg in unwanted_kwargs:
|
1396 |
+
inputs.pop(kwarg, None)
|
1397 |
+
return inputs
|
1398 |
+
|
1399 |
+
@staticmethod
|
1400 |
+
def _expand_inputs_for_generation(
|
1401 |
+
*args,
|
1402 |
+
**model_kwargs,
|
1403 |
+
):
|
1404 |
+
return expand_inputs_for_generation(*args, **model_kwargs)
|
1405 |
+
|
1406 |
+
@staticmethod
|
1407 |
+
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
1408 |
+
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
|
1409 |
+
|
1410 |
+
@staticmethod
|
1411 |
+
def _reorder_cache(past, beam_idx):
|
1412 |
+
reordered_past = ()
|
1413 |
+
for layer_past in past:
|
1414 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1415 |
+
return reordered_past
|
1416 |
+
|
1417 |
+
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
|
1418 |
+
config_vl_model = self.config
|
1419 |
+
|
1420 |
+
language_embed_size = config_vl_model.hidden_size
|
1421 |
+
num_language_layers = config_vl_model.num_hidden_layers
|
1422 |
+
ffn_inner_size = config_vl_model.ffn_dim
|
1423 |
+
|
1424 |
+
vision_config = self.model.decoder.vision_model.config
|
1425 |
+
if hasattr(vision_config, "vision_config"):
|
1426 |
+
vision_config = vision_config.vision_config
|
1427 |
+
|
1428 |
+
# Get vision model blocks infos
|
1429 |
+
vision_patch_size = vision_config.patch_size
|
1430 |
+
vision_hidden_size = vision_config.hidden_size
|
1431 |
+
num_vision_layers = vision_config.num_hidden_layers
|
1432 |
+
# The +1 is for the CLS token
|
1433 |
+
single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
|
1434 |
+
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
|
1435 |
+
|
1436 |
+
# Get language and cross-att blocks infos
|
1437 |
+
num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
|
1438 |
+
language_seq_len = data_param.max_seq_len
|
1439 |
+
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1440 |
+
cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
|
1441 |
+
k_v_cross_attn_seq_len = (
|
1442 |
+
(self.config.resampler_n_latents * max_num_images)
|
1443 |
+
if self.config.use_resampler
|
1444 |
+
else (single_image_seq_len * max_num_images)
|
1445 |
+
)
|
1446 |
+
|
1447 |
+
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1448 |
+
num_layers=num_language_layers,
|
1449 |
+
batch_size=hparams.batch_size_per_gpu,
|
1450 |
+
q_seq_len=language_seq_len,
|
1451 |
+
k_seq_len=language_seq_len,
|
1452 |
+
hidden_size=language_embed_size,
|
1453 |
+
kv_in_dim=language_embed_size,
|
1454 |
+
ff_exp_factor=language_exp_factor,
|
1455 |
+
grad_acc_size=hparams.grad_acc_size,
|
1456 |
+
swiglu=False,
|
1457 |
+
vocab_size=tokenizer.vocab_size,
|
1458 |
+
count_backward=True, # Always True regardless of freezing, because gradients are computed for cross-attentions
|
1459 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1460 |
+
)
|
1461 |
+
cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1462 |
+
num_layers=num_cross_attn_layers,
|
1463 |
+
batch_size=hparams.batch_size_per_gpu,
|
1464 |
+
q_seq_len=language_seq_len,
|
1465 |
+
k_seq_len=k_v_cross_attn_seq_len,
|
1466 |
+
hidden_size=language_embed_size,
|
1467 |
+
kv_in_dim=vision_hidden_size,
|
1468 |
+
ff_exp_factor=cross_att_exp_factor,
|
1469 |
+
grad_acc_size=hparams.grad_acc_size,
|
1470 |
+
swiglu=self.config.cross_layer_activation_function == "swiglu",
|
1471 |
+
vocab_size=None,
|
1472 |
+
count_backward=True,
|
1473 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1474 |
+
)
|
1475 |
+
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
|
1476 |
+
num_layers=num_vision_layers,
|
1477 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1478 |
+
q_seq_len=single_image_seq_len,
|
1479 |
+
k_seq_len=single_image_seq_len,
|
1480 |
+
hidden_size=vision_hidden_size,
|
1481 |
+
kv_in_dim=vision_hidden_size,
|
1482 |
+
ff_exp_factor=vision_exp_factor,
|
1483 |
+
grad_acc_size=hparams.grad_acc_size,
|
1484 |
+
swiglu=False,
|
1485 |
+
vocab_size=None,
|
1486 |
+
count_backward=not hparams.model_params["freeze_vision_layers"],
|
1487 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1488 |
+
)
|
1489 |
+
if self.config.use_resampler:
|
1490 |
+
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
|
1491 |
+
num_layers=self.config.resampler_depth,
|
1492 |
+
batch_size=hparams.batch_size_per_gpu * max_num_images,
|
1493 |
+
q_seq_len=self.config.resampler_n_latents,
|
1494 |
+
vision_embed_seq_len=single_image_seq_len,
|
1495 |
+
q_k_v_input_dim=vision_hidden_size,
|
1496 |
+
attention_hidden_size=self.config.resampler_n_heads * self.config.resampler_head_dim,
|
1497 |
+
ff_exp_factor=cross_att_exp_factor,
|
1498 |
+
count_backward=True,
|
1499 |
+
use_grad_checkpointing=hparams.gradient_checkpointing,
|
1500 |
+
)
|
1501 |
+
flop_count = (
|
1502 |
+
language_tflops_per_batch_per_gpu
|
1503 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1504 |
+
+ vision_tflops_per_batch_per_gpu
|
1505 |
+
+ perceiver_tflops_per_batch_per_gpu
|
1506 |
+
)
|
1507 |
+
else:
|
1508 |
+
flop_count = (
|
1509 |
+
language_tflops_per_batch_per_gpu
|
1510 |
+
+ cross_attention_tflops_per_batch_per_gpu
|
1511 |
+
+ vision_tflops_per_batch_per_gpu
|
1512 |
+
)
|
1513 |
+
return flop_count
|
m4/models/vt5/__init__.py
ADDED
File without changes
|
m4/models/vt5/configuration_vt5.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" VT5 model configuration"""
|
16 |
+
import os
|
17 |
+
from typing import Tuple, Union
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"t5-small": "https://huggingface.co/t5-small/resolve/main/config.json",
|
28 |
+
"t5-base": "https://huggingface.co/t5-base/resolve/main/config.json",
|
29 |
+
"t5-large": "https://huggingface.co/t5-large/resolve/main/config.json",
|
30 |
+
"t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json",
|
31 |
+
"t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
class VT5Config(PretrainedConfig):
|
36 |
+
r"""
|
37 |
+
This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
|
38 |
+
instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
|
39 |
+
configuration with the defaults will yield a similar configuration to that of the T5
|
40 |
+
[t5-small](https://huggingface.co/t5-small) architecture.
|
41 |
+
|
42 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
43 |
+
documentation from [`PretrainedConfig`] for more information.
|
44 |
+
|
45 |
+
TODO: this doc is completely out of sync with the actual args
|
46 |
+
|
47 |
+
Arguments:
|
48 |
+
vocab_size (`int`, *optional*, defaults to 32128):
|
49 |
+
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
50 |
+
`inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
|
51 |
+
d_model (`int`, *optional*, defaults to 512):
|
52 |
+
Size of the encoder layers and the pooler layer.
|
53 |
+
d_kv (`int`, *optional*, defaults to 64):
|
54 |
+
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
|
55 |
+
num_heads`.
|
56 |
+
d_ff (`int`, *optional*, defaults to 2048):
|
57 |
+
Size of the intermediate feed forward layer in each `T5Block`.
|
58 |
+
num_layers (`int`, *optional*, defaults to 6):
|
59 |
+
Number of hidden layers in the Transformer encoder.
|
60 |
+
num_decoder_layers (`int`, *optional*):
|
61 |
+
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
62 |
+
num_heads (`int`, *optional*, defaults to 8):
|
63 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
64 |
+
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
|
65 |
+
The number of buckets to use for each attention layer.
|
66 |
+
relative_attention_max_distance (`int`, *optional*, defaults to 128):
|
67 |
+
The maximum distance of the longer sequences for the bucket separation.
|
68 |
+
dropout_rate (`float`, *optional*, defaults to 0.1):
|
69 |
+
The ratio for all dropout layers.
|
70 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
71 |
+
The epsilon used by the layer normalization layers.
|
72 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
73 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
74 |
+
testing).
|
75 |
+
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
|
76 |
+
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
|
77 |
+
`"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
|
78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
80 |
+
additional_vocab_size (`int`, *optional`, defaults to 0):
|
81 |
+
Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
|
82 |
+
are always trainable whereas regular vocab tokens can be frozen or not.
|
83 |
+
alpha_initializer (`str`, *optional*, defaults to `"ones"`):
|
84 |
+
Initialization type for the alphas.
|
85 |
+
alphas_initializer_range (`float`, *optional*, defaults to 0.0):
|
86 |
+
The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross Attention.
|
87 |
+
alpha_type (`str`, *optional*, defaults to `"vector"`):
|
88 |
+
Whether the gating alphas should be vectors or single floats.
|
89 |
+
"""
|
90 |
+
model_type = "vt5"
|
91 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
92 |
+
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
vocab_size=32128,
|
97 |
+
d_model=512,
|
98 |
+
d_kv=64,
|
99 |
+
d_ff=2048,
|
100 |
+
num_layers=6,
|
101 |
+
num_decoder_layers=None,
|
102 |
+
num_heads=8,
|
103 |
+
relative_attention_num_buckets=32,
|
104 |
+
relative_attention_max_distance=128,
|
105 |
+
dropout_rate=0.1,
|
106 |
+
layer_norm_epsilon=1e-6,
|
107 |
+
initializer_factor=1.0,
|
108 |
+
feed_forward_proj="relu",
|
109 |
+
is_encoder_decoder=True,
|
110 |
+
use_cache=True,
|
111 |
+
pad_token_id=0,
|
112 |
+
eos_token_id=1,
|
113 |
+
additional_vocab_size=0,
|
114 |
+
alpha_initializer="ones",
|
115 |
+
alphas_initializer_range=0.0,
|
116 |
+
alpha_type="vector",
|
117 |
+
cross_layer_interval=1,
|
118 |
+
tie_word_embeddings=False,
|
119 |
+
freeze_text_layers=True,
|
120 |
+
freeze_lm_head=False,
|
121 |
+
freeze_vision_layers=True,
|
122 |
+
vision_model_name="google/vit-base-patch16-224",
|
123 |
+
vision_model_params="{}",
|
124 |
+
vision_embed_dim=768,
|
125 |
+
image_token_index=32128,
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
self.vocab_size = vocab_size
|
129 |
+
self.additional_vocab_size = additional_vocab_size
|
130 |
+
self.d_model = d_model
|
131 |
+
self.d_kv = d_kv
|
132 |
+
self.d_ff = d_ff
|
133 |
+
self.num_layers = num_layers
|
134 |
+
self.num_decoder_layers = (
|
135 |
+
num_decoder_layers if num_decoder_layers is not None else self.num_layers
|
136 |
+
) # default = symmetry
|
137 |
+
self.num_heads = num_heads
|
138 |
+
self.relative_attention_num_buckets = relative_attention_num_buckets
|
139 |
+
self.relative_attention_max_distance = relative_attention_max_distance
|
140 |
+
self.dropout_rate = dropout_rate
|
141 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
142 |
+
self.initializer_factor = initializer_factor
|
143 |
+
self.feed_forward_proj = feed_forward_proj
|
144 |
+
self.use_cache = use_cache
|
145 |
+
|
146 |
+
act_info = self.feed_forward_proj.split("-")
|
147 |
+
self.dense_act_fn = act_info[-1]
|
148 |
+
self.is_gated_act = act_info[0] == "gated"
|
149 |
+
|
150 |
+
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
151 |
+
raise ValueError(
|
152 |
+
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
|
153 |
+
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
154 |
+
"'gated-gelu' or 'relu'"
|
155 |
+
)
|
156 |
+
|
157 |
+
# for backwards compatibility
|
158 |
+
if feed_forward_proj == "gated-gelu":
|
159 |
+
self.dense_act_fn = "gelu_new"
|
160 |
+
|
161 |
+
self.alpha_initializer = alpha_initializer
|
162 |
+
self.alphas_initializer_range = alphas_initializer_range
|
163 |
+
self.alpha_type = alpha_type
|
164 |
+
|
165 |
+
self.cross_layer_interval = cross_layer_interval
|
166 |
+
self.freeze_vision_layers = freeze_vision_layers
|
167 |
+
self.vision_model_name = vision_model_name
|
168 |
+
self.vision_model_params = vision_model_params
|
169 |
+
|
170 |
+
self.tie_word_embeddings = tie_word_embeddings
|
171 |
+
self.freeze_text_layers = freeze_text_layers
|
172 |
+
self.freeze_lm_head = freeze_lm_head
|
173 |
+
self.image_token_index = image_token_index
|
174 |
+
|
175 |
+
self.vision_embed_dim = vision_embed_dim
|
176 |
+
|
177 |
+
# IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
|
178 |
+
# PretrainedConfig.from_dict first instantiates the class with the config dict and only then
|
179 |
+
# updates the config object with `kwargs` from from_pretrained, so during the instantiation
|
180 |
+
# of this object many attributes have default values and haven't yet been overridden.
|
181 |
+
# Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
|
182 |
+
|
183 |
+
super().__init__(
|
184 |
+
pad_token_id=pad_token_id,
|
185 |
+
eos_token_id=eos_token_id,
|
186 |
+
is_encoder_decoder=is_encoder_decoder,
|
187 |
+
tie_word_embeddings=tie_word_embeddings,
|
188 |
+
**kwargs,
|
189 |
+
)
|
190 |
+
|
191 |
+
def check_compatibilities(self):
|
192 |
+
if self.tie_word_embeddings and (self.freeze_text_layers != self.freeze_lm_head):
|
193 |
+
raise ValueError(
|
194 |
+
"if `tie_word_embeddings` is True, then `freeze_lm_head` and `freeze_text_layers` must be equal."
|
195 |
+
)
|
196 |
+
|
197 |
+
vision_model_params = eval(self.vision_model_params)
|
198 |
+
config = AutoConfig.from_pretrained(self.vision_model_name, **vision_model_params)
|
199 |
+
if hasattr(config, "vision_config"):
|
200 |
+
vison_config = config.vision_config
|
201 |
+
else:
|
202 |
+
vison_config = config
|
203 |
+
vision_embed_dim = vison_config.hidden_size
|
204 |
+
if self.vision_embed_dim != vision_embed_dim:
|
205 |
+
raise ValueError(
|
206 |
+
f"vision_embed_dim ({self.vision_embed_dim}) must match the hidden size of the vision model"
|
207 |
+
f" ({vision_embed_dim})"
|
208 |
+
)
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
212 |
+
outputs = super(VT5Config, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
|
213 |
+
if isinstance(outputs, Tuple):
|
214 |
+
# When called with return_unused_kwargs=True, the first item will be the config
|
215 |
+
outputs[0].check_compatibilities()
|
216 |
+
else:
|
217 |
+
outputs.check_compatibilities()
|
218 |
+
return outputs
|
m4/models/vt5/modeling_vt5.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
m4/models/zero_checkpoint_to_hf.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# This script combines the 2 steps of
|
4 |
+
# 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint
|
5 |
+
# 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints
|
6 |
+
#
|
7 |
+
# example usage:
|
8 |
+
#
|
9 |
+
# this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model
|
10 |
+
#
|
11 |
+
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40
|
12 |
+
#
|
13 |
+
# or you can override the destination by passing an explicit target dir, e.g.:
|
14 |
+
#
|
15 |
+
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import sys
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
23 |
+
|
24 |
+
|
25 |
+
# auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH
|
26 |
+
repodir = str(Path(__file__).resolve().parents[2])
|
27 |
+
sys.path.insert(0, repodir)
|
28 |
+
|
29 |
+
import m4.models
|
30 |
+
from m4.testing_utils import read_json_file
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument(
|
36 |
+
"checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100"
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"output_dir",
|
40 |
+
type=str,
|
41 |
+
nargs="?",
|
42 |
+
help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument",
|
43 |
+
)
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
checkpoint_dir = Path(args.checkpoint_dir)
|
47 |
+
config_dir = checkpoint_dir / "unwrapped_model"
|
48 |
+
ds_checkpoint_dir = checkpoint_dir / "accelerator_state"
|
49 |
+
config_file_path = config_dir / "config.json"
|
50 |
+
|
51 |
+
if args.output_dir is None:
|
52 |
+
output_dir = checkpoint_dir / "unwrapped_model"
|
53 |
+
else:
|
54 |
+
output_dir = args.output_dir
|
55 |
+
|
56 |
+
config = read_json_file(config_file_path)
|
57 |
+
config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None)
|
58 |
+
if config_class is None:
|
59 |
+
raise ValueError(f"{config['model_type']=} isn't supported by m4")
|
60 |
+
modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None)
|
61 |
+
|
62 |
+
print(f"Detected {config_class}")
|
63 |
+
|
64 |
+
print("Reconsolidating fp32 model from checkpoint shards (can take a long time)")
|
65 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu
|
66 |
+
|
67 |
+
# Keeping debug to use if you ever need to debug state dict
|
68 |
+
# print("Saved State Dict")
|
69 |
+
# for k, v in state_dict.items():
|
70 |
+
# print(f"{k} {v.shape}")
|
71 |
+
|
72 |
+
kwargs = {}
|
73 |
+
print(f"Loading config from {config_dir}")
|
74 |
+
model_config = config_class.from_pretrained(config_dir)
|
75 |
+
|
76 |
+
print(f"Instantiating a {modeling_class} model in bf16")
|
77 |
+
model = modeling_class.from_pretrained(
|
78 |
+
None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16
|
79 |
+
)
|
80 |
+
|
81 |
+
# Keeping debug to use if you ever need to debug state dict
|
82 |
+
# print("Model State Dict")
|
83 |
+
# for k, v in model.state_dict().items():
|
84 |
+
# print(f"{k} {v.shape}")
|
85 |
+
|
86 |
+
print(f"Saving model to {output_dir}")
|
87 |
+
model.save_pretrained(output_dir)
|
m4/scripts/cleanup-checkpoints.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
#
|
4 |
+
# This tool deletes checkpoints found at given path that are no longer needed
|
5 |
+
#
|
6 |
+
# we have 2 parts to each checkpoints to cleanup
|
7 |
+
#
|
8 |
+
# 1. the original deepspeed checkpoint
|
9 |
+
# 2. the converted hf checkpoint
|
10 |
+
#
|
11 |
+
# we will start with a combined requirement for eval to be completed and s3 synced to nuke the checkpoint
|
12 |
+
#
|
13 |
+
# Example:
|
14 |
+
#
|
15 |
+
# ./cleanup-checkpoints.py checkpoints-path
|
16 |
+
#
|
17 |
+
# Use `-h` for more options
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import shutil # noqa
|
21 |
+
import subprocess
|
22 |
+
import sys
|
23 |
+
import time
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
|
27 |
+
repo_path = Path(__file__).parents[2]
|
28 |
+
|
29 |
+
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
|
30 |
+
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
|
31 |
+
#
|
32 |
+
# If that file is there it means:
|
33 |
+
#
|
34 |
+
# 1. either the cleanup is still running
|
35 |
+
# 2. the cleanup got aborted (e.g. cpu-oom)
|
36 |
+
#
|
37 |
+
# to detect aborted cleanups we will check if the control file is older than a reasonable time to perform such a cleanup
|
38 |
+
control_file_name = "started-cleanup-checkpoint"
|
39 |
+
finished_uploading_file_name = "finished-upload-checkpoint"
|
40 |
+
# should fine tune - but surely 1h per checkpoint is plenty
|
41 |
+
reasonable_cleanup_time_in_secs = 1 * 60 * 60
|
42 |
+
|
43 |
+
|
44 |
+
def run_cmd(cmd, check=True):
|
45 |
+
try:
|
46 |
+
response = subprocess.run(
|
47 |
+
cmd,
|
48 |
+
stderr=subprocess.PIPE,
|
49 |
+
stdout=subprocess.PIPE,
|
50 |
+
check=check,
|
51 |
+
encoding="utf-8",
|
52 |
+
).stdout.strip()
|
53 |
+
except subprocess.CalledProcessError as exc:
|
54 |
+
raise EnvironmentError(exc.stderr)
|
55 |
+
|
56 |
+
return response
|
57 |
+
|
58 |
+
|
59 |
+
def get_args():
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
|
62 |
+
parser.add_argument("--skip-evals-check", action="store_true", help="skip evals done checks")
|
63 |
+
return parser.parse_args()
|
64 |
+
|
65 |
+
|
66 |
+
def exit(msg):
|
67 |
+
print(msg)
|
68 |
+
sys.exit()
|
69 |
+
|
70 |
+
|
71 |
+
def should_process(path, control_file_path, args):
|
72 |
+
"""Heuristics to decide whether to cleanup this opt_step-XXX checkpoint or not"""
|
73 |
+
|
74 |
+
s3_completed_path = path / finished_uploading_file_name
|
75 |
+
eval_completed_paths = [
|
76 |
+
path / "run_evals_0_shots_done",
|
77 |
+
path / "run_evals_4_shots_done",
|
78 |
+
path / "run_evals_perplexity_validation_done",
|
79 |
+
path / "run_evals_0_shots_a_la_flamingo_done",
|
80 |
+
]
|
81 |
+
|
82 |
+
# check s3 sync is completed
|
83 |
+
if not s3_completed_path.exists():
|
84 |
+
print(f"[N] {path} hasn't been synced to s3 yet. Skipping")
|
85 |
+
return False
|
86 |
+
|
87 |
+
# check evals are completed
|
88 |
+
if not args.skip_evals_check:
|
89 |
+
for eval_path in eval_completed_paths:
|
90 |
+
if not eval_path.exists():
|
91 |
+
print(f"[N] {path} hasn't been evaled yet. Skipping")
|
92 |
+
return False
|
93 |
+
|
94 |
+
# complicated checks - has another job already started processing? or did it crash?
|
95 |
+
if control_file_path.exists():
|
96 |
+
if control_file_path.stat().st_mtime < time.time() - reasonable_cleanup_time_in_secs:
|
97 |
+
print(f"[Y] {path} looks stale - probably aborted cleanup job. Deleting")
|
98 |
+
return True
|
99 |
+
else:
|
100 |
+
print(
|
101 |
+
f"[N] {path} either another job is doing the cleanup or less than"
|
102 |
+
f" {reasonable_cleanup_time_in_secs} secs has passed since it was launched. Skipping"
|
103 |
+
)
|
104 |
+
return False
|
105 |
+
else:
|
106 |
+
print(f"[Y] {path} completed s3 sync + eval. Deleting")
|
107 |
+
return True
|
108 |
+
|
109 |
+
|
110 |
+
def main():
|
111 |
+
args = get_args()
|
112 |
+
|
113 |
+
checkpoints_path = Path(args.checkpoints_path)
|
114 |
+
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
|
115 |
+
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
|
116 |
+
|
117 |
+
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
|
118 |
+
if len(checkpoint_dirs) == 0:
|
119 |
+
exit("No checkpoints found, exiting")
|
120 |
+
|
121 |
+
# Check each checkpoint folder in real time to allow for overlapping jobs starting at different times
|
122 |
+
# Additionally do not delete the last 2 checkpoints
|
123 |
+
#
|
124 |
+
# sort numerically to sort correctly different number of digits: opt_step-10, opt_step-100
|
125 |
+
checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
|
126 |
+
for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
|
127 |
+
print(f"\n*** Checking {checkpoint_dir}")
|
128 |
+
|
129 |
+
if i + 1 == len(checkpoint_dirs_sorted):
|
130 |
+
print(f"[N] {checkpoint_dir} is a last checkpoint. Skipping")
|
131 |
+
continue
|
132 |
+
|
133 |
+
if i + 2 == len(checkpoint_dirs_sorted):
|
134 |
+
print(f"[N] {checkpoint_dir} is a second to last checkpoint. Skipping")
|
135 |
+
continue
|
136 |
+
|
137 |
+
control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name
|
138 |
+
|
139 |
+
if not should_process(checkpoint_dir, control_file_path, args):
|
140 |
+
continue
|
141 |
+
|
142 |
+
print(f"Launching cleanup for {checkpoint_dir}")
|
143 |
+
# we could use flock here, to avoid a race condition, but it'd be pointless since each
|
144 |
+
# cronjob is likely to run on a different node and flock only works within a single node
|
145 |
+
control_file_path.touch()
|
146 |
+
|
147 |
+
# cleanup
|
148 |
+
# XXX: enable the actual delete once tested a lot
|
149 |
+
# The delete should be relatively safe since it'll only run if it finds 2 files:
|
150 |
+
# save_dir/opt_step-XXX/s3_sync_is_completed save_dir/opt_step-XXX/eval_is_completed
|
151 |
+
shutil.rmtree(checkpoint_dir, ignore_errors=True)
|
152 |
+
print(f"Checkpoint {checkpoint_dir} deleted")
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
m4/scripts/convert-checkpoints.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
#
|
4 |
+
# This tool converts any deepspeed checkpoints found at given path to hf format
|
5 |
+
#
|
6 |
+
# Example:
|
7 |
+
#
|
8 |
+
# ./convert-checkpoints.py checkpoints-path
|
9 |
+
#
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
import time
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
|
18 |
+
repo_path = Path(__file__).parents[2]
|
19 |
+
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
|
20 |
+
|
21 |
+
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
|
22 |
+
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
|
23 |
+
#
|
24 |
+
# If that file is there it means:
|
25 |
+
#
|
26 |
+
# 1. either the conversion is still running
|
27 |
+
# 2. the conversion got aborted (e.g. cpu-oom)
|
28 |
+
#
|
29 |
+
# to detect aborted conversions we will check if the control file is older than a reasonable time to perform such a conversion
|
30 |
+
control_file_name = "started-convert-checkpoint"
|
31 |
+
# should fine tune - but surely 2h per checkpoint is plenty
|
32 |
+
reasonable_conversion_time_in_secs = 2 * 60 * 60
|
33 |
+
|
34 |
+
|
35 |
+
def run_cmd(cmd, check=True):
|
36 |
+
try:
|
37 |
+
response = subprocess.run(
|
38 |
+
cmd,
|
39 |
+
stderr=subprocess.PIPE,
|
40 |
+
stdout=subprocess.PIPE,
|
41 |
+
check=check,
|
42 |
+
encoding="utf-8",
|
43 |
+
).stdout.strip()
|
44 |
+
except subprocess.CalledProcessError as exc:
|
45 |
+
raise EnvironmentError(exc.stderr)
|
46 |
+
|
47 |
+
return response
|
48 |
+
|
49 |
+
|
50 |
+
def get_args():
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
|
53 |
+
parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
|
54 |
+
return parser.parse_args()
|
55 |
+
|
56 |
+
|
57 |
+
def exit(msg):
|
58 |
+
print(msg)
|
59 |
+
sys.exit()
|
60 |
+
|
61 |
+
|
62 |
+
def should_process(path, force, control_file_path):
|
63 |
+
"""Heuristics to decide whether to convert this opt_step-XXX checkpoint or not"""
|
64 |
+
|
65 |
+
target_dir = path / "unwrapped_model"
|
66 |
+
|
67 |
+
# easy checks - the conversion is clearly completed
|
68 |
+
if (target_dir / "pytorch_model.bin").exists() or (target_dir / "pytorch_model.bin.index.json").exists():
|
69 |
+
print(f"[N] {path} appears to be already converted. Skipping")
|
70 |
+
return False
|
71 |
+
|
72 |
+
if force:
|
73 |
+
print("[Y] Forced to re-convert {checkpoint_dir}")
|
74 |
+
return True
|
75 |
+
|
76 |
+
# complicated checks - has another job already started processing? or did it crash?
|
77 |
+
control_file_path = target_dir / control_file_name
|
78 |
+
if control_file_path.exists():
|
79 |
+
if control_file_path.stat().st_mtime < time.time() - reasonable_conversion_time_in_secs:
|
80 |
+
print(f"[Y] {path} looks stale - probably aborted job. Re-converting")
|
81 |
+
return True
|
82 |
+
else:
|
83 |
+
print(
|
84 |
+
f"[N] {path} either another job is converting it or less than"
|
85 |
+
f" {reasonable_conversion_time_in_secs} secs has passed since it was launched. Skipping"
|
86 |
+
)
|
87 |
+
return False
|
88 |
+
else:
|
89 |
+
print(f"[Y] {path} is a new checkpoint. Converting")
|
90 |
+
return True
|
91 |
+
|
92 |
+
|
93 |
+
def main():
|
94 |
+
args = get_args()
|
95 |
+
|
96 |
+
checkpoints_path = Path(args.checkpoints_path)
|
97 |
+
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
|
98 |
+
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
|
99 |
+
|
100 |
+
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
|
101 |
+
if len(checkpoint_dirs) == 0:
|
102 |
+
exit("No checkpoints found, exiting")
|
103 |
+
|
104 |
+
# Check each folder in real time to allow for overlapping jobs starting at different times
|
105 |
+
for checkpoint_dir in checkpoint_dirs:
|
106 |
+
print(f"\n*** Checking {checkpoint_dir}")
|
107 |
+
|
108 |
+
control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name
|
109 |
+
|
110 |
+
if not should_process(checkpoint_dir, args.force, control_file_path):
|
111 |
+
continue
|
112 |
+
|
113 |
+
print(f"Launching conversion for {checkpoint_dir} - it could take a long time")
|
114 |
+
cmd = [zero_checkpoint_to_hf_path, checkpoint_dir]
|
115 |
+
# we could use flock here, to avoid a race condition, but it'd be pointless since each
|
116 |
+
# cronjob is likely to run on a different node and flock only works within a single node
|
117 |
+
control_file_path.touch()
|
118 |
+
response = run_cmd(cmd)
|
119 |
+
control_file_path.unlink()
|
120 |
+
print(response)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
m4/scripts/s3-upload-checkpoints.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
#
|
4 |
+
# This tool uploads any new deepspeed checkpoints found at given path to s3 (and also various non-checkpoint files, like logs)
|
5 |
+
#
|
6 |
+
# Example:
|
7 |
+
#
|
8 |
+
# ./s3-upload-checkpoints.py checkpoints-path
|
9 |
+
#
|
10 |
+
# Use `-h` for more options
|
11 |
+
#
|
12 |
+
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import subprocess
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
|
21 |
+
repo_path = Path(__file__).resolve().parents[2]
|
22 |
+
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
|
23 |
+
|
24 |
+
RETRIES = 5
|
25 |
+
|
26 |
+
# what dir/file glob patterns to include in the upload besides checkpoints
|
27 |
+
include_patterns = ["tb_run_*", "logs", "config.yaml"]
|
28 |
+
|
29 |
+
|
30 |
+
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
|
31 |
+
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
|
32 |
+
#
|
33 |
+
# If that file is there it means:
|
34 |
+
#
|
35 |
+
# 1. either the upload is still running
|
36 |
+
# 2. the upload got aborted (e.g. cpu-oom)
|
37 |
+
#
|
38 |
+
# to detect aborted uploads we will check if the control file is older than a reasonable time to perform such a upload
|
39 |
+
control_file_name = "started-upload-checkpoint"
|
40 |
+
finished_uploading_file_name = "finished-upload-checkpoint"
|
41 |
+
# should fine tune - but surely 2h per checkpoint is plenty
|
42 |
+
reasonable_upload_time_in_secs = 2 * 60 * 60
|
43 |
+
|
44 |
+
|
45 |
+
def run_cmd(cmd, check=True):
|
46 |
+
try:
|
47 |
+
response = subprocess.run(
|
48 |
+
cmd,
|
49 |
+
stderr=subprocess.PIPE,
|
50 |
+
stdout=subprocess.PIPE,
|
51 |
+
check=check,
|
52 |
+
encoding="utf-8",
|
53 |
+
).stdout.strip()
|
54 |
+
except subprocess.CalledProcessError as exc:
|
55 |
+
raise EnvironmentError(exc.stderr)
|
56 |
+
|
57 |
+
return response
|
58 |
+
|
59 |
+
|
60 |
+
def get_args():
|
61 |
+
parser = argparse.ArgumentParser()
|
62 |
+
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
|
63 |
+
# parser.add_argument("experiment_name", type=str, help="experiment name as a s3 sub-dir")
|
64 |
+
parser.add_argument("-f", "--force", action="store_true", help="force uploading of all checkpoints")
|
65 |
+
parser.add_argument(
|
66 |
+
"--skip-conversion-check", action="store_true", help="skip checkpoint conversion is done check"
|
67 |
+
)
|
68 |
+
return parser.parse_args()
|
69 |
+
|
70 |
+
|
71 |
+
def exit(msg):
|
72 |
+
print(msg)
|
73 |
+
sys.exit()
|
74 |
+
|
75 |
+
|
76 |
+
def should_process(path, force, control_file_path, finished_uploading_file_path, args):
|
77 |
+
"""Heuristics to decide whether to upload this opt_step-XXX checkpoint or not"""
|
78 |
+
|
79 |
+
# check if checkpoint is fully saved
|
80 |
+
finished_saving_path = path / "finished-saving" # defined in from trainer.py
|
81 |
+
if not finished_saving_path.exists():
|
82 |
+
print(f"[N] {path} isn't finished saving. Skipping")
|
83 |
+
return False
|
84 |
+
|
85 |
+
if force:
|
86 |
+
print("[Y] Forced to re-process {checkpoint_dir}")
|
87 |
+
return True
|
88 |
+
|
89 |
+
# check if already uploaded
|
90 |
+
if finished_uploading_file_path.exists():
|
91 |
+
print(f"[N] {path} has already been uploaded. Skipping")
|
92 |
+
return False
|
93 |
+
|
94 |
+
# check conversion is completed
|
95 |
+
if not args.skip_conversion_check:
|
96 |
+
converted_model_path_1 = path / "unwrapped_model" / "pytorch_model.bin.index.json"
|
97 |
+
converted_model_path_2 = path / "unwrapped_model" / "pytorch_model.bin"
|
98 |
+
if not converted_model_path_1.exists() and not converted_model_path_2.exists():
|
99 |
+
print(f"[N] {path} doesn't have a converted model. Skipping")
|
100 |
+
return False
|
101 |
+
|
102 |
+
# complicated checks - has another job already started uploading? or did it crash?
|
103 |
+
if control_file_path.exists():
|
104 |
+
if control_file_path.stat().st_mtime < time.time() - reasonable_upload_time_in_secs:
|
105 |
+
print(f"[Y] {path} looks stale - probably aborted job. Re-uploading")
|
106 |
+
return True
|
107 |
+
else:
|
108 |
+
print(
|
109 |
+
f"[N] {path} either another job is uploading it or less than"
|
110 |
+
f" {reasonable_upload_time_in_secs} secs has passed since it was launched. Skipping"
|
111 |
+
)
|
112 |
+
return False
|
113 |
+
else:
|
114 |
+
print(f"[Y] {path} is a new checkpoint. Uploading")
|
115 |
+
return True
|
116 |
+
|
117 |
+
|
118 |
+
def main():
|
119 |
+
args = get_args()
|
120 |
+
|
121 |
+
checkpoints_path = Path(args.checkpoints_path)
|
122 |
+
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
|
123 |
+
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
|
124 |
+
|
125 |
+
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
|
126 |
+
if len(checkpoint_dirs) == 0:
|
127 |
+
exit("No checkpoints found, exiting")
|
128 |
+
|
129 |
+
exp_name = checkpoints_path.name
|
130 |
+
|
131 |
+
# Check each folder in real time to allow for overlapping jobs starting at different times
|
132 |
+
for checkpoint_dir in checkpoint_dirs:
|
133 |
+
print(f"\n*** Checking {checkpoint_dir}")
|
134 |
+
|
135 |
+
control_file_path = checkpoint_dir / control_file_name
|
136 |
+
finished_uploading_file_path = checkpoint_dir / finished_uploading_file_name
|
137 |
+
|
138 |
+
if not should_process(checkpoint_dir, args.force, control_file_path, finished_uploading_file_path, args):
|
139 |
+
continue
|
140 |
+
|
141 |
+
opt_step = checkpoint_dir.name
|
142 |
+
bucket_name = "m4-exps"
|
143 |
+
bucket_path = f"{exp_name}/{opt_step}"
|
144 |
+
|
145 |
+
print(f"Launching upload for {checkpoint_dir} - it could take a long time")
|
146 |
+
cmd = f"s5cmd sync {checkpoint_dir}/ s3://{bucket_name}/{bucket_path}/".split()
|
147 |
+
# we could use flock here, to avoid a race condition, but it'd be pointless since each
|
148 |
+
# cronjob is likely to run on a different node and flock only works within a single node
|
149 |
+
control_file_path.touch()
|
150 |
+
# print(f"mock running {cmd}")
|
151 |
+
|
152 |
+
# s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
|
153 |
+
# ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
|
154 |
+
# you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
|
155 |
+
# XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
|
156 |
+
|
157 |
+
tries = 0
|
158 |
+
while tries < RETRIES:
|
159 |
+
tries += 1
|
160 |
+
try:
|
161 |
+
response = run_cmd(cmd)
|
162 |
+
print(response)
|
163 |
+
break
|
164 |
+
except EnvironmentError as e:
|
165 |
+
if "InvalidDigest" in str(e):
|
166 |
+
print(f"MD5 checksum failed, upload retry {tries}")
|
167 |
+
continue
|
168 |
+
except Exception:
|
169 |
+
# some other possible failure?
|
170 |
+
raise
|
171 |
+
|
172 |
+
# for now disable this as large files don't have sha256 checksums
|
173 |
+
# result = integrity_check_recursive(checkpoint_dir, bucket_name, bucket_path)
|
174 |
+
# print(f"Integrity check was {result}")
|
175 |
+
|
176 |
+
control_file_path.unlink()
|
177 |
+
finished_uploading_file_path.touch()
|
178 |
+
|
179 |
+
# now upload non-checkpoint files
|
180 |
+
print("\n*** Uploading non-checkpoint files")
|
181 |
+
upload_dirs = []
|
182 |
+
for pat in include_patterns:
|
183 |
+
upload_dirs += list(checkpoints_path.glob(pat))
|
184 |
+
|
185 |
+
for dir in upload_dirs:
|
186 |
+
print(f"Launching upload for {dir}")
|
187 |
+
cmd = f"s5cmd sync {dir} s3://m4-exps/{exp_name}/".split()
|
188 |
+
print(f"running {cmd}")
|
189 |
+
response = run_cmd(cmd)
|
190 |
+
print(response)
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|
m4/scripts/s3_checkpoint_download_convert_upload.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
#
|
4 |
+
# This tool converts any deepspeed checkpoints found at given path to hf format
|
5 |
+
#
|
6 |
+
# Example:
|
7 |
+
#
|
8 |
+
# ./convert-checkpoints.py checkpoints-path
|
9 |
+
#
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import boto3
|
17 |
+
|
18 |
+
|
19 |
+
def check_s3_directory(directory_path):
|
20 |
+
s3 = boto3.client("s3")
|
21 |
+
|
22 |
+
# Add a trailing slash to the directory path
|
23 |
+
if not directory_path.endswith("/"):
|
24 |
+
directory_path += "/"
|
25 |
+
|
26 |
+
# Check if any objects exist with the given directory prefix
|
27 |
+
response = s3.list_objects_v2(Bucket="m4-exps", Prefix=directory_path)
|
28 |
+
|
29 |
+
# If any objects are found, the directory exists
|
30 |
+
if "Contents" in response:
|
31 |
+
return True
|
32 |
+
|
33 |
+
return False
|
34 |
+
|
35 |
+
|
36 |
+
def check_s3_file(file_key):
|
37 |
+
s3 = boto3.client("s3")
|
38 |
+
|
39 |
+
try:
|
40 |
+
s3.head_object(Bucket="m4-exps", Key=file_key)
|
41 |
+
return True
|
42 |
+
except Exception:
|
43 |
+
return False
|
44 |
+
|
45 |
+
|
46 |
+
def run_cmd(cmd, check=True):
|
47 |
+
try:
|
48 |
+
response = subprocess.run(
|
49 |
+
cmd,
|
50 |
+
stderr=subprocess.PIPE,
|
51 |
+
stdout=subprocess.PIPE,
|
52 |
+
check=check,
|
53 |
+
encoding="utf-8",
|
54 |
+
).stdout.strip()
|
55 |
+
except subprocess.CalledProcessError as exc:
|
56 |
+
raise EnvironmentError(exc.stderr)
|
57 |
+
|
58 |
+
return response
|
59 |
+
|
60 |
+
|
61 |
+
def get_args():
|
62 |
+
parser = argparse.ArgumentParser()
|
63 |
+
parser.add_argument("run_name", type=str, help="run name")
|
64 |
+
parser.add_argument("opt_step_num_list", nargs="+", help="list of opt-steps to download")
|
65 |
+
parser.add_argument("repo_path", type=str, help="repo path")
|
66 |
+
|
67 |
+
parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
|
68 |
+
return parser.parse_args()
|
69 |
+
|
70 |
+
|
71 |
+
def exit(msg):
|
72 |
+
print(msg)
|
73 |
+
sys.exit()
|
74 |
+
|
75 |
+
|
76 |
+
def cmd_retry_loop(cmd, max_retries=5):
|
77 |
+
# s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
|
78 |
+
# ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
|
79 |
+
# you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
|
80 |
+
# XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
|
81 |
+
|
82 |
+
tries = 0
|
83 |
+
while tries < max_retries:
|
84 |
+
tries += 1
|
85 |
+
try:
|
86 |
+
response = run_cmd(cmd)
|
87 |
+
print(response)
|
88 |
+
break
|
89 |
+
except EnvironmentError as e:
|
90 |
+
if "InvalidDigest" in str(e):
|
91 |
+
print(f"MD5 checksum failed, download retry {tries}")
|
92 |
+
continue
|
93 |
+
except Exception:
|
94 |
+
# some other possible failure?
|
95 |
+
raise
|
96 |
+
return response
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
args = get_args()
|
101 |
+
|
102 |
+
run_name = args.run_name
|
103 |
+
opt_step_num_list = args.opt_step_num_list
|
104 |
+
repo_path = Path(args.repo_path)
|
105 |
+
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
|
106 |
+
bucket_name = "m4-exps"
|
107 |
+
opt_step_s3_file_keys = [f"{run_name}/opt_step-{opt_step_num}" for opt_step_num in opt_step_num_list]
|
108 |
+
|
109 |
+
check_s3_directory(run_name)
|
110 |
+
|
111 |
+
# Check each folder in real time to allow for overlapping jobs starting at different times
|
112 |
+
for opt_step_s3_file_key in opt_step_s3_file_keys:
|
113 |
+
print(f"\n*** Checking {opt_step_s3_file_key}")
|
114 |
+
if not check_s3_directory(opt_step_s3_file_key):
|
115 |
+
print(f"The checkpoint {opt_step_s3_file_key} does not exist - skipping")
|
116 |
+
continue
|
117 |
+
unwrapped_model_s3_file_key = f"{opt_step_s3_file_key}/unwrapped_model"
|
118 |
+
bin_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin"
|
119 |
+
index_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin.index.json"
|
120 |
+
is_not_converted = not check_s3_file(bin_s3_file_key) and not check_s3_file(index_s3_file_key)
|
121 |
+
if is_not_converted:
|
122 |
+
print(
|
123 |
+
f"The checkpoint hasn't been converted, launching download for {opt_step_s3_file_key} - it could take"
|
124 |
+
" a long time"
|
125 |
+
)
|
126 |
+
|
127 |
+
opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
|
128 |
+
cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
|
129 |
+
cmd = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/* {cluster_opt_step_dir}".split()
|
130 |
+
download_response_opt_step_dir = cmd_retry_loop(cmd, max_retries=5)
|
131 |
+
print(f"download_response_opt_step_dir: {download_response_opt_step_dir}")
|
132 |
+
else:
|
133 |
+
print(
|
134 |
+
"The checkpoint has been converted already, downloading only the unwrapped checkpoint and"
|
135 |
+
" tokenizer dir"
|
136 |
+
)
|
137 |
+
opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
|
138 |
+
cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
|
139 |
+
unwrapped_model_dir = f"{cluster_opt_step_dir}/unwrapped_model"
|
140 |
+
tokenizer_dir = f"{cluster_opt_step_dir}/tokenizer"
|
141 |
+
cmd_model = (
|
142 |
+
f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/* {unwrapped_model_dir}".split()
|
143 |
+
)
|
144 |
+
cmd_tokenizer = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/tokenizer/* {tokenizer_dir}".split()
|
145 |
+
download_response_model = cmd_retry_loop(cmd_model, max_retries=5)
|
146 |
+
print(f"download_response_model: {download_response_model}")
|
147 |
+
download_response_tokenizer = cmd_retry_loop(cmd_tokenizer, max_retries=5)
|
148 |
+
print(f"download_response_tokenizer: {download_response_tokenizer}")
|
149 |
+
|
150 |
+
print(f"opt_step_dirname: {opt_step_dirname} downloaded to cluster_opt_step_dir: {cluster_opt_step_dir}")
|
151 |
+
|
152 |
+
if is_not_converted:
|
153 |
+
print(f"Converting {cluster_opt_step_dir}")
|
154 |
+
convert_cmd = [zero_checkpoint_to_hf_path, cluster_opt_step_dir]
|
155 |
+
conversion_response = run_cmd(convert_cmd)
|
156 |
+
print(f"conversion_response: {conversion_response}")
|
157 |
+
print(f"upload converted checkpoint: {cluster_opt_step_dir}")
|
158 |
+
upload_cmd = (
|
159 |
+
f"s5cmd sync {cluster_opt_step_dir}/unwrapped_model/"
|
160 |
+
f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/ ".split()
|
161 |
+
)
|
162 |
+
upload_response = cmd_retry_loop(upload_cmd, max_retries=5)
|
163 |
+
print(f"upload_response: {upload_response}")
|
164 |
+
print(
|
165 |
+
f"Uploaded {cluster_opt_step_dir}/unwrapped_model to"
|
166 |
+
f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model"
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
main()
|
m4/scripts/s3_checkpoint_download_convert_upload.slurm
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=tr_test-s3-download-and-convert-checkpoints
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --nodes=1
|
5 |
+
#SBATCH --time=3:00:00
|
6 |
+
#SBATCH --partition=production-cluster
|
7 |
+
#SBATCH --output=/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/logs/%x-%j.out
|
8 |
+
|
9 |
+
|
10 |
+
set -e
|
11 |
+
|
12 |
+
# ----------------- Auto-Workdir -----------------
|
13 |
+
if [ -n $SLURM_JOB_ID ]; then
|
14 |
+
# check the original location through scontrol and $SLURM_JOB_ID
|
15 |
+
SCRIPT_PATH=$(scontrol show job $SLURM_JOB_ID | awk -F= '/Command=/{print $2}')
|
16 |
+
else
|
17 |
+
# otherwise: started with bash. Get the real location.
|
18 |
+
SCRIPT_PATH=$(realpath $0)
|
19 |
+
fi
|
20 |
+
SCRIPT_DIR=$(dirname ${SCRIPT_PATH})
|
21 |
+
M4_REPO_PATH=$(builtin cd $SCRIPT_DIR/../../; pwd)
|
22 |
+
|
23 |
+
# --------------------------------------------------
|
24 |
+
|
25 |
+
### EDIT ME START ###
|
26 |
+
|
27 |
+
CONDA_ENV_NAME=shared-m4
|
28 |
+
|
29 |
+
EXPERIMENT_NAME=tr_194_laion_cm4_mix
|
30 |
+
|
31 |
+
opt_step_num_list=(
|
32 |
+
"1000"
|
33 |
+
"2000"
|
34 |
+
)
|
35 |
+
|
36 |
+
### EDIT ME END ###
|
37 |
+
|
38 |
+
|
39 |
+
echo "START TIME: $(date)"
|
40 |
+
|
41 |
+
source /fsx/m4/start-m4-user
|
42 |
+
conda activate base
|
43 |
+
conda activate $CONDA_ENV_NAME
|
44 |
+
pushd $M4_REPO_PATH
|
45 |
+
export PYTHONPATH=$WORKING_DIR:$PYTHONPATH
|
46 |
+
|
47 |
+
echo "running checkpoint download, convert, upload for opt-steps: ${opt_step_num_list[@]} of experiment: $EXPERIMENT_NAME"
|
48 |
+
|
49 |
+
python $M4_REPO_PATH/m4/scripts/s3_checkpoint_download_convert_upload.py $EXPERIMENT_NAME ${opt_step_num_list[@]} $M4_REPO_PATH
|
50 |
+
|
51 |
+
echo "END TIME: $(date)"
|
m4/scripts/s3_downloaded_checkpoints_cleanup.slurm
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --job-name=tr_test-s3-cleanup-checkpoints
|
3 |
+
#SBATCH --ntasks=1
|
4 |
+
#SBATCH --nodes=1
|
5 |
+
#SBATCH --time=3:00:00
|
6 |
+
#SBATCH --partition=production-cluster
|
7 |
+
#SBATCH --output=/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/logs/%x-%j.out
|
8 |
+
|
9 |
+
|
10 |
+
set -e
|
11 |
+
|
12 |
+
# ----------------- Auto-Workdir -----------------
|
13 |
+
if [ -n $SLURM_JOB_ID ]; then
|
14 |
+
# check the original location through scontrol and $SLURM_JOB_ID
|
15 |
+
SCRIPT_PATH=$(scontrol show job $SLURM_JOB_ID | awk -F= '/Command=/{print $2}')
|
16 |
+
else
|
17 |
+
# otherwise: started with bash. Get the real location.
|
18 |
+
SCRIPT_PATH=$(realpath $0)
|
19 |
+
fi
|
20 |
+
SCRIPT_DIR=$(dirname ${SCRIPT_PATH})
|
21 |
+
M4_REPO_PATH=$(builtin cd $SCRIPT_DIR/../../; pwd)
|
22 |
+
|
23 |
+
# --------------------------------------------------
|
24 |
+
|
25 |
+
### EDIT ME START ###
|
26 |
+
|
27 |
+
CONDA_ENV_NAME=shared-m4
|
28 |
+
|
29 |
+
EXPERIMENT_NAME=tr_194_laion_cm4_mix
|
30 |
+
|
31 |
+
opt_step_num_list=(
|
32 |
+
"1000"
|
33 |
+
"2000"
|
34 |
+
)
|
35 |
+
|
36 |
+
### EDIT ME END ###
|
37 |
+
|
38 |
+
|
39 |
+
echo "START TIME: $(date)"
|
40 |
+
|
41 |
+
source /fsx/m4/start-m4-user
|
42 |
+
conda activate base
|
43 |
+
conda activate $CONDA_ENV_NAME
|
44 |
+
pushd $M4_REPO_PATH
|
45 |
+
export PYTHONPATH=$WORKING_DIR:$PYTHONPATH
|
46 |
+
|
47 |
+
for opt_step_num in ${opt_step_num_list[@]}
|
48 |
+
do
|
49 |
+
OPT_STEP_DIR="/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/${EXPERIMENT_NAME}/opt_step-${opt_step_num}"
|
50 |
+
rm -r $OPT_STEP_DIR
|
51 |
+
echo "Deleted $OPT_STEP_DIR of experiment: $EXPERIMENT_NAME"
|
52 |
+
done
|
53 |
+
|
54 |
+
echo "END TIME: $(date)"
|
m4/scripts/schedule-evals.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
#
|
4 |
+
# This tool checks if evaluation is needed
|
5 |
+
#
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
|
15 |
+
repo_path = Path(__file__).parents[2]
|
16 |
+
|
17 |
+
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
|
18 |
+
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
|
19 |
+
#
|
20 |
+
# If that file is there it means:
|
21 |
+
#
|
22 |
+
# 1. either the eval is still running
|
23 |
+
# 2. the eval got aborted (e.g. gpu-oom)
|
24 |
+
#
|
25 |
+
|
26 |
+
# should fine tune - but surely 9h per checkpoint is plenty
|
27 |
+
reasonable_eval_time_in_secs = 9 * 60 * 60
|
28 |
+
|
29 |
+
|
30 |
+
def run_cmd(cmd, check=True):
|
31 |
+
try:
|
32 |
+
response = subprocess.run(
|
33 |
+
cmd,
|
34 |
+
stderr=subprocess.PIPE,
|
35 |
+
stdout=subprocess.PIPE,
|
36 |
+
check=check,
|
37 |
+
encoding="utf-8",
|
38 |
+
).stdout.strip()
|
39 |
+
except subprocess.CalledProcessError as exc:
|
40 |
+
raise EnvironmentError(exc.stderr)
|
41 |
+
|
42 |
+
return response
|
43 |
+
|
44 |
+
|
45 |
+
def get_args():
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
|
48 |
+
return parser.parse_args()
|
49 |
+
|
50 |
+
|
51 |
+
def exit(msg):
|
52 |
+
print(msg)
|
53 |
+
sys.exit()
|
54 |
+
|
55 |
+
|
56 |
+
def check_eval_crash(path):
|
57 |
+
"""Heuristics to decide whether to restart this opt_step-XXX checkpoint evaluation or not"""
|
58 |
+
eval_0_completed_path = path / "start_run_evals_0_shots"
|
59 |
+
eval_4_completed_path = path / "start_run_evals_4_shots"
|
60 |
+
eval_perplexity_path = path / "start_run_evals_perplexity_validation"
|
61 |
+
# complicated checks - has another job already started processing? or did it crash?
|
62 |
+
for eval_start_path in [eval_0_completed_path, eval_4_completed_path, eval_perplexity_path]:
|
63 |
+
if eval_start_path.exists():
|
64 |
+
if eval_start_path.stat().st_mtime < time.time() - reasonable_eval_time_in_secs:
|
65 |
+
print(f"[Y] {path} looks stale - Probably crashed - Restart evals")
|
66 |
+
os.remove(eval_start_path)
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
args = get_args()
|
71 |
+
|
72 |
+
checkpoints_path = Path(args.checkpoints_path)
|
73 |
+
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
|
74 |
+
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
|
75 |
+
|
76 |
+
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
|
77 |
+
if len(checkpoint_dirs) == 0:
|
78 |
+
exit("No checkpoints found, exiting")
|
79 |
+
|
80 |
+
checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
|
81 |
+
for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
|
82 |
+
print(f"\n*** Checking {checkpoint_dir} for evals")
|
83 |
+
check_eval_crash(checkpoint_dir)
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|
m4/testing_utils.py
ADDED
@@ -0,0 +1,1116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# adapted from https://stackoverflow.com/a/59041913/9201239
|
16 |
+
import asyncio # noqa
|
17 |
+
import contextlib
|
18 |
+
import importlib.util
|
19 |
+
import inspect
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
import random
|
24 |
+
import re
|
25 |
+
import shutil
|
26 |
+
import sys
|
27 |
+
import tempfile
|
28 |
+
import unittest
|
29 |
+
from distutils.util import strtobool
|
30 |
+
from io import StringIO
|
31 |
+
from pathlib import Path
|
32 |
+
from typing import Iterator, Union
|
33 |
+
from unittest import mock
|
34 |
+
from unittest.case import SkipTest
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
from packaging import version
|
38 |
+
from parameterized import parameterized
|
39 |
+
|
40 |
+
|
41 |
+
try:
|
42 |
+
import torch
|
43 |
+
|
44 |
+
_torch_available = True
|
45 |
+
except Exception:
|
46 |
+
_torch_available = False
|
47 |
+
|
48 |
+
|
49 |
+
def is_torch_available():
|
50 |
+
return _torch_available
|
51 |
+
|
52 |
+
|
53 |
+
def parse_flag_from_env(key, default=False):
|
54 |
+
try:
|
55 |
+
value = os.environ[key]
|
56 |
+
except KeyError:
|
57 |
+
# KEY isn't set, default to `default`.
|
58 |
+
_value = default
|
59 |
+
else:
|
60 |
+
# KEY is set, convert it to True or False.
|
61 |
+
try:
|
62 |
+
_value = strtobool(value)
|
63 |
+
except ValueError:
|
64 |
+
# More values are supported, but let's keep the message simple.
|
65 |
+
raise ValueError(f"If set, {key} must be yes or no.")
|
66 |
+
return _value
|
67 |
+
|
68 |
+
|
69 |
+
def parse_int_from_env(key, default=None):
|
70 |
+
try:
|
71 |
+
value = os.environ[key]
|
72 |
+
except KeyError:
|
73 |
+
_value = default
|
74 |
+
else:
|
75 |
+
try:
|
76 |
+
_value = int(value)
|
77 |
+
except ValueError:
|
78 |
+
raise ValueError(f"If set, {key} must be a int.")
|
79 |
+
return _value
|
80 |
+
|
81 |
+
|
82 |
+
def require_torch(test_case):
|
83 |
+
"""
|
84 |
+
Decorator marking a test that requires PyTorch.
|
85 |
+
|
86 |
+
These tests are skipped when PyTorch isn't installed.
|
87 |
+
|
88 |
+
"""
|
89 |
+
if not is_torch_available():
|
90 |
+
return unittest.skip("test requires PyTorch")(test_case)
|
91 |
+
else:
|
92 |
+
return test_case
|
93 |
+
|
94 |
+
|
95 |
+
def require_torch_no_gpus(test_case):
|
96 |
+
"""
|
97 |
+
Decorator marking a test that requires a setup without GPUs (in PyTorch). These tests are skipped on a machine with GPUs.
|
98 |
+
|
99 |
+
To run *only* the no gpu tests, assuming all test names contain no_gpu: $ pytest -sv ./tests -k "no_gpu"
|
100 |
+
"""
|
101 |
+
import torch
|
102 |
+
|
103 |
+
if is_torch_available() and torch.cuda.device_count() > 0:
|
104 |
+
return unittest.skip("test requires an environment w/o GPUs")(test_case)
|
105 |
+
else:
|
106 |
+
return test_case
|
107 |
+
|
108 |
+
|
109 |
+
def require_torch_multi_gpu(test_case):
|
110 |
+
"""
|
111 |
+
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
112 |
+
multiple GPUs.
|
113 |
+
|
114 |
+
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
|
115 |
+
"""
|
116 |
+
if not is_torch_available():
|
117 |
+
return unittest.skip("test requires PyTorch")(test_case)
|
118 |
+
|
119 |
+
import torch
|
120 |
+
|
121 |
+
if torch.cuda.device_count() < 2:
|
122 |
+
return unittest.skip("test requires multiple GPUs")(test_case)
|
123 |
+
else:
|
124 |
+
return test_case
|
125 |
+
|
126 |
+
|
127 |
+
def require_torch_non_multi_gpu(test_case):
|
128 |
+
"""
|
129 |
+
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
130 |
+
"""
|
131 |
+
if not is_torch_available():
|
132 |
+
return unittest.skip("test requires PyTorch")(test_case)
|
133 |
+
|
134 |
+
import torch
|
135 |
+
|
136 |
+
if torch.cuda.device_count() > 1:
|
137 |
+
return unittest.skip("test requires 0 or 1 GPU")(test_case)
|
138 |
+
else:
|
139 |
+
return test_case
|
140 |
+
|
141 |
+
|
142 |
+
def require_torch_up_to_2_gpus(test_case):
|
143 |
+
"""
|
144 |
+
Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
|
145 |
+
"""
|
146 |
+
if not is_torch_available():
|
147 |
+
return unittest.skip("test requires PyTorch")(test_case)
|
148 |
+
|
149 |
+
import torch
|
150 |
+
|
151 |
+
if torch.cuda.device_count() > 2:
|
152 |
+
return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
|
153 |
+
else:
|
154 |
+
return test_case
|
155 |
+
|
156 |
+
|
157 |
+
if is_torch_available():
|
158 |
+
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
159 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
160 |
+
else:
|
161 |
+
torch_device = None
|
162 |
+
|
163 |
+
|
164 |
+
def require_torch_gpu(test_case):
|
165 |
+
"""Decorator marking a test that requires CUDA and PyTorch."""
|
166 |
+
if torch_device != "cuda":
|
167 |
+
return unittest.skip("test requires CUDA")(test_case)
|
168 |
+
else:
|
169 |
+
return test_case
|
170 |
+
|
171 |
+
|
172 |
+
def is_deepspeed_available():
|
173 |
+
return importlib.util.find_spec("deepspeed") is not None
|
174 |
+
|
175 |
+
|
176 |
+
def require_deepspeed(test_case):
|
177 |
+
"""
|
178 |
+
Decorator marking a test that requires deepspeed
|
179 |
+
"""
|
180 |
+
if not is_deepspeed_available():
|
181 |
+
return unittest.skip("test requires deepspeed")(test_case)
|
182 |
+
else:
|
183 |
+
return test_case
|
184 |
+
|
185 |
+
|
186 |
+
def is_bnb_available():
|
187 |
+
return importlib.util.find_spec("bitsandbytes") is not None
|
188 |
+
|
189 |
+
|
190 |
+
def require_bnb(test_case):
|
191 |
+
"""
|
192 |
+
Decorator marking a test that requires bitsandbytes
|
193 |
+
"""
|
194 |
+
if not is_bnb_available():
|
195 |
+
return unittest.skip("test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")(
|
196 |
+
test_case
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
return test_case
|
200 |
+
|
201 |
+
|
202 |
+
def require_bnb_non_decorator():
|
203 |
+
"""
|
204 |
+
Non-Decorator function that would skip a test if bitsandbytes is missing
|
205 |
+
"""
|
206 |
+
if not is_bnb_available():
|
207 |
+
raise SkipTest("Test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")
|
208 |
+
|
209 |
+
|
210 |
+
def set_seed(seed: int = 42):
|
211 |
+
"""
|
212 |
+
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``
|
213 |
+
|
214 |
+
Args:
|
215 |
+
seed (:obj:`int`): The seed to set.
|
216 |
+
"""
|
217 |
+
random.seed(seed)
|
218 |
+
np.random.seed(seed)
|
219 |
+
if is_torch_available():
|
220 |
+
torch.manual_seed(seed)
|
221 |
+
torch.cuda.manual_seed_all(seed)
|
222 |
+
# ^^ safe to call this function even if cuda is not available
|
223 |
+
|
224 |
+
|
225 |
+
def get_gpu_count():
|
226 |
+
"""
|
227 |
+
Return the number of available gpus (regardless of whether torch or tf is used)
|
228 |
+
"""
|
229 |
+
if is_torch_available():
|
230 |
+
import torch
|
231 |
+
|
232 |
+
return torch.cuda.device_count()
|
233 |
+
else:
|
234 |
+
return 0
|
235 |
+
|
236 |
+
|
237 |
+
def torch_assert_equal(actual, expected, **kwargs):
|
238 |
+
"""
|
239 |
+
compare two tensors or non-tensor numbers for their equality
|
240 |
+
"""
|
241 |
+
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
|
242 |
+
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
|
243 |
+
|
244 |
+
|
245 |
+
def torch_assert_close(actual, expected, **kwargs):
|
246 |
+
"""
|
247 |
+
compare two tensors or non-tensor numbers for their closeness.
|
248 |
+
"""
|
249 |
+
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
|
250 |
+
return torch.testing.assert_close(actual, expected, **kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
def is_torch_bf16_available():
|
254 |
+
# from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301
|
255 |
+
if is_torch_available():
|
256 |
+
import torch
|
257 |
+
|
258 |
+
if not torch.cuda.is_available() or torch.version.cuda is None:
|
259 |
+
return False
|
260 |
+
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
261 |
+
return False
|
262 |
+
if int(torch.version.cuda.split(".")[0]) < 11:
|
263 |
+
return False
|
264 |
+
if not version.parse(torch.__version__) >= version.parse("1.09"):
|
265 |
+
return False
|
266 |
+
return True
|
267 |
+
else:
|
268 |
+
return False
|
269 |
+
|
270 |
+
|
271 |
+
def require_torch_bf16(test_case):
|
272 |
+
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9."""
|
273 |
+
if not is_torch_bf16_available():
|
274 |
+
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case)
|
275 |
+
else:
|
276 |
+
return test_case
|
277 |
+
|
278 |
+
|
279 |
+
def get_tests_dir(append_path=None):
|
280 |
+
"""
|
281 |
+
Args:
|
282 |
+
append_path: optional path to append to the tests dir path
|
283 |
+
|
284 |
+
Return:
|
285 |
+
The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
|
286 |
+
joined after the `tests` dir the former is provided.
|
287 |
+
|
288 |
+
"""
|
289 |
+
# this function caller's __file__
|
290 |
+
caller__file__ = inspect.stack()[1][1]
|
291 |
+
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
|
292 |
+
if append_path:
|
293 |
+
return os.path.join(tests_dir, append_path)
|
294 |
+
else:
|
295 |
+
return tests_dir
|
296 |
+
|
297 |
+
|
298 |
+
def parameterized_custom_name_func_join_params(func, param_num, param):
|
299 |
+
"""
|
300 |
+
customize the test name generator function as we want all params to appear in the sub-test
|
301 |
+
name, as by default it shows only the first param or for multiple params it just uses a unique sequence of ids and no params at all.
|
302 |
+
|
303 |
+
Usage:
|
304 |
+
|
305 |
+
@parameterized.expand(
|
306 |
+
[
|
307 |
+
(0, True),
|
308 |
+
(0, False),
|
309 |
+
(1, True),
|
310 |
+
],
|
311 |
+
name_func=parameterized_custom_name_func_join_params,
|
312 |
+
)
|
313 |
+
def test_determinism_wrt_rank(self, num_workers, pad_dataset):
|
314 |
+
|
315 |
+
which gives:
|
316 |
+
|
317 |
+
test_determinism_wrt_rank_0_true
|
318 |
+
test_determinism_wrt_rank_0_false
|
319 |
+
test_determinism_wrt_rank_1_true
|
320 |
+
|
321 |
+
"""
|
322 |
+
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
|
323 |
+
return f"{func.__name__}_{param_based_name}"
|
324 |
+
|
325 |
+
|
326 |
+
#
|
327 |
+
# Helper functions for dealing with testing text outputs
|
328 |
+
# The original code came from:
|
329 |
+
# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
|
330 |
+
|
331 |
+
|
332 |
+
# When any function contains print() calls that get overwritten, like progress bars,
|
333 |
+
# a special care needs to be applied, since under pytest -s captured output (capsys
|
334 |
+
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
|
335 |
+
# \r's. This helper function ensures that the buffer will contain the same output
|
336 |
+
# with and without -s in pytest, by turning:
|
337 |
+
# foo bar\r tar mar\r final message
|
338 |
+
# into:
|
339 |
+
# final message
|
340 |
+
# it can handle a single string or a multiline buffer
|
341 |
+
def apply_print_resets(buf):
|
342 |
+
return re.sub(r"^.*\r", "", buf, 0, re.M)
|
343 |
+
|
344 |
+
|
345 |
+
def assert_screenout(out, what):
|
346 |
+
out_pr = apply_print_resets(out).lower()
|
347 |
+
match_str = out_pr.find(what.lower())
|
348 |
+
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
|
349 |
+
|
350 |
+
|
351 |
+
class CaptureStd:
|
352 |
+
"""
|
353 |
+
Context manager to capture:
|
354 |
+
|
355 |
+
- stdout: replay it, clean it up and make it available via ``obj.out``
|
356 |
+
- stderr: replay it and make it available via ``obj.err``
|
357 |
+
- combined: combined the chosen streams and make it available via ``obj.combined``
|
358 |
+
|
359 |
+
init arguments:
|
360 |
+
|
361 |
+
- out - capture stdout:`` True``/``False``, default ``True``
|
362 |
+
- err - capture stdout: ``True``/``False``, default ``True``
|
363 |
+
- replay - whether to replay or not: ``True``/``False``, default ``True``. By default each
|
364 |
+
captured stream gets replayed back on context's exit, so that one can see what the test was
|
365 |
+
doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass
|
366 |
+
``replay=False`` to disable this feature.
|
367 |
+
|
368 |
+
Examples::
|
369 |
+
|
370 |
+
# to capture stdout only with auto-replay
|
371 |
+
with CaptureStdout() as cs:
|
372 |
+
print("Secret message")
|
373 |
+
assert "message" in cs.out
|
374 |
+
|
375 |
+
# to capture stderr only with auto-replay
|
376 |
+
import sys
|
377 |
+
with CaptureStderr() as cs:
|
378 |
+
print("Warning: ", file=sys.stderr)
|
379 |
+
assert "Warning" in cs.err
|
380 |
+
|
381 |
+
# to capture both streams with auto-replay
|
382 |
+
with CaptureStd() as cs:
|
383 |
+
print("Secret message")
|
384 |
+
print("Warning: ", file=sys.stderr)
|
385 |
+
assert "message" in cs.out
|
386 |
+
assert "Warning" in cs.err
|
387 |
+
|
388 |
+
# to capture just one of the streams, and not the other, with auto-replay
|
389 |
+
with CaptureStd(err=False) as cs:
|
390 |
+
print("Secret message")
|
391 |
+
assert "message" in cs.out
|
392 |
+
# but best use the stream-specific subclasses
|
393 |
+
|
394 |
+
# to capture without auto-replay
|
395 |
+
with CaptureStd(replay=False) as cs:
|
396 |
+
print("Secret message")
|
397 |
+
assert "message" in cs.out
|
398 |
+
|
399 |
+
# sometimes it's easier to not try to figure out if it's stdout or stderr, and yet at
|
400 |
+
# other times the software may send the same output to stderr or stdout depending on
|
401 |
+
# environment, so to make the test robust a combined entry of both streams is available
|
402 |
+
|
403 |
+
"""
|
404 |
+
|
405 |
+
def __init__(self, out=True, err=True, replay=True):
|
406 |
+
self.replay = replay
|
407 |
+
|
408 |
+
if out:
|
409 |
+
self.out_buf = StringIO()
|
410 |
+
self.out = "error: CaptureStd context is unfinished yet, called too early"
|
411 |
+
else:
|
412 |
+
self.out_buf = None
|
413 |
+
self.out = "not capturing stdout"
|
414 |
+
|
415 |
+
if err:
|
416 |
+
self.err_buf = StringIO()
|
417 |
+
self.err = "error: CaptureStd context is unfinished yet, called too early"
|
418 |
+
else:
|
419 |
+
self.err_buf = None
|
420 |
+
self.err = "not capturing stderr"
|
421 |
+
|
422 |
+
self.combined = "error: CaptureStd context is unfinished yet, called too early"
|
423 |
+
|
424 |
+
def __enter__(self):
|
425 |
+
if self.out_buf is not None:
|
426 |
+
self.out_old = sys.stdout
|
427 |
+
sys.stdout = self.out_buf
|
428 |
+
|
429 |
+
if self.err_buf is not None:
|
430 |
+
self.err_old = sys.stderr
|
431 |
+
sys.stderr = self.err_buf
|
432 |
+
|
433 |
+
self.combined = ""
|
434 |
+
|
435 |
+
return self
|
436 |
+
|
437 |
+
def __exit__(self, *exc):
|
438 |
+
if self.out_buf is not None:
|
439 |
+
sys.stdout = self.out_old
|
440 |
+
captured = self.out_buf.getvalue()
|
441 |
+
if self.replay:
|
442 |
+
sys.stdout.write(captured)
|
443 |
+
self.out = apply_print_resets(captured)
|
444 |
+
self.combined += self.out
|
445 |
+
|
446 |
+
if self.err_buf is not None:
|
447 |
+
sys.stderr = self.err_old
|
448 |
+
captured = self.err_buf.getvalue()
|
449 |
+
if self.replay:
|
450 |
+
sys.stderr.write(captured)
|
451 |
+
self.err = captured
|
452 |
+
self.combined += self.err
|
453 |
+
|
454 |
+
def __repr__(self):
|
455 |
+
msg = ""
|
456 |
+
if self.out_buf:
|
457 |
+
msg += f"stdout: {self.out}\n"
|
458 |
+
if self.err_buf:
|
459 |
+
msg += f"stderr: {self.err}\n"
|
460 |
+
return msg
|
461 |
+
|
462 |
+
|
463 |
+
# in tests it's the best to capture only the stream that's wanted, otherwise
|
464 |
+
# it's easy to miss things, so unless you need to capture both streams, use the
|
465 |
+
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
|
466 |
+
# disable the stream you don't need to test.
|
467 |
+
|
468 |
+
|
469 |
+
class CaptureStdout(CaptureStd):
|
470 |
+
"""Same as CaptureStd but captures only stdout"""
|
471 |
+
|
472 |
+
def __init__(self, replay=True):
|
473 |
+
super().__init__(err=False, replay=replay)
|
474 |
+
|
475 |
+
|
476 |
+
class CaptureStderr(CaptureStd):
|
477 |
+
"""Same as CaptureStd but captures only stderr"""
|
478 |
+
|
479 |
+
def __init__(self, replay=True):
|
480 |
+
super().__init__(out=False, replay=replay)
|
481 |
+
|
482 |
+
|
483 |
+
class CaptureLogger:
|
484 |
+
"""
|
485 |
+
Context manager to capture `logging` streams
|
486 |
+
|
487 |
+
Args:
|
488 |
+
|
489 |
+
- logger: 'logging` logger object
|
490 |
+
|
491 |
+
Results:
|
492 |
+
The captured output is available via `self.out`
|
493 |
+
|
494 |
+
Example::
|
495 |
+
|
496 |
+
>>> from transformers import logging
|
497 |
+
>>> from transformers.testing_utils import CaptureLogger
|
498 |
+
|
499 |
+
>>> msg = "Testing 1, 2, 3"
|
500 |
+
>>> logging.set_verbosity_info()
|
501 |
+
>>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
|
502 |
+
>>> with CaptureLogger(logger) as cl:
|
503 |
+
... logger.info(msg)
|
504 |
+
>>> assert cl.out, msg+"\n"
|
505 |
+
"""
|
506 |
+
|
507 |
+
def __init__(self, logger):
|
508 |
+
self.logger = logger
|
509 |
+
self.io = StringIO()
|
510 |
+
self.sh = logging.StreamHandler(self.io)
|
511 |
+
self.out = ""
|
512 |
+
|
513 |
+
def __enter__(self):
|
514 |
+
self.logger.addHandler(self.sh)
|
515 |
+
return self
|
516 |
+
|
517 |
+
def __exit__(self, *exc):
|
518 |
+
self.logger.removeHandler(self.sh)
|
519 |
+
self.out = self.io.getvalue()
|
520 |
+
|
521 |
+
def __repr__(self):
|
522 |
+
return f"captured: {self.out}\n"
|
523 |
+
|
524 |
+
|
525 |
+
@contextlib.contextmanager
|
526 |
+
# adapted from https://stackoverflow.com/a/64789046/9201239
|
527 |
+
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
|
528 |
+
"""
|
529 |
+
Temporary add given path to `sys.path`.
|
530 |
+
|
531 |
+
Usage ::
|
532 |
+
|
533 |
+
with ExtendSysPath('/path/to/dir'):
|
534 |
+
mymodule = importlib.import_module('mymodule')
|
535 |
+
|
536 |
+
"""
|
537 |
+
|
538 |
+
path = os.fspath(path)
|
539 |
+
try:
|
540 |
+
sys.path.insert(0, path)
|
541 |
+
yield
|
542 |
+
finally:
|
543 |
+
sys.path.remove(path)
|
544 |
+
|
545 |
+
|
546 |
+
class TestCasePlus(unittest.TestCase):
|
547 |
+
"""This class extends `unittest.TestCase` with additional features.
|
548 |
+
|
549 |
+
Feature 1: A set of fully resolved important file and dir path accessors.
|
550 |
+
|
551 |
+
In tests often we need to know where things are relative to the current test file, and it's not trivial since the
|
552 |
+
test could be invoked from more than one directory or could reside in sub-directories with different depths. This
|
553 |
+
class solves this problem by sorting out all the basic paths and provides easy accessors to them:
|
554 |
+
|
555 |
+
* ``pathlib`` objects (all fully resolved):
|
556 |
+
|
557 |
+
- ``test_file_path`` - the current test file path (=``__file__``)
|
558 |
+
- ``test_file_dir`` - the directory containing the current test file
|
559 |
+
- ``tests_dir`` - the directory of the ``tests`` test suite
|
560 |
+
- ``data_dir`` - the directory of the ``tests/data`` test suite
|
561 |
+
- ``repo_root_dir`` - the directory of the repository
|
562 |
+
- ``src_dir`` - the directory where the ``m4`` sub-dir resides (same as repo_root_dir in this case)
|
563 |
+
|
564 |
+
* stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects:
|
565 |
+
|
566 |
+
- ``test_file_path_str``
|
567 |
+
- ``test_file_dir_str``
|
568 |
+
- ``tests_dir_str``
|
569 |
+
- ``data_dir_str``
|
570 |
+
- ``repo_root_dir_str``
|
571 |
+
- ``src_dir_str``
|
572 |
+
|
573 |
+
Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
|
574 |
+
|
575 |
+
1. Create a unique temporary dir:
|
576 |
+
|
577 |
+
::
|
578 |
+
|
579 |
+
def test_whatever(self):
|
580 |
+
tmp_dir = self.get_auto_remove_tmp_dir()
|
581 |
+
|
582 |
+
``tmp_dir`` will contain the pathlib path to the created temporary dir. It will be automatically
|
583 |
+
removed at the end of the test.
|
584 |
+
|
585 |
+
|
586 |
+
2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
|
587 |
+
empty it after the test.
|
588 |
+
|
589 |
+
::
|
590 |
+
|
591 |
+
def test_whatever(self):
|
592 |
+
tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
|
593 |
+
|
594 |
+
This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
|
595 |
+
didn't leave any data in there.
|
596 |
+
|
597 |
+
3. You can override the first two options by directly overriding the ``before`` and ``after`` args, leading to the
|
598 |
+
following behavior:
|
599 |
+
|
600 |
+
``before=True``: the temporary dir will always be cleared at the beginning of the test.
|
601 |
+
|
602 |
+
``before=False``: if the temporary dir already existed, any existing files will remain there.
|
603 |
+
|
604 |
+
``after=True``: the temporary dir will always be deleted at the end of the test.
|
605 |
+
|
606 |
+
``after=False``: the temporary dir will always be left intact at the end of the test.
|
607 |
+
|
608 |
+
Use `self.get_auto_remove_tmp_dir_str()` instead if you want the returned value to be a non-pathlib version.
|
609 |
+
|
610 |
+
Note 1: In order to run the equivalent of ``rm -r`` safely, only subdirs of the project repository checkout are
|
611 |
+
allowed if an explicit ``tmp_dir`` is used, so that by mistake no ``/tmp`` or similar important part of the
|
612 |
+
filesystem will get nuked. i.e. please always pass paths that start with ``./``
|
613 |
+
|
614 |
+
Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
|
615 |
+
otherwise.
|
616 |
+
|
617 |
+
Feature 3: Get a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` specific to the current test suite.
|
618 |
+
This is useful for invoking external programs from the test suite - e.g. distributed training.
|
619 |
+
|
620 |
+
|
621 |
+
::
|
622 |
+
def test_whatever(self):
|
623 |
+
env = self.get_env()
|
624 |
+
|
625 |
+
"""
|
626 |
+
|
627 |
+
def setUp(self):
|
628 |
+
# get_auto_remove_tmp_dir feature:
|
629 |
+
self.teardown_tmp_dirs = []
|
630 |
+
|
631 |
+
# figure out the resolved paths for repo_root, tests, etc.
|
632 |
+
self._test_file_path = inspect.getfile(self.__class__)
|
633 |
+
path = Path(self._test_file_path).resolve()
|
634 |
+
self._test_file_dir = path.parents[0]
|
635 |
+
for up in [1, 2, 3]:
|
636 |
+
tmp_dir = path.parents[up]
|
637 |
+
if (tmp_dir / "m4").is_dir() and (tmp_dir / "tests").is_dir():
|
638 |
+
break
|
639 |
+
if tmp_dir:
|
640 |
+
self._repo_root_dir = tmp_dir
|
641 |
+
else:
|
642 |
+
raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
|
643 |
+
self._tests_dir = self._repo_root_dir / "tests"
|
644 |
+
self._data_dir = self._repo_root_dir / "tests" / "test_data"
|
645 |
+
self._src_dir = self._repo_root_dir # m4 doesn't use "src/" prefix in the repo
|
646 |
+
|
647 |
+
@property
|
648 |
+
def test_file_path(self):
|
649 |
+
return self._test_file_path
|
650 |
+
|
651 |
+
@property
|
652 |
+
def test_file_path_str(self):
|
653 |
+
return str(self._test_file_path)
|
654 |
+
|
655 |
+
@property
|
656 |
+
def test_file_dir(self):
|
657 |
+
return self._test_file_dir
|
658 |
+
|
659 |
+
@property
|
660 |
+
def test_file_dir_str(self):
|
661 |
+
return str(self._test_file_dir)
|
662 |
+
|
663 |
+
@property
|
664 |
+
def tests_dir(self):
|
665 |
+
return self._tests_dir
|
666 |
+
|
667 |
+
@property
|
668 |
+
def tests_dir_str(self):
|
669 |
+
return str(self._tests_dir)
|
670 |
+
|
671 |
+
@property
|
672 |
+
def data_dir(self):
|
673 |
+
return self._data_dir
|
674 |
+
|
675 |
+
@property
|
676 |
+
def data_dir_str(self):
|
677 |
+
return str(self._data_dir)
|
678 |
+
|
679 |
+
@property
|
680 |
+
def repo_root_dir(self):
|
681 |
+
return self._repo_root_dir
|
682 |
+
|
683 |
+
@property
|
684 |
+
def repo_root_dir_str(self):
|
685 |
+
return str(self._repo_root_dir)
|
686 |
+
|
687 |
+
@property
|
688 |
+
def src_dir(self):
|
689 |
+
return self._src_dir
|
690 |
+
|
691 |
+
@property
|
692 |
+
def src_dir_str(self):
|
693 |
+
return str(self._src_dir)
|
694 |
+
|
695 |
+
def get_env(self):
|
696 |
+
"""
|
697 |
+
Return a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` correctly. This is useful
|
698 |
+
for invoking external programs from the test suite - e.g. distributed training.
|
699 |
+
|
700 |
+
It always inserts ``.`` first, then ``./tests`` depending on the test suite type and
|
701 |
+
finally the preset ``PYTHONPATH`` if any (all full resolved paths).
|
702 |
+
|
703 |
+
"""
|
704 |
+
env = os.environ.copy()
|
705 |
+
paths = [self.src_dir_str]
|
706 |
+
paths.append(self.tests_dir_str)
|
707 |
+
paths.append(env.get("PYTHONPATH", ""))
|
708 |
+
|
709 |
+
env["PYTHONPATH"] = ":".join(paths)
|
710 |
+
return env
|
711 |
+
|
712 |
+
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
|
713 |
+
"""
|
714 |
+
Args:
|
715 |
+
tmp_dir (:obj:`string`, `optional`):
|
716 |
+
if :obj:`None`:
|
717 |
+
|
718 |
+
- a unique temporary path will be created
|
719 |
+
- sets ``before=True`` if ``before`` is :obj:`None`
|
720 |
+
- sets ``after=True`` if ``after`` is :obj:`None`
|
721 |
+
else:
|
722 |
+
|
723 |
+
- :obj:`tmp_dir` will be created
|
724 |
+
- sets ``before=True`` if ``before`` is :obj:`None`
|
725 |
+
- sets ``after=False`` if ``after`` is :obj:`None`
|
726 |
+
before (:obj:`bool`, `optional`):
|
727 |
+
If :obj:`True` and the :obj:`tmp_dir` already exists, make sure to empty it right away if :obj:`False`
|
728 |
+
and the :obj:`tmp_dir` already exists, any existing files will remain there.
|
729 |
+
after (:obj:`bool`, `optional`):
|
730 |
+
If :obj:`True`, delete the :obj:`tmp_dir` at the end of the test if :obj:`False`, leave the
|
731 |
+
:obj:`tmp_dir` and its contents intact at the end of the test.
|
732 |
+
|
733 |
+
Returns:
|
734 |
+
tmp_dir(:obj:`string`): either the same value as passed via `tmp_dir` or the path to the auto-selected tmp
|
735 |
+
dir
|
736 |
+
"""
|
737 |
+
if tmp_dir is not None:
|
738 |
+
# defining the most likely desired behavior for when a custom path is provided.
|
739 |
+
# this most likely indicates the debug mode where we want an easily locatable dir that:
|
740 |
+
# 1. gets cleared out before the test (if it already exists)
|
741 |
+
# 2. is left intact after the test
|
742 |
+
if before is None:
|
743 |
+
before = True
|
744 |
+
if after is None:
|
745 |
+
after = False
|
746 |
+
|
747 |
+
# to avoid nuking parts of the filesystem, only relative paths are allowed
|
748 |
+
if not tmp_dir.startswith("./"):
|
749 |
+
raise ValueError(
|
750 |
+
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
|
751 |
+
)
|
752 |
+
|
753 |
+
# using provided path
|
754 |
+
tmp_dir = Path(tmp_dir).resolve()
|
755 |
+
|
756 |
+
# ensure the dir is empty to start with
|
757 |
+
if before is True and tmp_dir.exists():
|
758 |
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
759 |
+
|
760 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
761 |
+
|
762 |
+
else:
|
763 |
+
# defining the most likely desired behavior for when a unique tmp path is auto generated
|
764 |
+
# (not a debug mode), here we require a unique tmp dir that:
|
765 |
+
# 1. is empty before the test (it will be empty in this situation anyway)
|
766 |
+
# 2. gets fully removed after the test
|
767 |
+
if before is None:
|
768 |
+
before = True
|
769 |
+
if after is None:
|
770 |
+
after = True
|
771 |
+
|
772 |
+
# using unique tmp dir (always empty, regardless of `before`)
|
773 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
774 |
+
|
775 |
+
if after is True:
|
776 |
+
# register for deletion
|
777 |
+
self.teardown_tmp_dirs.append(tmp_dir)
|
778 |
+
|
779 |
+
return tmp_dir
|
780 |
+
|
781 |
+
def get_auto_remove_tmp_dir_str(self, *args, **kwargs):
|
782 |
+
return str(self.get_auto_remove_tmp_dir(*args, **kwargs))
|
783 |
+
|
784 |
+
def tearDown(self):
|
785 |
+
# get_auto_remove_tmp_dir feature: remove registered temp dirs
|
786 |
+
for path in self.teardown_tmp_dirs:
|
787 |
+
shutil.rmtree(path, ignore_errors=True)
|
788 |
+
self.teardown_tmp_dirs = []
|
789 |
+
|
790 |
+
|
791 |
+
def mockenv(**kwargs):
|
792 |
+
"""
|
793 |
+
this is a convenience wrapper, that allows this ::
|
794 |
+
|
795 |
+
@mockenv(RUN_SLOW=True, USE_TF=False)
|
796 |
+
def test_something():
|
797 |
+
run_slow = os.getenv("RUN_SLOW", False)
|
798 |
+
use_tf = os.getenv("USE_TF", False)
|
799 |
+
|
800 |
+
"""
|
801 |
+
return mock.patch.dict(os.environ, kwargs)
|
802 |
+
|
803 |
+
|
804 |
+
# from https://stackoverflow.com/a/34333710/9201239
|
805 |
+
@contextlib.contextmanager
|
806 |
+
def mockenv_context(*remove, **update):
|
807 |
+
"""
|
808 |
+
Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
|
809 |
+
|
810 |
+
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
|
811 |
+
|
812 |
+
Args:
|
813 |
+
remove: Environment variables to remove.
|
814 |
+
update: Dictionary of environment variables and values to add/update.
|
815 |
+
"""
|
816 |
+
env = os.environ
|
817 |
+
update = update or {}
|
818 |
+
remove = remove or []
|
819 |
+
|
820 |
+
# List of environment variables being updated or removed.
|
821 |
+
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
|
822 |
+
# Environment variables and values to restore on exit.
|
823 |
+
update_after = {k: env[k] for k in stomped}
|
824 |
+
# Environment variables and values to remove on exit.
|
825 |
+
remove_after = frozenset(k for k in update if k not in env)
|
826 |
+
|
827 |
+
try:
|
828 |
+
env.update(update)
|
829 |
+
[env.pop(k, None) for k in remove]
|
830 |
+
yield
|
831 |
+
finally:
|
832 |
+
env.update(update_after)
|
833 |
+
[env.pop(k) for k in remove_after]
|
834 |
+
|
835 |
+
|
836 |
+
# --- test network helper functions --- #
|
837 |
+
|
838 |
+
|
839 |
+
def get_xdist_worker_id():
|
840 |
+
"""
|
841 |
+
when run under pytest-xdist returns the worker id (int), otherwise returns 0
|
842 |
+
"""
|
843 |
+
worker_id_string = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
|
844 |
+
return int(worker_id_string[2:]) # strip "gw"
|
845 |
+
|
846 |
+
|
847 |
+
DEFAULT_MASTER_PORT = 10999
|
848 |
+
|
849 |
+
|
850 |
+
def get_unique_port_number():
|
851 |
+
"""
|
852 |
+
When the test suite runs under pytest-xdist we need to make sure that concurrent tests won't use
|
853 |
+
the same port number. We can accomplish that by using the same base and always adding the xdist
|
854 |
+
worker id to it, or 0 if not running under pytest-xdist
|
855 |
+
"""
|
856 |
+
return DEFAULT_MASTER_PORT + get_xdist_worker_id()
|
857 |
+
|
858 |
+
|
859 |
+
# --- test IO helper functions --- #
|
860 |
+
|
861 |
+
|
862 |
+
def write_file(file, content):
|
863 |
+
with open(file, "w") as f:
|
864 |
+
f.write(content)
|
865 |
+
|
866 |
+
|
867 |
+
def read_json_file(file):
|
868 |
+
with open(file, "r") as fh:
|
869 |
+
return json.load(fh)
|
870 |
+
|
871 |
+
|
872 |
+
def replace_str_in_file(file, text_to_search, replacement_text):
|
873 |
+
file = Path(file)
|
874 |
+
text = file.read_text()
|
875 |
+
text = text.replace(text_to_search, replacement_text)
|
876 |
+
file.write_text(text)
|
877 |
+
|
878 |
+
|
879 |
+
# --- pytest conf functions --- #
|
880 |
+
|
881 |
+
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
|
882 |
+
pytest_opt_registered = {}
|
883 |
+
|
884 |
+
|
885 |
+
def pytest_addoption_shared(parser):
|
886 |
+
"""
|
887 |
+
This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
|
888 |
+
|
889 |
+
It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
|
890 |
+
option.
|
891 |
+
|
892 |
+
"""
|
893 |
+
option = "--make-reports"
|
894 |
+
if option not in pytest_opt_registered:
|
895 |
+
parser.addoption(
|
896 |
+
option,
|
897 |
+
action="store",
|
898 |
+
default=False,
|
899 |
+
help="generate report files. The value of this option is used as a prefix to report names",
|
900 |
+
)
|
901 |
+
pytest_opt_registered[option] = 1
|
902 |
+
|
903 |
+
|
904 |
+
def pytest_terminal_summary_main(tr, id):
|
905 |
+
"""
|
906 |
+
Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
|
907 |
+
directory. The report files are prefixed with the test suite name.
|
908 |
+
|
909 |
+
This function emulates --duration and -rA pytest arguments.
|
910 |
+
|
911 |
+
This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
|
912 |
+
there.
|
913 |
+
|
914 |
+
Args:
|
915 |
+
- tr: `terminalreporter` passed from `conftest.py`
|
916 |
+
- id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
|
917 |
+
needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
|
918 |
+
|
919 |
+
NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
|
920 |
+
changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
|
921 |
+
plugins and interfere.
|
922 |
+
|
923 |
+
"""
|
924 |
+
from _pytest.config import create_terminal_writer
|
925 |
+
|
926 |
+
if not len(id):
|
927 |
+
id = "tests"
|
928 |
+
|
929 |
+
config = tr.config
|
930 |
+
orig_writer = config.get_terminal_writer()
|
931 |
+
orig_tbstyle = config.option.tbstyle
|
932 |
+
orig_reportchars = tr.reportchars
|
933 |
+
|
934 |
+
dir = f"reports/{id}"
|
935 |
+
Path(dir).mkdir(parents=True, exist_ok=True)
|
936 |
+
report_files = {
|
937 |
+
k: f"{dir}/{k}.txt"
|
938 |
+
for k in [
|
939 |
+
"durations",
|
940 |
+
"errors",
|
941 |
+
"failures_long",
|
942 |
+
"failures_short",
|
943 |
+
"failures_line",
|
944 |
+
"passes",
|
945 |
+
"stats",
|
946 |
+
"summary_short",
|
947 |
+
"warnings",
|
948 |
+
]
|
949 |
+
}
|
950 |
+
|
951 |
+
# custom durations report
|
952 |
+
# note: there is no need to call pytest --durations=XX to get this separate report
|
953 |
+
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
|
954 |
+
dlist = []
|
955 |
+
for replist in tr.stats.values():
|
956 |
+
for rep in replist:
|
957 |
+
if hasattr(rep, "duration"):
|
958 |
+
dlist.append(rep)
|
959 |
+
if dlist:
|
960 |
+
dlist.sort(key=lambda x: x.duration, reverse=True)
|
961 |
+
with open(report_files["durations"], "w") as f:
|
962 |
+
durations_min = 0.05 # sec
|
963 |
+
f.write("slowest durations\n")
|
964 |
+
for i, rep in enumerate(dlist):
|
965 |
+
if rep.duration < durations_min:
|
966 |
+
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
|
967 |
+
break
|
968 |
+
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
|
969 |
+
|
970 |
+
def summary_failures_short(tr):
|
971 |
+
# expecting that the reports were --tb=long (default) so we chop them off here to the last frame
|
972 |
+
reports = tr.getreports("failed")
|
973 |
+
if not reports:
|
974 |
+
return
|
975 |
+
tr.write_sep("=", "FAILURES SHORT STACK")
|
976 |
+
for rep in reports:
|
977 |
+
msg = tr._getfailureheadline(rep)
|
978 |
+
tr.write_sep("_", msg, red=True, bold=True)
|
979 |
+
# chop off the optional leading extra frames, leaving only the last one
|
980 |
+
longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
|
981 |
+
tr._tw.line(longrepr)
|
982 |
+
# note: not printing out any rep.sections to keep the report short
|
983 |
+
|
984 |
+
# use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
|
985 |
+
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
|
986 |
+
# note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
|
987 |
+
# pytest-instafail does that)
|
988 |
+
|
989 |
+
# report failures with line/short/long styles
|
990 |
+
config.option.tbstyle = "auto" # full tb
|
991 |
+
with open(report_files["failures_long"], "w") as f:
|
992 |
+
tr._tw = create_terminal_writer(config, f)
|
993 |
+
tr.summary_failures()
|
994 |
+
|
995 |
+
# config.option.tbstyle = "short" # short tb
|
996 |
+
with open(report_files["failures_short"], "w") as f:
|
997 |
+
tr._tw = create_terminal_writer(config, f)
|
998 |
+
summary_failures_short(tr)
|
999 |
+
|
1000 |
+
config.option.tbstyle = "line" # one line per error
|
1001 |
+
with open(report_files["failures_line"], "w") as f:
|
1002 |
+
tr._tw = create_terminal_writer(config, f)
|
1003 |
+
tr.summary_failures()
|
1004 |
+
|
1005 |
+
with open(report_files["errors"], "w") as f:
|
1006 |
+
tr._tw = create_terminal_writer(config, f)
|
1007 |
+
tr.summary_errors()
|
1008 |
+
|
1009 |
+
with open(report_files["warnings"], "w") as f:
|
1010 |
+
tr._tw = create_terminal_writer(config, f)
|
1011 |
+
tr.summary_warnings() # normal warnings
|
1012 |
+
tr.summary_warnings() # final warnings
|
1013 |
+
|
1014 |
+
tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
|
1015 |
+
|
1016 |
+
# Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
|
1017 |
+
# takes > 10 minutes (as this part doesn't generate any output on the terminal).
|
1018 |
+
# (also, it seems there is no useful information in this report, and we rarely need to read it)
|
1019 |
+
# with open(report_files["passes"], "w") as f:
|
1020 |
+
# tr._tw = create_terminal_writer(config, f)
|
1021 |
+
# tr.summary_passes()
|
1022 |
+
|
1023 |
+
with open(report_files["summary_short"], "w") as f:
|
1024 |
+
tr._tw = create_terminal_writer(config, f)
|
1025 |
+
tr.short_test_summary()
|
1026 |
+
|
1027 |
+
with open(report_files["stats"], "w") as f:
|
1028 |
+
tr._tw = create_terminal_writer(config, f)
|
1029 |
+
tr.summary_stats()
|
1030 |
+
|
1031 |
+
# restore:
|
1032 |
+
tr._tw = orig_writer
|
1033 |
+
tr.reportchars = orig_reportchars
|
1034 |
+
config.option.tbstyle = orig_tbstyle
|
1035 |
+
|
1036 |
+
|
1037 |
+
# --- distributed testing functions --- #
|
1038 |
+
|
1039 |
+
|
1040 |
+
class _RunOutput:
|
1041 |
+
def __init__(self, returncode, stdout, stderr):
|
1042 |
+
self.returncode = returncode
|
1043 |
+
self.stdout = stdout
|
1044 |
+
self.stderr = stderr
|
1045 |
+
|
1046 |
+
|
1047 |
+
async def _read_stream(stream, callback):
|
1048 |
+
while True:
|
1049 |
+
line = await stream.readline()
|
1050 |
+
if line:
|
1051 |
+
callback(line)
|
1052 |
+
else:
|
1053 |
+
break
|
1054 |
+
|
1055 |
+
|
1056 |
+
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
1057 |
+
if echo:
|
1058 |
+
print("\nRunning: ", " ".join(cmd))
|
1059 |
+
|
1060 |
+
p = await asyncio.create_subprocess_exec(
|
1061 |
+
cmd[0],
|
1062 |
+
*cmd[1:],
|
1063 |
+
stdin=stdin,
|
1064 |
+
stdout=asyncio.subprocess.PIPE,
|
1065 |
+
stderr=asyncio.subprocess.PIPE,
|
1066 |
+
env=env,
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
# note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
1070 |
+
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
1071 |
+
#
|
1072 |
+
# If it starts hanging, will need to switch to the following code. The problem is that no data
|
1073 |
+
# will be seen until it's done and if it hangs for example there will be no debug info.
|
1074 |
+
# out, err = await p.communicate()
|
1075 |
+
# return _RunOutput(p.returncode, out, err)
|
1076 |
+
|
1077 |
+
out = []
|
1078 |
+
err = []
|
1079 |
+
|
1080 |
+
def tee(line, sink, pipe, label=""):
|
1081 |
+
line = line.decode("utf-8").rstrip()
|
1082 |
+
sink.append(line)
|
1083 |
+
if not quiet:
|
1084 |
+
print(label, line, file=pipe)
|
1085 |
+
|
1086 |
+
# XXX: the timeout doesn't seem to make any difference here
|
1087 |
+
await asyncio.wait(
|
1088 |
+
[
|
1089 |
+
_read_stream(p.stdout, lambda line: tee(line, out, sys.stdout, label="stdout:")),
|
1090 |
+
_read_stream(p.stderr, lambda line: tee(line, err, sys.stderr, label="stderr:")),
|
1091 |
+
],
|
1092 |
+
timeout=timeout,
|
1093 |
+
)
|
1094 |
+
return _RunOutput(await p.wait(), out, err)
|
1095 |
+
|
1096 |
+
|
1097 |
+
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
1098 |
+
loop = asyncio.get_event_loop()
|
1099 |
+
result = loop.run_until_complete(
|
1100 |
+
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
cmd_str = " ".join(cmd)
|
1104 |
+
if result.returncode > 0:
|
1105 |
+
stderr = "\n".join(result.stderr)
|
1106 |
+
raise RuntimeError(
|
1107 |
+
f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
|
1108 |
+
f"The combined stderr from workers follows:\n{stderr}"
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# check that the subprocess actually did run and produced some output, should the test rely on
|
1112 |
+
# the remote side to do the testing
|
1113 |
+
if not result.stdout and not result.stderr:
|
1114 |
+
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
1115 |
+
|
1116 |
+
return result
|
m4/training/__init__.py
ADDED
File without changes
|
m4/training/config.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from dataclasses import InitVar, asdict, dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
+
|
8 |
+
import git
|
9 |
+
import yaml
|
10 |
+
from simple_parsing import ArgumentParser, Serializable
|
11 |
+
from simple_parsing.helpers import dict_field, list_field
|
12 |
+
|
13 |
+
from m4.training.types import DatasetNames, DatasetTypes
|
14 |
+
from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, LoggingTypes
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class CfgFileConfig:
|
22 |
+
"""Config file args"""
|
23 |
+
|
24 |
+
# path to config file
|
25 |
+
config: Optional[Path] = None
|
26 |
+
# set to false if you don't want to save config automatically
|
27 |
+
save_config: bool = True
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class GlobalBatchSizeRampUp:
|
32 |
+
"""These are init variables that are used to set up the GBS ramp up protocol"""
|
33 |
+
|
34 |
+
# global batch size ramp up protocol:
|
35 |
+
#
|
36 |
+
# 1. start with global batch size `start`
|
37 |
+
# 2. every time the number of `samples` is consumed increment global batch size by `increment`
|
38 |
+
# 3. repeat step 2 until global batch size reaches `finish`
|
39 |
+
start: Optional[int] = None
|
40 |
+
finish: Optional[int] = None
|
41 |
+
increment: Optional[int] = None
|
42 |
+
samples: Optional[int] = None
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class GlobalBatchSizeRampUpRunningParams:
|
47 |
+
"""The are running variables that are used to tell when to increment GBS and when to stop doing
|
48 |
+
that, they are never set directly in the config file, but are calculated when the training starts.
|
49 |
+
"""
|
50 |
+
|
51 |
+
global_seen_samples: int = 0
|
52 |
+
global_batch_size_current: int = 0
|
53 |
+
next_goal_samples: int = 0
|
54 |
+
grad_acc_size_current: int = 1
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class Hparams:
|
59 |
+
"""General Hyperparameters"""
|
60 |
+
|
61 |
+
# --------------------
|
62 |
+
# General parameters
|
63 |
+
# --------------------
|
64 |
+
|
65 |
+
seed: int = 13
|
66 |
+
# If set to True, the sole purpose of the job is to pre-process the dataset (i.e. the map
|
67 |
+
# operations). The job will exit as soon as the dataset is pre-processed.
|
68 |
+
just_preprocess: bool = False
|
69 |
+
jz_job_time_sec: Optional[float] = None
|
70 |
+
jz_start_time: float = time.time()
|
71 |
+
job_id: Optional[int] = None
|
72 |
+
timeout: int = 1800 # 30 min
|
73 |
+
# set to False to ignore the optimizer states when loading from a checkpoint
|
74 |
+
load_optimizer_states: Optional[bool] = True
|
75 |
+
# set to False to disable this gpu memory saving method
|
76 |
+
gradient_checkpointing: Optional[bool] = True
|
77 |
+
|
78 |
+
# --------------------
|
79 |
+
# Model-related hparams
|
80 |
+
# --------------------
|
81 |
+
tokenizer_name: str = "gpt2"
|
82 |
+
# The value of the string will evaluated (i.e. interpreted) and must be a dict
|
83 |
+
tokenizer_params: str = '{"use_fast":True}'
|
84 |
+
tokenizer_add_tokens: str = (
|
85 |
+
f'[AddedToken("{FAKE_TOKEN_AROUND_IMAGE_V2}", rstrip=False, lstrip=False), AddedToken("{IMAGE_TOKEN}",'
|
86 |
+
" rstrip=False, lstrip=False)]"
|
87 |
+
)
|
88 |
+
# The value of the string will evaluated (i.e. interpreted). Unnecessary if tokenizer has a `pad_token`.
|
89 |
+
tokenizer_add_special_tokens: str = '{"pad_token": tokenizer.eos_token}'
|
90 |
+
model_name: str = "gpt2"
|
91 |
+
revision: str = "main"
|
92 |
+
model_params: Dict[str, Any] = dict_field(
|
93 |
+
dict(
|
94 |
+
vision_embed_dim=768,
|
95 |
+
vision_image_size=224,
|
96 |
+
vision_model_name="google/vit-base-patch16-224",
|
97 |
+
# The value of the string will evaluated (i.e. interpreted) and must be a dict
|
98 |
+
vision_model_params="{}",
|
99 |
+
# Ties the word embedding with LM head's weights
|
100 |
+
# Since word embedding is frozen, use in conjuncation with freeze_lm_head=True
|
101 |
+
tie_word_embeddings=False,
|
102 |
+
# Freeze different parts of the model
|
103 |
+
freeze_lm_head=False,
|
104 |
+
freeze_text_layers=True,
|
105 |
+
freeze_text_module_exceptions=[],
|
106 |
+
freeze_vision_layers=True,
|
107 |
+
freeze_vision_module_exceptions=[],
|
108 |
+
# Perceiver Resampler Parameters
|
109 |
+
use_resampler=False,
|
110 |
+
resampler_n_latents=64,
|
111 |
+
resampler_depth=6,
|
112 |
+
resampler_n_heads=16,
|
113 |
+
resampler_head_dim=96,
|
114 |
+
)
|
115 |
+
)
|
116 |
+
|
117 |
+
# --------------------
|
118 |
+
# Training parameters
|
119 |
+
# --------------------
|
120 |
+
resume_run: Optional[bool] = None
|
121 |
+
do_validation: bool = True
|
122 |
+
|
123 |
+
# deprecated in favor of batch_size_per_gpu
|
124 |
+
batch_size: Optional[int] = None
|
125 |
+
batch_size_per_gpu: int = 1
|
126 |
+
global_batch_size: Optional[int] = None
|
127 |
+
|
128 |
+
global_batch_size_ramp_up: GlobalBatchSizeRampUp = GlobalBatchSizeRampUp()
|
129 |
+
grad_acc_size: Optional[int] = 1
|
130 |
+
|
131 |
+
grad_clip: float = 1.0
|
132 |
+
|
133 |
+
# weights by which to multiply the loss of each dataset when accumulating gradients over datasets
|
134 |
+
loss_weights_per_dataset: Optional[List[float]] = None
|
135 |
+
# int(max_num_tokens / (batch_size * max_seq_len * grad_acc_size * num_processes))
|
136 |
+
max_num_opt_steps: Optional[int] = 500_000
|
137 |
+
max_num_opt_steps_this_run: Optional[int] = None
|
138 |
+
max_num_epochs: Optional[int] = None
|
139 |
+
|
140 |
+
# If the path appears the program will stop after finishing the current training step
|
141 |
+
kill_switch_path: Optional[Path] = None
|
142 |
+
|
143 |
+
# If the path appears the program will save a checkpoint and immediately delete this flag
|
144 |
+
save_switch_path: Optional[Path] = None
|
145 |
+
|
146 |
+
# --------------------
|
147 |
+
# Logging parameters
|
148 |
+
# --------------------
|
149 |
+
train_logging_opt_steps: int = 50
|
150 |
+
train_logging_per_dataset_suffix: str = ""
|
151 |
+
|
152 |
+
# If a specific logging type is specified, per dataset information will be inserted inside
|
153 |
+
# those logs.
|
154 |
+
train_logging_per_dataset_info: List[LoggingTypes] = list_field(LoggingTypes.JSONL, LoggingTypes.WANDB)
|
155 |
+
|
156 |
+
# If `train_logging_activations` is not empty, hooks will be inserted to the model to track
|
157 |
+
# the min/max/std/norm of the activations and weights. This will slow down training.
|
158 |
+
# See https://huggingface.co/docs/transformers/main/en/debugging#underflow-and-overflow-detection
|
159 |
+
train_logging_activations: List[LoggingTypes] = list_field()
|
160 |
+
train_logging_activations_opt_steps: Optional[int] = 25
|
161 |
+
train_logging_grad_param_deepspeed: List[LoggingTypes] = list_field()
|
162 |
+
train_logging_grad_param_deepspeed_opt_steps: int = 50
|
163 |
+
val_logging_opt_steps: int = train_logging_opt_steps * 5
|
164 |
+
val_inline_logging_opt_steps: int = train_logging_opt_steps
|
165 |
+
train_saving_opt_steps: int = train_logging_opt_steps * 5
|
166 |
+
save_dir: Optional[Path] = None
|
167 |
+
upload_to_s3: bool = False
|
168 |
+
train_log_mem_usage: bool = False
|
169 |
+
timing_break_down: bool = False
|
170 |
+
|
171 |
+
save_batch_max_idx: Optional[int] = None
|
172 |
+
save_batch_min_idx: Optional[int] = None
|
173 |
+
|
174 |
+
# ----------------------
|
175 |
+
# Wandb Parameters
|
176 |
+
# ----------------------
|
177 |
+
wandb_enable: bool = False
|
178 |
+
# name of the project
|
179 |
+
wandb_project: str = "VLOOM"
|
180 |
+
wandb_entity: str = "huggingfacem4"
|
181 |
+
# name of the wandb entity
|
182 |
+
wandb_log_freq: int = 50
|
183 |
+
wandb_run_id: str = ""
|
184 |
+
wandb_tags: Optional[List[str]] = None
|
185 |
+
|
186 |
+
repo_commit_id: Optional[str] = None
|
187 |
+
|
188 |
+
# ----------------------
|
189 |
+
# Debug Parameters
|
190 |
+
# ----------------------
|
191 |
+
use_torch_profiler: bool = False
|
192 |
+
|
193 |
+
|
194 |
+
@dataclass
|
195 |
+
class ResumeParams:
|
196 |
+
# ----------------------
|
197 |
+
# Resume run Parameters
|
198 |
+
# ----------------------
|
199 |
+
# Need to make sure that resume_run is True to give an input here
|
200 |
+
opt_step_dir: Optional[Path] = None
|
201 |
+
accelerator_state_dir: Optional[Path] = None
|
202 |
+
model_file: Optional[Path] = None
|
203 |
+
model_config_file: Optional[Path] = None
|
204 |
+
# Automatically resumes last run of the save_dir. Set to False to choose a specific run
|
205 |
+
resume_last: bool = True
|
206 |
+
train_logs: Dict = dict_field()
|
207 |
+
resume_opt_step: int = 0
|
208 |
+
resume_epoch: int = 0
|
209 |
+
resume_dataset_state: List = list_field()
|
210 |
+
|
211 |
+
gbs_running: GlobalBatchSizeRampUpRunningParams = GlobalBatchSizeRampUpRunningParams()
|
212 |
+
|
213 |
+
|
214 |
+
@dataclass
|
215 |
+
class DatasetParams:
|
216 |
+
# This always need to be specified as it is needed by dataset utils down the line
|
217 |
+
dataset_name: DatasetNames
|
218 |
+
# max number of images per sample
|
219 |
+
max_num_images: int = 5
|
220 |
+
# maximum sequence length
|
221 |
+
max_seq_len: int = 256
|
222 |
+
training_datasets_paths: List[Path] = list_field()
|
223 |
+
validation_datasets_paths: List[Path] = list_field()
|
224 |
+
# if True, instead of split and pack, each instance in sample will be
|
225 |
+
# either truncated or padded to the same length.
|
226 |
+
pad_dataset: bool = False
|
227 |
+
map_batch_size: int = 64
|
228 |
+
# Preprocessing number of processes in map (not useful for processing on the fly)
|
229 |
+
map_num_proc: Optional[int] = None
|
230 |
+
# Decides how many number of samples/subsequence should be extracted from the
|
231 |
+
# CM4 corpus when the dataset is to be padded irrelavent otherwise as full packing
|
232 |
+
# is used
|
233 |
+
max_num_samples_per_document: int = 10
|
234 |
+
|
235 |
+
# Strategy for detecting blur, laplacian or fft
|
236 |
+
blur_strategy: str = "fft"
|
237 |
+
# Threshold for blur detection, 0.0 means disabled. Set 32 for "laplacian" and
|
238 |
+
# 10 for "fft" for starters
|
239 |
+
blur_threshold: float = 0.0
|
240 |
+
|
241 |
+
add_begin_of_doc_token: bool = False
|
242 |
+
add_end_of_doc_token: bool = True
|
243 |
+
|
244 |
+
shuffle_after_packing: bool = False
|
245 |
+
|
246 |
+
# Parameters for T5 MLM
|
247 |
+
t5_mlm_noise_density: float = 0.15
|
248 |
+
t5_mlm_mean_noise_span_length: int = 3
|
249 |
+
|
250 |
+
dataset_type: Optional[DatasetTypes] = None
|
251 |
+
|
252 |
+
# Parameters for webdataset pipeline
|
253 |
+
shuffle_initial_urls_list: bool = False
|
254 |
+
shuffle_before_split_by_node_buffer_size: Optional[int] = None
|
255 |
+
shuffle_before_split_by_worker_buffer_size: Optional[int] = None
|
256 |
+
shuffle_after_tarfile_to_samples_buffer_size: Optional[int] = None
|
257 |
+
shuffle_after_batching_buffer_size: Optional[int] = None
|
258 |
+
|
259 |
+
|
260 |
+
@dataclass
|
261 |
+
class ImageCaptionPairedDatasetParams(DatasetParams):
|
262 |
+
# PMD only: This value decides the probability of the image token being at the start
|
263 |
+
# of the text or at the end of the text. Set to 0.5 for equal probability.
|
264 |
+
# Set to 0 for the image always at start.
|
265 |
+
prob_image_at_end: float = 0.5
|
266 |
+
# PMD only: Specifies the tolerance for the amount of padding in a sequence. If set
|
267 |
+
# to -1, then all padding will be tolerated. If set to 0, then no padding will be tolerated.
|
268 |
+
# Continuously increase this value to allow more padding in the sequence.
|
269 |
+
padding_tolerance: int = -1
|
270 |
+
dataset_type: DatasetTypes = DatasetTypes.IMAGE_CAPTION_PAIRS
|
271 |
+
|
272 |
+
|
273 |
+
@dataclass
|
274 |
+
class WebDocumentsDatasetParams(DatasetParams):
|
275 |
+
# Decide how often should the image attention mask is such that the
|
276 |
+
# the text attends to next image. Set to 0 for just perceding images
|
277 |
+
# NOTE: For PMD, this option doesn't apply anymore. Use `prob_image_at_end`
|
278 |
+
# to control the position of the image and corresponding image.
|
279 |
+
p_next: float = 0.5
|
280 |
+
dataset_type: DatasetTypes = DatasetTypes.WEB_DOCUMENTS
|
281 |
+
|
282 |
+
|
283 |
+
@dataclass
|
284 |
+
class DataParams(Serializable):
|
285 |
+
"""Data Parameters"""
|
286 |
+
|
287 |
+
# what software to use for the dataset
|
288 |
+
use_webdataset: bool = False
|
289 |
+
|
290 |
+
# number of workers for dataloaders int
|
291 |
+
num_workers: int = 1
|
292 |
+
# allow async faster data transfer to GPUs (only make sense when CUDA GPUs are available)
|
293 |
+
# known to cause memory issues
|
294 |
+
pin_memory: bool = False
|
295 |
+
# Whether to use persistent workers for the dataloaders
|
296 |
+
persistent_workers: bool = True
|
297 |
+
realtime_processing: bool = False
|
298 |
+
|
299 |
+
train_seed: int = 1
|
300 |
+
val_seed: int = 2
|
301 |
+
|
302 |
+
# can use one config for both train + validation or specific ones if need to be different
|
303 |
+
select_n_examples: Optional[int] = None
|
304 |
+
select_n_examples_train: Optional[int] = None
|
305 |
+
select_n_examples_validation: Optional[int] = None
|
306 |
+
|
307 |
+
# TODO: Move to per dataset params as it makes more sense there
|
308 |
+
proba_interleaving_dataset: Optional[List[float]] = None
|
309 |
+
|
310 |
+
pmd: ImageCaptionPairedDatasetParams = ImageCaptionPairedDatasetParams(dataset_name=DatasetNames.PMD)
|
311 |
+
laion: ImageCaptionPairedDatasetParams = ImageCaptionPairedDatasetParams(dataset_name=DatasetNames.LAION)
|
312 |
+
cm4: WebDocumentsDatasetParams = WebDocumentsDatasetParams(dataset_name=DatasetNames.CM4)
|
313 |
+
wiki: WebDocumentsDatasetParams = WebDocumentsDatasetParams(dataset_name=DatasetNames.WIKI)
|
314 |
+
|
315 |
+
|
316 |
+
@dataclass
|
317 |
+
class OptimizerParams:
|
318 |
+
"""Optimization parameters"""
|
319 |
+
|
320 |
+
# --------------------
|
321 |
+
# vl optim parameters
|
322 |
+
# --------------------
|
323 |
+
vl_optim: str = "AdamW"
|
324 |
+
vl_optim_params: Dict[str, Any] = dict_field(
|
325 |
+
dict(
|
326 |
+
# learning rate
|
327 |
+
lr=1e-4,
|
328 |
+
# betas for adam
|
329 |
+
betas=(0.9, 0.999),
|
330 |
+
weight_decay=0.1,
|
331 |
+
no_decay=["bias", "alpha", "layernorm", "ln", "layer_norm", "perceiver_resampler"],
|
332 |
+
)
|
333 |
+
)
|
334 |
+
|
335 |
+
vl_lr_scheduler: str = "get_constant_schedule_with_warmup"
|
336 |
+
# number of warmup steps for the learning rate
|
337 |
+
vl_lr_scheduler_params: Dict[str, Any] = dict_field(dict(num_warmup_steps=5_000, last_epoch=-1))
|
338 |
+
z_loss: float = 0.0
|
339 |
+
|
340 |
+
|
341 |
+
@dataclass
|
342 |
+
class Parameters(Serializable):
|
343 |
+
"""base options."""
|
344 |
+
|
345 |
+
hparams: Hparams = Hparams()
|
346 |
+
optim_param: OptimizerParams = OptimizerParams()
|
347 |
+
data_param: DataParams = DataParams()
|
348 |
+
resume_param: ResumeParams = ResumeParams()
|
349 |
+
should_verify: InitVar[bool] = True
|
350 |
+
|
351 |
+
def verify(self, should_verify: bool):
|
352 |
+
if not should_verify:
|
353 |
+
return
|
354 |
+
|
355 |
+
dict_rep = vars(self)
|
356 |
+
expected = vars(self.__class__(should_verify=False))
|
357 |
+
for key, value in dict_rep.items():
|
358 |
+
if isinstance(value, dict):
|
359 |
+
diff = set(value.keys()) - set(asdict(expected[key]).keys())
|
360 |
+
raise TypeError(
|
361 |
+
f"{key} in {self.__class__.__name__} has extra keys: {diff}. Please fix your config if you are"
|
362 |
+
" using one."
|
363 |
+
)
|
364 |
+
if key not in expected:
|
365 |
+
raise ValueError(f"{key} is not a valid parameter for {self.__class__.__name__}")
|
366 |
+
|
367 |
+
def __post_init__(self, should_verify: bool = True):
|
368 |
+
"""Post-initialization code"""
|
369 |
+
self.verify(should_verify=should_verify)
|
370 |
+
|
371 |
+
# copy select_n_examples to the more specific ones if the latter haven't been preset
|
372 |
+
if self.data_param.select_n_examples is not None:
|
373 |
+
if self.data_param.select_n_examples_train is None:
|
374 |
+
self.data_param.select_n_examples_train = self.data_param.select_n_examples
|
375 |
+
if self.data_param.select_n_examples_validation is None:
|
376 |
+
self.data_param.select_n_examples_validation = self.data_param.select_n_examples
|
377 |
+
|
378 |
+
# Get commit id
|
379 |
+
if self.hparams.repo_commit_id is None:
|
380 |
+
self.hparams.repo_commit_id = git.Repo(search_parent_directories=True).head.object.hexsha
|
381 |
+
|
382 |
+
# If processing on the fly, with the current implementation, we can't have `num_workers=0`
|
383 |
+
if self.data_param.realtime_processing and self.data_param.num_workers == 0:
|
384 |
+
raise ValueError(
|
385 |
+
"If doing processing on the fly (and thus using the `IterableDataset`), you can't have `num_workers`"
|
386 |
+
" equal to 0."
|
387 |
+
)
|
388 |
+
|
389 |
+
# batch_size deprecation
|
390 |
+
if self.hparams.batch_size is not None:
|
391 |
+
if self.hparams.batch_size_per_gpu > 1:
|
392 |
+
raise ValueError(
|
393 |
+
"as hparams.batch_size is deprecated - don't know how to proceed with both hparams.batch_size>1"
|
394 |
+
" and hparams.batch_size_per_gpu > 1"
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
logger.warning(
|
398 |
+
"will use the deprecated hparams.batch_size, but transition to hparams.batch_size_per_gpu instead"
|
399 |
+
)
|
400 |
+
self.hparams.batch_size_per_gpu = self.hparams.batch_size
|
401 |
+
self.hparams.batch_size = None
|
402 |
+
|
403 |
+
# Assign batch size to data_param as well for dataloaders
|
404 |
+
self.data_param.batch_size = self.hparams.batch_size_per_gpu
|
405 |
+
|
406 |
+
# note: all global batch_size-related configs including hparams.grad_acc_size will be
|
407 |
+
# checked/set in trainer's setup_batch_size_related_configs since we need to know the value
|
408 |
+
# of num_processes
|
409 |
+
|
410 |
+
# Assign loggingtypes given values
|
411 |
+
self.hparams.train_logging_activations = [LoggingTypes(val) for val in self.hparams.train_logging_activations]
|
412 |
+
|
413 |
+
# Check that proba_interleaving_dataset is mutually exclusive to loss_weights_per_dataset
|
414 |
+
if self.data_param.proba_interleaving_dataset and self.hparams.loss_weights_per_dataset:
|
415 |
+
raise ValueError(
|
416 |
+
"Can't have hparams.loss_weights_per_dataset and proba_interleaving_dataset. If we have"
|
417 |
+
" loss_weights_per_dataset, it means the gradients are accumulated over datasets. Therefore a batch of"
|
418 |
+
" each given at each update and there is no use of proba_interleaving_dataset"
|
419 |
+
)
|
420 |
+
|
421 |
+
if (
|
422 |
+
self.data_param.proba_interleaving_dataset is not None
|
423 |
+
and sum(self.data_param.proba_interleaving_dataset) != 1
|
424 |
+
):
|
425 |
+
raise ValueError("proba_interleaving_dataset must sum to 1")
|
426 |
+
|
427 |
+
self.hparams.train_logging_grad_param_deepspeed = [
|
428 |
+
LoggingTypes(val) for val in self.hparams.train_logging_grad_param_deepspeed
|
429 |
+
]
|
430 |
+
|
431 |
+
# Resume run if there is already an existing folder for this run
|
432 |
+
if self.hparams.save_dir is not None and self.hparams.save_dir.exists():
|
433 |
+
save_dir_has_checkpoints = (
|
434 |
+
len([dir for dir in self.hparams.save_dir.iterdir() if (dir.is_dir() and "opt_step" in str(dir))]) > 0
|
435 |
+
)
|
436 |
+
if self.hparams.resume_run is not None and not self.hparams.resume_run and save_dir_has_checkpoints:
|
437 |
+
logger.warning(
|
438 |
+
"`resume_run` was explicitely set to False (i.e. starting from scratch), but the experiment"
|
439 |
+
" folder already has been populated with previous runs.\nAlready saved checkpoints will be"
|
440 |
+
" overwritten (at best, when `train_saving_opt_steps` is the same) or will be mixed with the new"
|
441 |
+
" checkpoints of a potentially brand new experiment. Would it make sense to create a new"
|
442 |
+
" `save_dir`?"
|
443 |
+
)
|
444 |
+
self.hparams.resume_run = save_dir_has_checkpoints
|
445 |
+
|
446 |
+
# Setup all args needed to resume a run
|
447 |
+
if self.hparams.resume_run:
|
448 |
+
# Get last step directory
|
449 |
+
if self.resume_param.opt_step_dir is None and not self.resume_param.resume_last:
|
450 |
+
raise ValueError(
|
451 |
+
"`opt_step_dir` cannot be None while `resume_last` is False. Choose which dir you want to resume"
|
452 |
+
" from..."
|
453 |
+
)
|
454 |
+
if self.resume_param.resume_last:
|
455 |
+
if self.resume_param.opt_step_dir is not None:
|
456 |
+
raise ValueError(
|
457 |
+
"`resume_last` cannot be True while `opt_step_dir` is not None. Choose which dir you want to"
|
458 |
+
" resume from..."
|
459 |
+
)
|
460 |
+
latest_path = self.hparams.save_dir / "latest_opt_step_dir"
|
461 |
+
with open(latest_path, "r") as fd:
|
462 |
+
self.resume_param.opt_step_dir = Path(fd.read().strip())
|
463 |
+
if not (self.resume_param.opt_step_dir.exists() and self.resume_param.opt_step_dir.is_dir()):
|
464 |
+
raise ValueError(
|
465 |
+
f"It appears that the path in the `latest_opt_step_dir` file {latest_path} is invalid. It's"
|
466 |
+
" either does not exist or is not a directory. Please fix that."
|
467 |
+
)
|
468 |
+
|
469 |
+
with open(self.resume_param.opt_step_dir / "resume_run_infos.json", "r") as f:
|
470 |
+
resume_infos = json.load(f)
|
471 |
+
logger.info(f"Resuming from {self.resume_param.opt_step_dir}")
|
472 |
+
self.resume_param.accelerator_state_dir = self.resume_param.opt_step_dir / "accelerator_state"
|
473 |
+
self.resume_param.model_file = self.resume_param.opt_step_dir / "unwrapped_model"
|
474 |
+
self.resume_param.model_config_file = self.resume_param.opt_step_dir / "unwrapped_model/config.json"
|
475 |
+
self.resume_param.tokenizer = self.resume_param.opt_step_dir / "tokenizer"
|
476 |
+
|
477 |
+
self.resume_param.train_logs = resume_infos["train_logs"]
|
478 |
+
self.resume_param.resume_opt_step = resume_infos["resume_opt_step"]
|
479 |
+
self.resume_param.resume_epoch = resume_infos["resume_epoch"]
|
480 |
+
self.resume_param.resume_dataset_state = resume_infos.get("resume_dataset_state", list())
|
481 |
+
|
482 |
+
gbs_running = resume_infos["gbs_running"]
|
483 |
+
self.resume_param.gbs_running.global_batch_size_current = gbs_running["global_batch_size_current"]
|
484 |
+
self.resume_param.gbs_running.global_seen_samples = gbs_running["global_seen_samples"]
|
485 |
+
self.resume_param.gbs_running.next_goal_samples = gbs_running["next_goal_samples"]
|
486 |
+
self.resume_param.gbs_running.grad_acc_size_current = gbs_running["grad_acc_size_current"]
|
487 |
+
|
488 |
+
self.hparams.wandb_run_id = resume_infos["wandb_run_id"]
|
489 |
+
self.hparams.seed = resume_infos["seed"]
|
490 |
+
|
491 |
+
# Should not happen, but this is in case there is a run mixing
|
492 |
+
# wandb_enable = True and wandb_enable = False between jobs
|
493 |
+
if not self.hparams.wandb_enable:
|
494 |
+
self.hparams.wandb_run_id = ""
|
495 |
+
|
496 |
+
@classmethod
|
497 |
+
def parse(cls):
|
498 |
+
cfgfile_parser = ArgumentParser(add_help=False)
|
499 |
+
cfgfile_parser.add_arguments(CfgFileConfig, dest="cfgfile")
|
500 |
+
cfgfile_args, rest = cfgfile_parser.parse_known_args()
|
501 |
+
|
502 |
+
cfgfile: CfgFileConfig = cfgfile_args.cfgfile
|
503 |
+
|
504 |
+
file_config: Optional[Parameters] = None
|
505 |
+
if cfgfile.config is not None:
|
506 |
+
file_config = Parameters.load(cfgfile.config, load_fn=yaml.safe_load)
|
507 |
+
|
508 |
+
parser = ArgumentParser()
|
509 |
+
|
510 |
+
# add cfgfile args so they appear in the help message
|
511 |
+
parser.add_arguments(CfgFileConfig, dest="cfgfile")
|
512 |
+
parser.add_arguments(Parameters, dest="parameters", default=file_config)
|
513 |
+
|
514 |
+
# XXX: currently when called from tests we don't want to parse pytest arguments, so either
|
515 |
+
# this whole logic needs to be rewritten to not always call `parser.parse_args` but only
|
516 |
+
# when needed, for now as a workaround using `parse_known_args` and ignoring the args which
|
517 |
+
# don't belong to this program
|
518 |
+
args, unknown = parser.parse_known_args()
|
519 |
+
|
520 |
+
parameters: Parameters = args.parameters
|
521 |
+
|
522 |
+
parameters.save_config = cfgfile.save_config
|
523 |
+
|
524 |
+
return parameters
|
525 |
+
|
526 |
+
def save_config_state(self):
|
527 |
+
if self.save_config:
|
528 |
+
self.hparams.save_dir.mkdir(parents=True, exist_ok=True)
|
529 |
+
if self.hparams.job_id is not None:
|
530 |
+
config_file_name = f"{self.hparams.job_id}_config.yaml"
|
531 |
+
else:
|
532 |
+
config_file_name = "config.yaml"
|
533 |
+
self.save(self.hparams.save_dir / config_file_name, indent=4)
|
534 |
+
|
535 |
+
|
536 |
+
def get_config(print_config: bool = True):
|
537 |
+
parameters: Parameters = Parameters.parse()
|
538 |
+
if print_config:
|
539 |
+
print(parameters)
|
540 |
+
return parameters
|
541 |
+
|
542 |
+
|
543 |
+
if __name__ == "__main__":
|
544 |
+
config = get_config()
|
545 |
+
config.save_config_state()
|
m4/training/dataset_utils.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
import webdataset as wds
|
5 |
+
from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener
|
6 |
+
|
7 |
+
from m4.training.types import DatasetTypes
|
8 |
+
|
9 |
+
|
10 |
+
meta_prefix = "__"
|
11 |
+
meta_suffix = "__"
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
trace = False
|
15 |
+
|
16 |
+
|
17 |
+
def webdoc_valid_sample(sample):
|
18 |
+
"""Check whether a sample is valid.
|
19 |
+
|
20 |
+
:param sample: sample to be checked
|
21 |
+
"""
|
22 |
+
return (
|
23 |
+
sample is not None
|
24 |
+
and isinstance(sample, dict)
|
25 |
+
and len(list(sample.keys())) > 0
|
26 |
+
and not sample.get("__bad__", False)
|
27 |
+
and sample_has_all_files(sample)
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def sample_has_all_files(current_sample):
|
32 |
+
meta = current_sample.get("metadata.value", None)
|
33 |
+
if meta is None:
|
34 |
+
return False
|
35 |
+
meta = meta.decode("utf-8")
|
36 |
+
if len(meta) == 0:
|
37 |
+
return False
|
38 |
+
target_file_list = meta.split("\n")
|
39 |
+
fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")]
|
40 |
+
fnames = [current_sample[key] for key in fname_keys]
|
41 |
+
check = all([fname in fnames for fname in target_file_list])
|
42 |
+
if not check:
|
43 |
+
return False
|
44 |
+
return True
|
45 |
+
|
46 |
+
|
47 |
+
class ImageDecoder:
|
48 |
+
def __call__(self, bytes_):
|
49 |
+
import io
|
50 |
+
|
51 |
+
import PIL.Image
|
52 |
+
|
53 |
+
img = PIL.Image.open(io.BytesIO(bytes_))
|
54 |
+
img.load()
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
# Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183
|
59 |
+
def log_and_continue(exn):
|
60 |
+
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
|
61 |
+
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
62 |
+
return True
|
63 |
+
|
64 |
+
|
65 |
+
# Adapt group_by_keys to our webdocument format in which each samples contains several text and image files
|
66 |
+
# https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250
|
67 |
+
def group_by_keys_interleaved(data, handler=log_and_continue):
|
68 |
+
"""Return function over iterator that groups key, value pairs into samples."""
|
69 |
+
current_sample = None
|
70 |
+
for filesample in data:
|
71 |
+
try:
|
72 |
+
assert isinstance(filesample, dict)
|
73 |
+
fname, value = filesample["fname"], filesample["data"]
|
74 |
+
fname = fname.strip("./")
|
75 |
+
if fname.endswith(".metadata.txt"):
|
76 |
+
prefix, data_type, extension = fname.split(".")
|
77 |
+
suffix = data_type
|
78 |
+
else:
|
79 |
+
prefix, idx, data_type, extension = fname.split(".")
|
80 |
+
if data_type not in ["text", "image"]:
|
81 |
+
raise ValueError(f"{fname}: unknown data type {data_type}")
|
82 |
+
suffix = idx
|
83 |
+
if trace:
|
84 |
+
print(
|
85 |
+
f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:"
|
86 |
+
f" {current_sample.keys() if isinstance(current_sample, dict) else None}"
|
87 |
+
)
|
88 |
+
if prefix is None:
|
89 |
+
continue
|
90 |
+
if current_sample is None or prefix != current_sample["__key__"]:
|
91 |
+
valid = webdoc_valid_sample(current_sample)
|
92 |
+
if valid:
|
93 |
+
yield current_sample
|
94 |
+
elif current_sample is not None:
|
95 |
+
logging.warning(f"{fname}: invalid sample {current_sample} ignored")
|
96 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
97 |
+
if suffix in current_sample:
|
98 |
+
raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
|
99 |
+
current_sample[f"{suffix}.value"] = value
|
100 |
+
current_sample[f"{suffix}.type"] = data_type
|
101 |
+
current_sample[f"{suffix}.fname"] = fname
|
102 |
+
except Exception as exn:
|
103 |
+
exn.args = exn.args + (filesample.get("stream"), filesample.get("url"))
|
104 |
+
if handler(exn):
|
105 |
+
continue
|
106 |
+
else:
|
107 |
+
break
|
108 |
+
|
109 |
+
if webdoc_valid_sample(current_sample):
|
110 |
+
yield current_sample
|
111 |
+
|
112 |
+
|
113 |
+
def _tarfile_to_webdocument_samples(src, handler=log_and_continue):
|
114 |
+
streams = url_opener(src, handler=handler)
|
115 |
+
files = tar_file_expander(streams, handler=handler)
|
116 |
+
samples = group_by_keys_interleaved(files, handler=handler)
|
117 |
+
return samples
|
118 |
+
|
119 |
+
|
120 |
+
tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples)
|
121 |
+
|
122 |
+
|
123 |
+
def _collate_texts_and_images_webdocument(data, handler=log_and_continue):
|
124 |
+
for sample in data:
|
125 |
+
try:
|
126 |
+
max_example_indices = max(
|
127 |
+
[int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"]
|
128 |
+
)
|
129 |
+
texts = [None for _ in range(max_example_indices + 1)]
|
130 |
+
images = [None for _ in range(max_example_indices + 1)]
|
131 |
+
for idx in range(max_example_indices + 1):
|
132 |
+
if f"{idx}.value" not in sample:
|
133 |
+
continue
|
134 |
+
if "text" in sample[f"{idx}.type"]:
|
135 |
+
texts[idx] = sample[f"{idx}.value"]
|
136 |
+
elif "image" in sample[f"{idx}.type"]:
|
137 |
+
images[idx] = sample[f"{idx}.value"]
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}")
|
140 |
+
example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images}
|
141 |
+
yield example
|
142 |
+
except Exception as exn:
|
143 |
+
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
|
144 |
+
if handler(exn):
|
145 |
+
continue
|
146 |
+
else:
|
147 |
+
break
|
148 |
+
|
149 |
+
|
150 |
+
collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument)
|
151 |
+
|
152 |
+
|
153 |
+
def _decode_image_and_text_webdocument(data, handler=log_and_continue):
|
154 |
+
image_decoder = ImageDecoder()
|
155 |
+
for sample in data:
|
156 |
+
try:
|
157 |
+
sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]]
|
158 |
+
sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]]
|
159 |
+
yield sample
|
160 |
+
except Exception as exn:
|
161 |
+
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
|
162 |
+
if handler(exn):
|
163 |
+
continue
|
164 |
+
else:
|
165 |
+
break
|
166 |
+
|
167 |
+
|
168 |
+
decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument)
|
169 |
+
|
170 |
+
|
171 |
+
def collate_dicts(samples):
|
172 |
+
keys = samples[0].keys()
|
173 |
+
batched_samples = {key: [sample[key] for sample in samples] for key in keys}
|
174 |
+
return batched_samples
|
175 |
+
|
176 |
+
|
177 |
+
def get_webdocuments_webdataset(
|
178 |
+
urls,
|
179 |
+
batch_size,
|
180 |
+
shuffle_initial_urls_list=False,
|
181 |
+
shuffle_before_split_by_node_buffer_size=100,
|
182 |
+
shuffle_before_split_by_worker_buffer_size=100,
|
183 |
+
shuffle_after_tarfile_to_samples_buffer_size=100,
|
184 |
+
shuffle_after_batching_buffer_size=1000,
|
185 |
+
):
|
186 |
+
if shuffle_initial_urls_list:
|
187 |
+
random.shuffle(urls)
|
188 |
+
|
189 |
+
pipeline_list = [wds.SimpleShardList(urls)]
|
190 |
+
|
191 |
+
if shuffle_before_split_by_node_buffer_size is not None:
|
192 |
+
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
|
193 |
+
|
194 |
+
pipeline_list.append(wds.split_by_node)
|
195 |
+
|
196 |
+
if shuffle_before_split_by_worker_buffer_size is not None:
|
197 |
+
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
|
198 |
+
|
199 |
+
pipeline_list.extend(
|
200 |
+
[
|
201 |
+
wds.split_by_worker,
|
202 |
+
tarfile_to_webdocument_samples(),
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
if shuffle_after_tarfile_to_samples_buffer_size is not None:
|
207 |
+
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
|
208 |
+
|
209 |
+
pipeline_list.extend(
|
210 |
+
[
|
211 |
+
collate_texts_and_images_webdocument(),
|
212 |
+
decode_image_and_text_webdocument(),
|
213 |
+
wds.batched(batch_size, collation_fn=collate_dicts, partial=True),
|
214 |
+
]
|
215 |
+
)
|
216 |
+
|
217 |
+
if shuffle_after_batching_buffer_size is not None:
|
218 |
+
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
|
219 |
+
|
220 |
+
dataset = wds.DataPipeline(pipeline_list)
|
221 |
+
return dataset
|
222 |
+
|
223 |
+
|
224 |
+
def split_keep_2(x):
|
225 |
+
x = x.strip("./")
|
226 |
+
x_splitter = x.split(".")
|
227 |
+
return x_splitter[0], x_splitter[1]
|
228 |
+
|
229 |
+
|
230 |
+
def _tarfile_to_pair_samples(src, handler=log_and_continue):
|
231 |
+
streams = url_opener(src, handler=handler)
|
232 |
+
files = tar_file_expander(streams, handler=handler)
|
233 |
+
samples = group_by_keys(files, keys=split_keep_2, handler=handler)
|
234 |
+
return samples
|
235 |
+
|
236 |
+
|
237 |
+
tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples)
|
238 |
+
|
239 |
+
|
240 |
+
def _decode_image_and_text_pairs(data, handler=log_and_continue):
|
241 |
+
image_decoder = ImageDecoder()
|
242 |
+
for sample in data:
|
243 |
+
try:
|
244 |
+
sample["image"] = image_decoder(sample["image"])
|
245 |
+
sample["text"] = sample["text"].decode("utf-8")
|
246 |
+
yield sample
|
247 |
+
except Exception as exn:
|
248 |
+
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
|
249 |
+
if handler(exn):
|
250 |
+
continue
|
251 |
+
else:
|
252 |
+
break
|
253 |
+
|
254 |
+
|
255 |
+
decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs)
|
256 |
+
|
257 |
+
|
258 |
+
def get_image_caption_pairs_webdataset(
|
259 |
+
urls,
|
260 |
+
batch_size,
|
261 |
+
shuffle_initial_urls_list=False,
|
262 |
+
shuffle_before_split_by_node_buffer_size=100,
|
263 |
+
shuffle_before_split_by_worker_buffer_size=100,
|
264 |
+
shuffle_after_tarfile_to_samples_buffer_size=100,
|
265 |
+
shuffle_after_batching_buffer_size=1000,
|
266 |
+
):
|
267 |
+
if shuffle_initial_urls_list:
|
268 |
+
random.shuffle(urls)
|
269 |
+
|
270 |
+
pipeline_list = [wds.SimpleShardList(urls)]
|
271 |
+
|
272 |
+
if shuffle_before_split_by_node_buffer_size is not None:
|
273 |
+
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
|
274 |
+
|
275 |
+
pipeline_list.append(wds.split_by_node)
|
276 |
+
|
277 |
+
if shuffle_before_split_by_worker_buffer_size is not None:
|
278 |
+
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
|
279 |
+
|
280 |
+
pipeline_list.extend(
|
281 |
+
[
|
282 |
+
wds.split_by_worker,
|
283 |
+
tarfile_to_pair_samples(handler=log_and_continue),
|
284 |
+
]
|
285 |
+
)
|
286 |
+
|
287 |
+
if shuffle_after_tarfile_to_samples_buffer_size is not None:
|
288 |
+
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
|
289 |
+
|
290 |
+
pipeline_list.extend(
|
291 |
+
[
|
292 |
+
decode_image_and_text_pairs(),
|
293 |
+
wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed
|
294 |
+
]
|
295 |
+
)
|
296 |
+
|
297 |
+
if shuffle_after_batching_buffer_size is not None:
|
298 |
+
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
|
299 |
+
|
300 |
+
dataset = wds.DataPipeline(pipeline_list)
|
301 |
+
return dataset
|
302 |
+
|
303 |
+
|
304 |
+
def get_webdataset(
|
305 |
+
urls,
|
306 |
+
ds_type: DatasetTypes,
|
307 |
+
batch_size: int,
|
308 |
+
shuffle_initial_urls_list,
|
309 |
+
shuffle_before_split_by_node_buffer_size,
|
310 |
+
shuffle_before_split_by_worker_buffer_size,
|
311 |
+
shuffle_after_tarfile_to_samples_buffer_size,
|
312 |
+
shuffle_after_batching_buffer_size,
|
313 |
+
):
|
314 |
+
if ds_type == DatasetTypes.WEB_DOCUMENTS:
|
315 |
+
return get_webdocuments_webdataset(
|
316 |
+
urls,
|
317 |
+
batch_size,
|
318 |
+
shuffle_initial_urls_list,
|
319 |
+
shuffle_before_split_by_node_buffer_size,
|
320 |
+
shuffle_before_split_by_worker_buffer_size,
|
321 |
+
shuffle_after_tarfile_to_samples_buffer_size,
|
322 |
+
shuffle_after_batching_buffer_size,
|
323 |
+
)
|
324 |
+
elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS:
|
325 |
+
return get_image_caption_pairs_webdataset(
|
326 |
+
urls,
|
327 |
+
batch_size,
|
328 |
+
shuffle_initial_urls_list,
|
329 |
+
shuffle_before_split_by_node_buffer_size,
|
330 |
+
shuffle_before_split_by_worker_buffer_size,
|
331 |
+
shuffle_after_tarfile_to_samples_buffer_size,
|
332 |
+
shuffle_after_batching_buffer_size,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
raise ValueError(f"Unknown dataset type: {ds_type}")
|
336 |
+
|
337 |
+
|
338 |
+
def check_webdataset_command(command):
|
339 |
+
if "s3:/" not in command:
|
340 |
+
return True
|
341 |
+
|
342 |
+
command = command.strip()
|
343 |
+
if not command.startswith("pipe:bash"):
|
344 |
+
return False
|
345 |
+
|
346 |
+
if not command.endswith(".tar"):
|
347 |
+
return False
|
348 |
+
|
349 |
+
if "get_file.sh" not in command:
|
350 |
+
return False
|
351 |
+
|
352 |
+
return True
|
m4/training/debug_utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Trainer debug utils """
|
2 |
+
|
3 |
+
|
4 |
+
def dump_optim_states(self):
|
5 |
+
"""dumps basic information about the state of the optimizer"""
|
6 |
+
|
7 |
+
print("*** Optim States Dump:")
|
8 |
+
param_groups_cnt = len(self.vl_optim.param_groups)
|
9 |
+
# state dict has more than param_groups info, so extract only the param groups
|
10 |
+
param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
|
11 |
+
for i, state in enumerate(param_group_states):
|
12 |
+
print(f"param group: {i}")
|
13 |
+
print(f" step={state['step']}")
|
14 |
+
print(f" exp_avg all_zero={all(state['exp_avg'] == 0)}")
|
15 |
+
print(f" exp_avg_sq all_zero={all(state['exp_avg_sq'] == 0)}")
|
16 |
+
|
17 |
+
# can also dump LR state if need be
|
18 |
+
# print(f"LR={self.vl_scheduler.get_last_lr()}")
|
19 |
+
|
20 |
+
|
21 |
+
def validate_optim_states_are_reset(self):
|
22 |
+
"""
|
23 |
+
for a new or fully reset optimizer we expect all zeros `exp_avg` and `exp_avg_sq` state tensors and step=1
|
24 |
+
"""
|
25 |
+
|
26 |
+
param_groups_cnt = len(self.vl_optim.param_groups)
|
27 |
+
param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
|
28 |
+
for i, state in enumerate(param_group_states):
|
29 |
+
if state["step"] != 1:
|
30 |
+
raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['step']}")
|
31 |
+
if not all(state["exp_avg"] == 0):
|
32 |
+
raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg']}")
|
33 |
+
if not all(state["exp_avg_sq"] == 0):
|
34 |
+
raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg_sq']}")
|
m4/training/packing.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from bisect import bisect_left
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
# Hyper-parameters
|
16 |
+
_IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token
|
17 |
+
_MIN_LENGTH_DOCUMENTS_TO_PACK = (
|
18 |
+
5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens)
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
|
23 |
+
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
|
24 |
+
|
25 |
+
# If any of images index are more than num_classes, set them to -1.
|
26 |
+
# Words after the max number of images allowed have been seen don't attend on anything
|
27 |
+
if num_classes != -1:
|
28 |
+
incremental_mask[incremental_mask >= num_classes] = -1
|
29 |
+
|
30 |
+
negatives = incremental_mask == -1
|
31 |
+
incremental_mask[negatives] = 0
|
32 |
+
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
|
33 |
+
attn_mask[negatives, :] = 0
|
34 |
+
return attn_mask
|
35 |
+
|
36 |
+
|
37 |
+
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
|
38 |
+
image_attention_mask = torch.full_like(input_ids, fill_value=-1)
|
39 |
+
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
|
40 |
+
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
41 |
+
eod_token_id = tokenizer.eos_token_id
|
42 |
+
for batch_idx in range(input_ids.size(0)):
|
43 |
+
count = -1
|
44 |
+
seen_eod = False
|
45 |
+
for idx, token_id in enumerate(input_ids[batch_idx]):
|
46 |
+
if token_id == image_token_id:
|
47 |
+
count += 1
|
48 |
+
image_attention_mask[batch_idx][idx] = count
|
49 |
+
seen_eod = False
|
50 |
+
else:
|
51 |
+
image_attention_mask[batch_idx][idx] = count
|
52 |
+
|
53 |
+
if seen_eod:
|
54 |
+
image_attention_mask[batch_idx][idx] = -1
|
55 |
+
|
56 |
+
if token_id == eod_token_id:
|
57 |
+
seen_eod = True
|
58 |
+
|
59 |
+
for batch_idx in range(input_ids.size(0)):
|
60 |
+
count = -1
|
61 |
+
seen_eod = False
|
62 |
+
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
|
63 |
+
token_id = input_ids[batch_idx][idx]
|
64 |
+
if token_id == image_token_id:
|
65 |
+
count += 1
|
66 |
+
next_image_attention_mask[batch_idx][idx] = count
|
67 |
+
seen_eod = False
|
68 |
+
else:
|
69 |
+
next_image_attention_mask[batch_idx][idx] = count
|
70 |
+
|
71 |
+
if token_id == eod_token_id:
|
72 |
+
seen_eod = True
|
73 |
+
|
74 |
+
if seen_eod:
|
75 |
+
next_image_attention_mask[batch_idx][idx] = -1
|
76 |
+
|
77 |
+
non_negative_indices = next_image_attention_mask[batch_idx] != -1
|
78 |
+
next_image_attention_mask[batch_idx][non_negative_indices] -= count
|
79 |
+
next_image_attention_mask[batch_idx][non_negative_indices] *= -1
|
80 |
+
|
81 |
+
return image_attention_mask, next_image_attention_mask
|
82 |
+
|
83 |
+
|
84 |
+
def laplacian_blur_detection(image, threshold=0.0):
|
85 |
+
# compute the Laplacian of the image and then return the focus
|
86 |
+
# measure, which is simply the variance of the Laplacian
|
87 |
+
if threshold == 0.0:
|
88 |
+
return False
|
89 |
+
|
90 |
+
image = np.array(image)
|
91 |
+
|
92 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
93 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
94 |
+
return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold
|
95 |
+
else:
|
96 |
+
# Don't remove grayscale images
|
97 |
+
return False
|
98 |
+
|
99 |
+
|
100 |
+
def fft_blur_detection(image, size=50, threshold=0.0):
|
101 |
+
if threshold == 0.0:
|
102 |
+
return False
|
103 |
+
(h, w) = image.shape
|
104 |
+
(cX, cY) = (int(w / 2.0), int(h / 2.0))
|
105 |
+
fft = np.fft.fft2(image)
|
106 |
+
fftShift = np.fft.fftshift(fft)
|
107 |
+
fftShift[cY - size : cY + size, cX - size : cX + size] = 0
|
108 |
+
fftShift = np.fft.ifftshift(fftShift)
|
109 |
+
recon = np.fft.ifft2(fftShift)
|
110 |
+
magnitude = 20 * np.log(np.abs(recon))
|
111 |
+
mean = np.mean(magnitude)
|
112 |
+
return mean < threshold
|
113 |
+
|
114 |
+
|
115 |
+
def split_pack_and_pad(
|
116 |
+
sample,
|
117 |
+
tokenizer,
|
118 |
+
max_seq_len,
|
119 |
+
image_transform,
|
120 |
+
max_num_images,
|
121 |
+
max_num_samples_per_document=10,
|
122 |
+
prefix_seed=(0, 0),
|
123 |
+
is_blurred_fn=None,
|
124 |
+
blur_threshold=0.0,
|
125 |
+
add_begin_of_doc_token=False,
|
126 |
+
add_end_of_doc_token=True,
|
127 |
+
max_num_images_per_document=None,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Return a batch of samples in the format expected by the model which
|
131 |
+
includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`,
|
132 |
+
and `next_image_attention_mask`. The `input_ids` are sampled from the document to
|
133 |
+
ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together.
|
134 |
+
For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document`
|
135 |
+
(where the latter is proportional to the length of the document and inversely proportional to the length of subsequences)
|
136 |
+
`input_ids` with sequence length `max_seq_len` from the document. This means that
|
137 |
+
each sample sampled can have different start index. Based on the start index of sample that
|
138 |
+
has been sampled, we also sample a maximum of `max_num_images` images from the document.
|
139 |
+
If there are less than `max_num_images` images in the document, we pad the images with zeros.
|
140 |
+
The start indexes are skewed towards subsequences that contain images.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
sample (Dict): A sample object containing the document with images and text.
|
144 |
+
tokenizer (PretrainedTokenizer): Text tokenizer to be used.
|
145 |
+
max_seq_len (int): Maximum sequence length of the returned text tokens.
|
146 |
+
image_transform (Callable): Transform to be applied on the images
|
147 |
+
max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros.
|
148 |
+
max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10.
|
149 |
+
prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice`
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
_type_: _description_
|
153 |
+
"""
|
154 |
+
text_batch = sample["texts"]
|
155 |
+
|
156 |
+
image_batch = sample.get("image_embeddings", None)
|
157 |
+
is_raw_images = False
|
158 |
+
if image_batch is None:
|
159 |
+
image_batch = sample.get("images", None)
|
160 |
+
is_raw_images = True
|
161 |
+
if image_batch is None:
|
162 |
+
raise ValueError("Either image_embeddings or images must be present in the sample")
|
163 |
+
|
164 |
+
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
165 |
+
last_was_image = False
|
166 |
+
|
167 |
+
if is_blurred_fn is None:
|
168 |
+
is_blurred_fn = fft_blur_detection
|
169 |
+
|
170 |
+
all_images = []
|
171 |
+
all_texts = []
|
172 |
+
for raw_images, raw_texts in zip(image_batch, text_batch):
|
173 |
+
# Filter ones that don't have either one image and one text word
|
174 |
+
if not any(raw_images) or not any(raw_texts):
|
175 |
+
continue
|
176 |
+
|
177 |
+
if max_num_images_per_document:
|
178 |
+
num_images = sum([1 if image is not None else 0 for image in raw_images])
|
179 |
+
if num_images > max_num_images_per_document:
|
180 |
+
continue
|
181 |
+
|
182 |
+
any_blurred = False
|
183 |
+
|
184 |
+
if is_raw_images and blur_threshold > 0.0:
|
185 |
+
for image in raw_images:
|
186 |
+
if image is not None:
|
187 |
+
image = _convert_to_rgb(image)
|
188 |
+
any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold)
|
189 |
+
if any_blurred:
|
190 |
+
break
|
191 |
+
|
192 |
+
if any_blurred:
|
193 |
+
continue
|
194 |
+
|
195 |
+
inds_of_texts_to_split = [
|
196 |
+
i
|
197 |
+
for i, text in enumerate(raw_texts)
|
198 |
+
if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
|
199 |
+
]
|
200 |
+
if inds_of_texts_to_split:
|
201 |
+
splitted_raw_images, splitted_raw_texts = [], []
|
202 |
+
previous_i = 0
|
203 |
+
for i in inds_of_texts_to_split:
|
204 |
+
splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
|
205 |
+
part1, part2 = splitting[0], splitting[-1]
|
206 |
+
|
207 |
+
sub_doc_images = raw_images[previous_i:i] + [None]
|
208 |
+
sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
|
209 |
+
if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None
|
210 |
+
continue
|
211 |
+
|
212 |
+
splitted_raw_images.append(sub_doc_images)
|
213 |
+
splitted_raw_texts.append(sub_doc_texts)
|
214 |
+
|
215 |
+
if part2.strip() == "":
|
216 |
+
previous_i = i + 1
|
217 |
+
else:
|
218 |
+
raw_texts[i] = part2.strip()
|
219 |
+
previous_i = i
|
220 |
+
|
221 |
+
if previous_i < len(raw_images) and any(raw_images[previous_i:]):
|
222 |
+
splitted_raw_images.append(raw_images[previous_i:])
|
223 |
+
splitted_raw_texts.append(raw_texts[previous_i:])
|
224 |
+
|
225 |
+
else:
|
226 |
+
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]
|
227 |
+
|
228 |
+
# Sanity check
|
229 |
+
if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
|
230 |
+
raise ValueError(
|
231 |
+
"Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
|
232 |
+
" Something core went wrong during the splitting and needs to be fixed."
|
233 |
+
)
|
234 |
+
|
235 |
+
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
|
236 |
+
images, web_text = [], ""
|
237 |
+
for image, text in zip(s_r_ims, s_r_txts):
|
238 |
+
if text is None and image is None:
|
239 |
+
continue
|
240 |
+
|
241 |
+
if image is not None:
|
242 |
+
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
|
243 |
+
if is_raw_images:
|
244 |
+
images.append(image_transform(image))
|
245 |
+
else:
|
246 |
+
images.append(torch.tensor(image))
|
247 |
+
last_was_image = True
|
248 |
+
elif text is not None:
|
249 |
+
if last_was_image:
|
250 |
+
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
|
251 |
+
last_was_image = False
|
252 |
+
else:
|
253 |
+
web_text += f" {text}" if web_text != "" else text
|
254 |
+
|
255 |
+
if last_was_image:
|
256 |
+
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"
|
257 |
+
|
258 |
+
web_text = web_text.strip(" ")
|
259 |
+
|
260 |
+
# This is mostly a sanity check. Cases like that should not happen at that point.
|
261 |
+
if web_text == "" or len(images) == 0:
|
262 |
+
continue
|
263 |
+
|
264 |
+
images = torch.stack(images)
|
265 |
+
all_images.append(images)
|
266 |
+
|
267 |
+
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
|
268 |
+
if add_end_of_doc_token:
|
269 |
+
web_text_ids += [tokenizer.eos_token_id]
|
270 |
+
|
271 |
+
if add_begin_of_doc_token:
|
272 |
+
web_text_ids = [tokenizer.bos_token_id] + web_text_ids
|
273 |
+
all_texts.append(web_text_ids)
|
274 |
+
|
275 |
+
output_input_ids = []
|
276 |
+
output_images = []
|
277 |
+
output_attention_masks = []
|
278 |
+
output_num_images = []
|
279 |
+
output_num_text_tokens = []
|
280 |
+
|
281 |
+
input_ids_to_pack = []
|
282 |
+
images_to_pack = []
|
283 |
+
for images, text in zip(all_images, all_texts):
|
284 |
+
# We save all the documents which are shorter than the max_seq_len to pack them together.
|
285 |
+
if len(text) <= max_seq_len:
|
286 |
+
if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences
|
287 |
+
continue
|
288 |
+
input_ids_to_pack.extend(text)
|
289 |
+
images_to_pack.extend(images)
|
290 |
+
else:
|
291 |
+
# Computing the bonus scores for tokens near images to skew the sampling towards them
|
292 |
+
# The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled.
|
293 |
+
# Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images.
|
294 |
+
# We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution.
|
295 |
+
# Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added.
|
296 |
+
|
297 |
+
# Now the remaining question is which precedding tokens do we distribue bonuses to.
|
298 |
+
# We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx].
|
299 |
+
# For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order).
|
300 |
+
# If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to.
|
301 |
+
# In practice, images will tend to be at the beginning of the sampled sub-sequence.
|
302 |
+
# If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to.
|
303 |
+
# In practice, images will tend to be at the end of the sampled sub-sequence.
|
304 |
+
# To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part.
|
305 |
+
# In practise, images will tend to be in the middle of the sampled sequence.
|
306 |
+
|
307 |
+
# Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence),
|
308 |
+
# and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image.
|
309 |
+
all_scores = np.array([1] * len(text))
|
310 |
+
for img_token_idx in np.where(np.array(text) == image_token_id)[0]:
|
311 |
+
all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE
|
312 |
+
# all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents)
|
313 |
+
all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK]
|
314 |
+
|
315 |
+
# The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length
|
316 |
+
max_num_samples_for_curr_document = len(text) // max_seq_len
|
317 |
+
# Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len)
|
318 |
+
choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice(
|
319 |
+
range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing
|
320 |
+
min(
|
321 |
+
len(text) - max_seq_len, 2 * max_num_samples_per_document
|
322 |
+
), # Sampling more than necessary and then breaking out of the for loop once we have enough samples
|
323 |
+
p=all_scores / np.linalg.norm(all_scores, ord=1),
|
324 |
+
replace=False,
|
325 |
+
)
|
326 |
+
|
327 |
+
nb_effective_sequences_out_of_sampling = 0
|
328 |
+
for start_index in choices:
|
329 |
+
image_start_index = text[:start_index].count(image_token_id)
|
330 |
+
text_sub_sequence = text[start_index : start_index + max_seq_len]
|
331 |
+
image_count = text_sub_sequence.count(image_token_id)
|
332 |
+
if image_count == 0:
|
333 |
+
# Skip if there are no images in the sequence
|
334 |
+
continue
|
335 |
+
|
336 |
+
if len(text_sub_sequence) < max_seq_len:
|
337 |
+
# If the sub-sequence is shorter than max_seq_len, we reserve it for packing
|
338 |
+
# It necessarily mean that the sub-sequence was sampled towards the end of the document,
|
339 |
+
# which implies that we only need the `image_start_index` and not the `image_end_index`
|
340 |
+
if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]):
|
341 |
+
# A safeguard for this
|
342 |
+
logger.warning(
|
343 |
+
"Skipping this sample because of mismatch in actual number of images and "
|
344 |
+
"the '<image>' tokens in the text"
|
345 |
+
)
|
346 |
+
continue
|
347 |
+
input_ids_to_pack.extend(text_sub_sequence)
|
348 |
+
images_to_pack.extend(images[image_start_index:])
|
349 |
+
continue
|
350 |
+
|
351 |
+
current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)]
|
352 |
+
if len(current_images) != min(max_num_images, image_count):
|
353 |
+
# A safeguard for something off about this document, maybe `<image>` tag that
|
354 |
+
# by there from before or some issue in parsing the image?
|
355 |
+
logger.warning(
|
356 |
+
"Skipping this sample because of mismatch in actual number of images and "
|
357 |
+
"the '<image>' tokens in the text"
|
358 |
+
)
|
359 |
+
break
|
360 |
+
padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:])
|
361 |
+
padded_image_tensor[: min(max_num_images, image_count)] = current_images
|
362 |
+
output_images.append(padded_image_tensor)
|
363 |
+
output_num_images.append(min(max_num_images, image_count))
|
364 |
+
|
365 |
+
output_input_ids.append(torch.tensor(text_sub_sequence))
|
366 |
+
output_num_text_tokens.append(len(text_sub_sequence))
|
367 |
+
|
368 |
+
attention_mask = torch.ones((max_seq_len,), dtype=torch.long)
|
369 |
+
output_attention_masks.append(attention_mask)
|
370 |
+
|
371 |
+
nb_effective_sequences_out_of_sampling += 1
|
372 |
+
if nb_effective_sequences_out_of_sampling >= min(
|
373 |
+
max_num_samples_for_curr_document, max_num_samples_per_document
|
374 |
+
):
|
375 |
+
# We got all the samples we need for this document, so breaking out
|
376 |
+
break
|
377 |
+
|
378 |
+
# Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack`
|
379 |
+
if input_ids_to_pack:
|
380 |
+
image_counter = 0
|
381 |
+
for i in range(0, len(input_ids_to_pack), max_seq_len):
|
382 |
+
current_input_ids = input_ids_to_pack[i : i + max_seq_len]
|
383 |
+
unpadded_seq_len = len(current_input_ids)
|
384 |
+
num_images = current_input_ids.count(image_token_id)
|
385 |
+
if num_images == 0:
|
386 |
+
continue
|
387 |
+
current_images = images_to_pack[image_counter : image_counter + num_images]
|
388 |
+
image_counter += num_images
|
389 |
+
if unpadded_seq_len < max_seq_len:
|
390 |
+
padded_input_ids = [tokenizer.pad_token_id] * max_seq_len
|
391 |
+
padded_input_ids[:unpadded_seq_len] = current_input_ids
|
392 |
+
current_input_ids = padded_input_ids
|
393 |
+
elif unpadded_seq_len > max_seq_len:
|
394 |
+
# This case has no purpose other than safeguard
|
395 |
+
continue
|
396 |
+
try:
|
397 |
+
current_images = torch.stack(current_images)[:max_num_images]
|
398 |
+
except Exception:
|
399 |
+
continue
|
400 |
+
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
|
401 |
+
padded_image_tensor[: current_images.size(0)] = current_images
|
402 |
+
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
|
403 |
+
attention_mask[:unpadded_seq_len] = 1
|
404 |
+
|
405 |
+
output_images.append(padded_image_tensor)
|
406 |
+
output_input_ids.append(torch.tensor(current_input_ids))
|
407 |
+
output_num_text_tokens.append(unpadded_seq_len)
|
408 |
+
output_num_images.append(min(max_num_images, num_images))
|
409 |
+
|
410 |
+
output_attention_masks.append(attention_mask)
|
411 |
+
|
412 |
+
if len(output_images) == 0 or len(output_input_ids) == 0:
|
413 |
+
result = {
|
414 |
+
"input_ids": torch.tensor([], dtype=torch.long),
|
415 |
+
"attention_mask": torch.tensor([], dtype=torch.bool),
|
416 |
+
"image_attention_mask": torch.tensor([], dtype=torch.bool),
|
417 |
+
"next_image_attention_mask": torch.tensor([], dtype=torch.bool),
|
418 |
+
"num_images": torch.tensor([], dtype=torch.long),
|
419 |
+
"num_text_tokens": torch.tensor([], dtype=torch.long),
|
420 |
+
}
|
421 |
+
if is_raw_images:
|
422 |
+
result["pixel_values"] = torch.tensor([], dtype=torch.float32)
|
423 |
+
else:
|
424 |
+
result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
|
425 |
+
return result
|
426 |
+
|
427 |
+
output_input_ids = torch.stack(output_input_ids)
|
428 |
+
output_images = torch.stack(output_images)
|
429 |
+
output_attention_masks = torch.stack(output_attention_masks)
|
430 |
+
|
431 |
+
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(
|
432 |
+
output_input_ids, tokenizer
|
433 |
+
)
|
434 |
+
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
|
435 |
+
next_image_attention_mask = incremental_to_binary_attention_mask(
|
436 |
+
next_image_attention_mask, num_classes=max_num_images
|
437 |
+
)
|
438 |
+
|
439 |
+
result = {
|
440 |
+
"input_ids": output_input_ids,
|
441 |
+
"attention_mask": output_attention_masks,
|
442 |
+
"image_attention_mask": image_attention_mask,
|
443 |
+
"next_image_attention_mask": next_image_attention_mask,
|
444 |
+
"num_images": torch.tensor(output_num_images),
|
445 |
+
"num_text_tokens": torch.tensor(output_num_text_tokens),
|
446 |
+
}
|
447 |
+
if is_raw_images:
|
448 |
+
result["pixel_values"] = output_images
|
449 |
+
else:
|
450 |
+
result["image_embeddings"] = output_images
|
451 |
+
return result
|
452 |
+
|
453 |
+
|
454 |
+
def split_and_pad_pmd(
|
455 |
+
sample,
|
456 |
+
tokenizer,
|
457 |
+
max_seq_len,
|
458 |
+
image_transform,
|
459 |
+
max_num_images,
|
460 |
+
prefix_seed=(0, 0),
|
461 |
+
is_blurred_fn=None,
|
462 |
+
blur_threshold=0.0,
|
463 |
+
prob_image_at_end=0.5, # If 1, the <image> token is always added at the end of the text
|
464 |
+
# If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated.
|
465 |
+
padding_tolerance=-1,
|
466 |
+
add_begin_of_doc_token=False,
|
467 |
+
add_end_of_doc_token=True,
|
468 |
+
):
|
469 |
+
if is_blurred_fn is None:
|
470 |
+
is_blurred_fn = fft_blur_detection
|
471 |
+
|
472 |
+
text_batch = sample["text"]
|
473 |
+
image_batch = sample.get("image_embedding", None)
|
474 |
+
is_raw_images = False
|
475 |
+
if image_batch is None:
|
476 |
+
image_batch = sample.get("image", None)
|
477 |
+
is_raw_images = True
|
478 |
+
|
479 |
+
filtered_image_batch = []
|
480 |
+
filtered_input_ids = []
|
481 |
+
|
482 |
+
# Define whether for the current PMD batch whether the images will be at the start or at the end.
|
483 |
+
rng = np.random.default_rng(seed=list(prefix_seed))
|
484 |
+
is_image_at_end = False
|
485 |
+
|
486 |
+
# rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will
|
487 |
+
# always be less than `prob_image_at_end` and `is_image_at_end` will always be True.
|
488 |
+
# This means that images will always be at the end of the text.
|
489 |
+
if rng.random() < prob_image_at_end:
|
490 |
+
is_image_at_end = True
|
491 |
+
|
492 |
+
for image, text in zip(image_batch, text_batch):
|
493 |
+
if text is None or image is None:
|
494 |
+
continue
|
495 |
+
|
496 |
+
if is_raw_images and is_blurred_fn(image, threshold=blur_threshold):
|
497 |
+
continue
|
498 |
+
|
499 |
+
sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}"
|
500 |
+
|
501 |
+
# Remove trailing and leading whitespaces, including newlines and tabs
|
502 |
+
text = text.strip()
|
503 |
+
|
504 |
+
if is_image_at_end:
|
505 |
+
sample_text = f"{text}{sample_text}"
|
506 |
+
else:
|
507 |
+
sample_text = f"{sample_text}{text}"
|
508 |
+
|
509 |
+
sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False)
|
510 |
+
if add_end_of_doc_token:
|
511 |
+
sample_input_ids += [tokenizer.eos_token_id]
|
512 |
+
|
513 |
+
if add_begin_of_doc_token:
|
514 |
+
sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids
|
515 |
+
|
516 |
+
filtered_image_batch.append(image)
|
517 |
+
filtered_input_ids.append(sample_input_ids)
|
518 |
+
|
519 |
+
# sort by length of text and save same length elements in a mapping so we
|
520 |
+
# can retrieve candidates later.
|
521 |
+
filtered_image_batch, filtered_input_ids = zip(
|
522 |
+
*sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1]))
|
523 |
+
)
|
524 |
+
mapping_by_len = OrderedDict()
|
525 |
+
for i, sample_input_ids in enumerate(filtered_input_ids):
|
526 |
+
if len(sample_input_ids) not in mapping_by_len:
|
527 |
+
mapping_by_len[len(sample_input_ids)] = []
|
528 |
+
mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids))
|
529 |
+
|
530 |
+
all_images = []
|
531 |
+
all_texts = []
|
532 |
+
all_attention_masks = []
|
533 |
+
all_num_images = []
|
534 |
+
all_num_text_tokens = []
|
535 |
+
current_text = []
|
536 |
+
current_images = []
|
537 |
+
|
538 |
+
while True:
|
539 |
+
current_lens = list(mapping_by_len.keys())
|
540 |
+
if len(current_text) > 0:
|
541 |
+
# Now we try to do a binary search to find the biggest sequence that
|
542 |
+
# we can fit into the current sequence.
|
543 |
+
# This will eventually use up bigger sequences faster which is good
|
544 |
+
# and leave smaller sequences to pack with each other later.
|
545 |
+
diff = max_seq_len - len(current_text)
|
546 |
+
if len(current_lens) == 0:
|
547 |
+
possible_index = -1
|
548 |
+
else:
|
549 |
+
possible_index = bisect_left(current_lens, diff)
|
550 |
+
if possible_index == len(current_lens) or current_lens[possible_index] != diff:
|
551 |
+
possible_index -= 1
|
552 |
+
|
553 |
+
if possible_index >= 0:
|
554 |
+
best_possible_length = current_lens[possible_index]
|
555 |
+
image, sample_input_ids = mapping_by_len[best_possible_length].pop(0)
|
556 |
+
|
557 |
+
# If we have used up all the samples of a certain length, remove
|
558 |
+
# that length from the mapping.
|
559 |
+
if len(mapping_by_len[best_possible_length]) == 0:
|
560 |
+
del mapping_by_len[best_possible_length]
|
561 |
+
current_text.extend(sample_input_ids)
|
562 |
+
if is_raw_images:
|
563 |
+
current_images.append(image_transform(image))
|
564 |
+
else:
|
565 |
+
current_images.append(torch.tensor(image))
|
566 |
+
elif diff > padding_tolerance and padding_tolerance != -1:
|
567 |
+
# If we are here, it means that we still have padding left
|
568 |
+
# and we have exhausted our current unique options that will allow us to
|
569 |
+
# fill this sequence completely.
|
570 |
+
# So, we will try to fill the sequence with whatever we get from the unchanged
|
571 |
+
# copy of all sequences.
|
572 |
+
while diff > padding_tolerance:
|
573 |
+
# Find a random sequence to fit
|
574 |
+
# Why we need to add more stuff to prefix seed?
|
575 |
+
# prefix_seed will be same in the same batch which means that it might sample
|
576 |
+
# same thing again and again if there are multiple cases of padding in the
|
577 |
+
# same batch which means we need to make this part as random as possible.
|
578 |
+
rng = np.random.default_rng(
|
579 |
+
prefix_seed
|
580 |
+
+ (
|
581 |
+
diff,
|
582 |
+
len(current_text),
|
583 |
+
len(all_texts),
|
584 |
+
all_num_images,
|
585 |
+
)
|
586 |
+
)
|
587 |
+
choice = rng.choice(range(len(filtered_input_ids)))
|
588 |
+
image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice]
|
589 |
+
current_text.extend(sample_input_ids)
|
590 |
+
if is_raw_images:
|
591 |
+
current_images.append(image_transform(image))
|
592 |
+
else:
|
593 |
+
current_images.append(torch.tensor(image))
|
594 |
+
diff = max_seq_len - len(current_text)
|
595 |
+
# In the next top-level while loop iteration, this should go into the else
|
596 |
+
# clause which should also handle the sequences longer than max_seq_len
|
597 |
+
else:
|
598 |
+
current_images = torch.stack(current_images)
|
599 |
+
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
|
600 |
+
padded_image_tensor[: current_images.size(0)] = current_images[
|
601 |
+
: min(max_num_images, current_images.size(0))
|
602 |
+
]
|
603 |
+
all_num_images.append(min(max_num_images, current_images.size(0)))
|
604 |
+
all_images.append(padded_image_tensor)
|
605 |
+
|
606 |
+
padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id)
|
607 |
+
current_max_len = min(max_seq_len, len(current_text))
|
608 |
+
padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len]
|
609 |
+
all_num_text_tokens.append(current_max_len)
|
610 |
+
all_texts.append(padded_input_ids)
|
611 |
+
|
612 |
+
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
|
613 |
+
attention_mask[: len(current_text)] = 1
|
614 |
+
all_attention_masks.append(attention_mask)
|
615 |
+
|
616 |
+
# Make sure to reset the current text and images.
|
617 |
+
current_images = []
|
618 |
+
current_text = []
|
619 |
+
if len(current_lens) == 0:
|
620 |
+
break
|
621 |
+
else:
|
622 |
+
# A case where we might not have any samples left over after the initial filtering step.
|
623 |
+
if len(current_lens) == 0:
|
624 |
+
break
|
625 |
+
image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0)
|
626 |
+
if len(mapping_by_len[current_lens[-1]]) == 0:
|
627 |
+
del mapping_by_len[current_lens[-1]]
|
628 |
+
current_text = sample_input_ids[:max_seq_len]
|
629 |
+
if is_raw_images:
|
630 |
+
current_images = [image_transform(image)]
|
631 |
+
else:
|
632 |
+
current_images = [torch.tensor(image)]
|
633 |
+
|
634 |
+
if len(all_images) == 0 or len(all_texts) == 0:
|
635 |
+
result = {
|
636 |
+
"input_ids": torch.tensor([], dtype=torch.long),
|
637 |
+
"attention_mask": torch.tensor([], dtype=torch.bool),
|
638 |
+
"image_attention_mask": torch.tensor([], dtype=torch.bool),
|
639 |
+
"num_images": torch.tensor([], dtype=torch.long),
|
640 |
+
"num_text_tokens": torch.tensor([], dtype=torch.long),
|
641 |
+
}
|
642 |
+
if is_raw_images:
|
643 |
+
result["pixel_values"] = torch.tensor([], dtype=torch.float32)
|
644 |
+
else:
|
645 |
+
result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
|
646 |
+
return result
|
647 |
+
|
648 |
+
all_texts = torch.stack(all_texts)
|
649 |
+
all_images = torch.stack(all_images)
|
650 |
+
all_attention_masks = torch.stack(all_attention_masks)
|
651 |
+
|
652 |
+
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer)
|
653 |
+
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
|
654 |
+
next_image_attention_mask = incremental_to_binary_attention_mask(
|
655 |
+
next_image_attention_mask, num_classes=max_num_images
|
656 |
+
)
|
657 |
+
|
658 |
+
output = {
|
659 |
+
"input_ids": all_texts,
|
660 |
+
"attention_mask": all_attention_masks,
|
661 |
+
"image_attention_mask": image_attention_mask,
|
662 |
+
"num_images": torch.tensor(all_num_images),
|
663 |
+
"num_text_tokens": torch.tensor(all_num_text_tokens),
|
664 |
+
}
|
665 |
+
if is_raw_images:
|
666 |
+
output["pixel_values"] = all_images
|
667 |
+
else:
|
668 |
+
output["image_embeddings"] = all_images
|
669 |
+
|
670 |
+
if is_image_at_end:
|
671 |
+
# Set the correct attention mask based on whether the image is at the start
|
672 |
+
# or not. When it is at the end, we need next image attention mask.
|
673 |
+
output["image_attention_mask"] = next_image_attention_mask
|
674 |
+
|
675 |
+
return output
|
676 |
+
|
677 |
+
|
678 |
+
# Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py
|
679 |
+
def random_spans_helper(
|
680 |
+
inputs_length,
|
681 |
+
noise_density,
|
682 |
+
mean_noise_span_length,
|
683 |
+
extra_tokens_per_span_inputs,
|
684 |
+
extra_tokens_per_span_targets,
|
685 |
+
verbose=False,
|
686 |
+
):
|
687 |
+
"""Training parameters to avoid padding with random_spans_noise_mask.
|
688 |
+
|
689 |
+
When training a model with random_spans_noise_mask, we would like to set the
|
690 |
+
other training hyperparmeters in a way that avoids padding. This function
|
691 |
+
helps us compute these hyperparameters.
|
692 |
+
|
693 |
+
We assume that each noise span in the input is replaced by
|
694 |
+
extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the
|
695 |
+
targets is replaced by extra_tokens_per_span_targets sentinel tokens.
|
696 |
+
|
697 |
+
This function tells us the required number of tokens in the raw example (for
|
698 |
+
split_tokens()) as well as the length of the encoded targets.
|
699 |
+
|
700 |
+
Note that this function assumes the inputs and targets will have EOS appended
|
701 |
+
and includes that in the reported length.
|
702 |
+
|
703 |
+
Args:
|
704 |
+
inputs_length: an integer - desired length of the tokenized inputs sequence
|
705 |
+
noise_density: a float
|
706 |
+
mean_noise_span_length: a float
|
707 |
+
extra_tokens_per_span_inputs: an integer
|
708 |
+
extra_tokens_per_span_targets: an integer
|
709 |
+
verbose: a bool indicating whether to log sequence lengths
|
710 |
+
Returns:
|
711 |
+
tokens_length: length of original text in tokens
|
712 |
+
targets_length: an integer - length in tokens of encoded targets sequence
|
713 |
+
"""
|
714 |
+
|
715 |
+
if extra_tokens_per_span_inputs != 1:
|
716 |
+
raise NotImplementedError(
|
717 |
+
"extra_tokens_per_span_inputs != 1 not supported yet. You need to check"
|
718 |
+
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
|
719 |
+
)
|
720 |
+
if extra_tokens_per_span_targets != 1:
|
721 |
+
raise NotImplementedError(
|
722 |
+
"extra_tokens_per_span_targets != 1 not supported yet. You need to check"
|
723 |
+
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
|
724 |
+
)
|
725 |
+
|
726 |
+
def _tokens_length_to_inputs_length_targets_length(tokens_length):
|
727 |
+
num_noise_tokens = int(round(tokens_length * noise_density))
|
728 |
+
num_nonnoise_tokens = tokens_length - num_noise_tokens
|
729 |
+
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
|
730 |
+
# inputs contain all nonnoise tokens, sentinels for all noise spans
|
731 |
+
# and one EOS token.
|
732 |
+
return (
|
733 |
+
num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
|
734 |
+
num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
|
735 |
+
)
|
736 |
+
|
737 |
+
tokens_length = inputs_length - 1
|
738 |
+
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
|
739 |
+
tokens_length += 1
|
740 |
+
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
|
741 |
+
# minor hack to get the targets length to be equal to inputs length
|
742 |
+
# which is more likely to have been set to a nice round number.
|
743 |
+
if noise_density == 0.5 and targets_length > inputs_length:
|
744 |
+
tokens_length -= 1
|
745 |
+
targets_length -= 1
|
746 |
+
if verbose:
|
747 |
+
logging.info(
|
748 |
+
"tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ",
|
749 |
+
tokens_length,
|
750 |
+
inputs_length,
|
751 |
+
targets_length,
|
752 |
+
noise_density,
|
753 |
+
mean_noise_span_length,
|
754 |
+
)
|
755 |
+
return tokens_length, targets_length
|
m4/training/setup_language_model.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from m4.models.vbloom.configuration_vbloom import VBloomConfig
|
4 |
+
from m4.models.vbloom.modeling_vbloom import VBloomForCausalLM
|
5 |
+
from m4.models.vgpt2.configuration_vgpt2 import VGPT2Config
|
6 |
+
from m4.models.vgpt2.modeling_vgpt2 import VGPT2LMHeadModel
|
7 |
+
from m4.models.vgpt_neo.configuration_vgpt_neo import VGPTNeoConfig
|
8 |
+
from m4.models.vgpt_neo.modeling_vgpt_neo import VGPTNeoForCausalLM
|
9 |
+
from m4.models.vllama.configuration_vllama import VLlamaConfig
|
10 |
+
from m4.models.vllama.modeling_vllama import VLlamaForCausalLM
|
11 |
+
from m4.models.vopt.configuration_vopt import VOPTConfig
|
12 |
+
from m4.models.vopt.modeling_vopt import VOPTForCausalLM
|
13 |
+
from m4.models.vt5.configuration_vt5 import VT5Config
|
14 |
+
from m4.models.vt5.modeling_vt5 import VT5ForConditionalGeneration
|
15 |
+
|
16 |
+
|
17 |
+
model_name2classes = {
|
18 |
+
r"bloom|bigscience-small-testing": [VBloomConfig, VBloomForCausalLM],
|
19 |
+
r"gpt-neo|gptneo": [VGPTNeoConfig, VGPTNeoForCausalLM],
|
20 |
+
r"gpt2": [VGPT2Config, VGPT2LMHeadModel],
|
21 |
+
r"opt": [VOPTConfig, VOPTForCausalLM],
|
22 |
+
r"t5": [VT5Config, VT5ForConditionalGeneration],
|
23 |
+
r"llama": [VLlamaConfig, VLlamaForCausalLM],
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def model_name_to_classes(model_name_or_path):
|
28 |
+
"""returns config_class, model_class for a given model name or path"""
|
29 |
+
|
30 |
+
model_name_lowcase = model_name_or_path.lower()
|
31 |
+
for rx, classes in model_name2classes.items():
|
32 |
+
if re.search(rx, model_name_lowcase):
|
33 |
+
return classes
|
34 |
+
else:
|
35 |
+
raise ValueError(
|
36 |
+
f"Unknown type of backbone LM. Got {model_name_or_path}, supported regexes:"
|
37 |
+
f" {list(model_name2classes.keys())}."
|
38 |
+
)
|
m4/training/setup_vision_model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from transformers import AutoModel
|
4 |
+
|
5 |
+
|
6 |
+
# map to check the supported cv archs and also how to extract the model - in some arch, we want to
|
7 |
+
# go through a specific prefix to get to the model as in `model.vision_model` for clip
|
8 |
+
vision_model_name2model = {
|
9 |
+
r"clip": lambda model: model.vision_model,
|
10 |
+
r"vit": lambda model: model,
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
def vision_model_name_to_model(model_name_or_path, model):
|
15 |
+
"""returns the model if supported, asserts otherwise"""
|
16 |
+
|
17 |
+
model_name_lowcase = model_name_or_path.lower()
|
18 |
+
for rx, lookup in vision_model_name2model.items():
|
19 |
+
if re.search(rx, model_name_lowcase):
|
20 |
+
return lookup(model)
|
21 |
+
else:
|
22 |
+
raise ValueError(
|
23 |
+
f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:"
|
24 |
+
f" {list(vision_model_name2model.keys())}."
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def get_vision_model(config):
|
29 |
+
vision_model_name = config.vision_model_name
|
30 |
+
vision_model_params = eval(config.vision_model_params)
|
31 |
+
|
32 |
+
model = AutoModel.from_pretrained(vision_model_name, **vision_model_params)
|
33 |
+
return vision_model_name_to_model(vision_model_name, model)
|
m4/training/types.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class DatasetNames(Enum):
|
5 |
+
PMD = "pmd"
|
6 |
+
LAION = "laion"
|
7 |
+
CM4 = "cm4"
|
8 |
+
WIKI = "wiki"
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetTypes(Enum):
|
12 |
+
WEB_DOCUMENTS = "wd"
|
13 |
+
IMAGE_CAPTION_PAIRS = "icp"
|
m4/training/utils.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import gc
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from enum import Enum
|
7 |
+
|
8 |
+
import accelerate
|
9 |
+
import psutil
|
10 |
+
import pynvml
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
from accelerate.state import AcceleratorState
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import ( # AddedToken is needed for the eval of the tokenizer params # noqa: F401
|
17 |
+
AddedToken,
|
18 |
+
AutoTokenizer,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
IMAGE_TOKEN = "<image>"
|
23 |
+
FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>"
|
24 |
+
FAKE_TOKEN_AROUND_IMAGE_V1 = "\n\n"
|
25 |
+
# Originally taken from the values used in OpenCLIP
|
26 |
+
IMAGE_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
27 |
+
IMAGE_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class LoggingTypes(Enum):
|
32 |
+
"""Types of logging to use for the gradient and parameter statistics"""
|
33 |
+
|
34 |
+
JSONL = "jsonl"
|
35 |
+
WANDB = "wandb"
|
36 |
+
PRINT = "print"
|
37 |
+
|
38 |
+
|
39 |
+
class JSONEncoderForDataclasses(json.JSONEncoder):
|
40 |
+
"""
|
41 |
+
Use to serialize dataclass object, like so:
|
42 |
+
json.dump(data, fp, indent=2, cls=JSONEncoderForDataclasses)
|
43 |
+
"""
|
44 |
+
|
45 |
+
def default(self, obj):
|
46 |
+
if dataclasses.is_dataclass(obj):
|
47 |
+
return dataclasses.asdict(obj)
|
48 |
+
return super().default(obj)
|
49 |
+
|
50 |
+
|
51 |
+
def freeze_model(model, module_exceptions=[]):
|
52 |
+
mapping = {
|
53 |
+
"LayerNorm": nn.LayerNorm,
|
54 |
+
"Linear": nn.Linear,
|
55 |
+
"Embedding": nn.Embedding,
|
56 |
+
}
|
57 |
+
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
|
58 |
+
for module in model.modules():
|
59 |
+
if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
|
60 |
+
module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes
|
61 |
+
else:
|
62 |
+
module.requires_grad_(False)
|
63 |
+
return model
|
64 |
+
|
65 |
+
|
66 |
+
def _convert_to_rgb(image):
|
67 |
+
# `image.convert("RGB")` would only work for .jpg images, as it creates
|
68 |
+
# a wrong background for transparent images. The call to `alpha_composite`
|
69 |
+
# handles this case
|
70 |
+
if image.mode == "RGB":
|
71 |
+
return image
|
72 |
+
image_rgba = image.convert("RGBA")
|
73 |
+
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
74 |
+
alpha_composite = Image.alpha_composite(background, image_rgba)
|
75 |
+
alpha_composite = alpha_composite.convert("RGB")
|
76 |
+
return alpha_composite
|
77 |
+
|
78 |
+
|
79 |
+
# TODO(aps): Take parameters from config
|
80 |
+
def build_image_transform(image_size=224, eval=False):
|
81 |
+
return transforms.Compose(
|
82 |
+
[
|
83 |
+
_convert_to_rgb,
|
84 |
+
(
|
85 |
+
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC)
|
86 |
+
if eval
|
87 |
+
else transforms.RandomResizedCrop(
|
88 |
+
(image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
|
89 |
+
)
|
90 |
+
),
|
91 |
+
transforms.ToTensor(),
|
92 |
+
transforms.Normalize(mean=IMAGE_DATASET_MEAN, std=IMAGE_DATASET_STD),
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def get_tokenizer(
|
98 |
+
tokenizer_name: str,
|
99 |
+
tokenizer_add_tokens,
|
100 |
+
tokenizer_add_special_tokens,
|
101 |
+
tokenizer_params,
|
102 |
+
additional_vocab_size,
|
103 |
+
model_vocab_size=None,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
We artificially separate `tokenizer_add_tokens` and `tokenizer_add_special_tokens` is a dictionary whose keys only takes into account special tokens (eos, pad, cls, etc.).
|
107 |
+
On the contrary, `tokenizer_add_tokens` is a list of string of `AddedToken`.
|
108 |
+
In practise, we use `tokenizer.add_special_tokens` to add all of these new special tokens or update the existing ones.
|
109 |
+
|
110 |
+
NB: we constraint to tokenizer to be a fast tokenizer because with the slow tokenizer, we can't set the arguments of the added tokens (cf `.add_tokens`) and by default, the separators are stripped.
|
111 |
+
"""
|
112 |
+
tokenizer_params = eval(tokenizer_params)
|
113 |
+
assert isinstance(tokenizer_params, dict)
|
114 |
+
|
115 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_params)
|
116 |
+
|
117 |
+
if model_vocab_size is not None:
|
118 |
+
if model_vocab_size > len(tokenizer):
|
119 |
+
logger.warning(
|
120 |
+
f"The model vocabulary size ({model_vocab_size}) is larger than the tokenizer vocabulary size "
|
121 |
+
f"({len(tokenizer)}). Updating the tokenizer to match."
|
122 |
+
)
|
123 |
+
if "additional_special_tokens" in tokenizer_params:
|
124 |
+
raise ValueError(
|
125 |
+
"You can't use `additional_special_tokens` in `tokenizer_params` with a model vocab "
|
126 |
+
"size > tokenizer vocab size. We need to adjust tokenizer before adding special "
|
127 |
+
"tokens. Please use `tokenizer_add_tokens` instead."
|
128 |
+
)
|
129 |
+
# We need to pad the tokenizer vocab with fake tokens
|
130 |
+
tokenizer.add_tokens(["<fake_token_{}>".format(i) for i in range(model_vocab_size - len(tokenizer))])
|
131 |
+
|
132 |
+
assert str(eval(tokenizer_add_tokens)[-1]) == IMAGE_TOKEN
|
133 |
+
assert str(eval(tokenizer_add_tokens)[-2]) == FAKE_TOKEN_AROUND_IMAGE_V2
|
134 |
+
# This check ensures that the image token and the fake token around it will be in the `DecoupledEmbedding.additional_weight`.
|
135 |
+
existing_special_tokens = (
|
136 |
+
[*tokenizer.special_tokens_map_extended["additional_special_tokens"]]
|
137 |
+
if "additional_special_tokens" in tokenizer.special_tokens_map_extended
|
138 |
+
else []
|
139 |
+
)
|
140 |
+
add_special_tokens_dict = {"additional_special_tokens": existing_special_tokens + eval(tokenizer_add_tokens)}
|
141 |
+
if tokenizer_add_special_tokens is not None:
|
142 |
+
add_special_tokens_dict.update(eval(tokenizer_add_special_tokens))
|
143 |
+
|
144 |
+
tokenizer.add_special_tokens(add_special_tokens_dict)
|
145 |
+
|
146 |
+
assert IMAGE_TOKEN in tokenizer.convert_ids_to_tokens(
|
147 |
+
[idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))]
|
148 |
+
)
|
149 |
+
assert FAKE_TOKEN_AROUND_IMAGE_V2 in tokenizer.convert_ids_to_tokens(
|
150 |
+
[idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))]
|
151 |
+
)
|
152 |
+
# This verifies that `<image>` was correctly added to the tokenizer vocabulary
|
153 |
+
# XXX: opt-1.3b fails here
|
154 |
+
# assert tokenizer.is_fast == tokenizer_params.get("use_fast", True)
|
155 |
+
|
156 |
+
return tokenizer
|
157 |
+
|
158 |
+
|
159 |
+
def pynmvl_handle(accelerator):
|
160 |
+
if not torch.cuda.is_available():
|
161 |
+
return None
|
162 |
+
|
163 |
+
pynvml.nvmlInit()
|
164 |
+
return pynvml.nvmlDeviceGetHandleByIndex(accelerator.local_process_index)
|
165 |
+
|
166 |
+
|
167 |
+
def pynvml_get_total_energy_in_joules(handle):
|
168 |
+
if not torch.cuda.is_available():
|
169 |
+
return 0
|
170 |
+
return pynvml.nvmlDeviceGetTotalEnergyConsumption(handle) / 1000
|
171 |
+
|
172 |
+
|
173 |
+
def compute_tflops_per_batch_per_gpu(
|
174 |
+
num_layers,
|
175 |
+
batch_size,
|
176 |
+
q_seq_len,
|
177 |
+
k_seq_len,
|
178 |
+
hidden_size,
|
179 |
+
kv_in_dim,
|
180 |
+
ff_exp_factor=None,
|
181 |
+
grad_acc_size=1,
|
182 |
+
swiglu=False,
|
183 |
+
vocab_size=None,
|
184 |
+
count_backward=False,
|
185 |
+
use_grad_checkpointing=False,
|
186 |
+
):
|
187 |
+
multiply_add_factor = torch.tensor(2)
|
188 |
+
query_transformation = multiply_add_factor * batch_size * q_seq_len * hidden_size**2
|
189 |
+
# k_seq_len == v_seq_len
|
190 |
+
key_value_transformation = multiply_add_factor * batch_size * k_seq_len * (2 * hidden_size * kv_in_dim)
|
191 |
+
attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size
|
192 |
+
attention_softmax = multiply_add_factor * q_seq_len * k_seq_len
|
193 |
+
att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size
|
194 |
+
post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * hidden_size**2
|
195 |
+
|
196 |
+
# There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size
|
197 |
+
# When using a classic decoder, some blocks don't have those feed-forward layers
|
198 |
+
# Swiglu duplicates the first linear layer, so we have to account for 3 of them instead of 2
|
199 |
+
if ff_exp_factor and swiglu:
|
200 |
+
expansion_linear_layers = 3 * (
|
201 |
+
multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size
|
202 |
+
)
|
203 |
+
elif ff_exp_factor:
|
204 |
+
expansion_linear_layers = 2 * (
|
205 |
+
multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
expansion_linear_layers = torch.tensor(0)
|
209 |
+
|
210 |
+
transformer_block_flops = (
|
211 |
+
query_transformation
|
212 |
+
+ key_value_transformation
|
213 |
+
+ attention_matrix_computation
|
214 |
+
+ attention_softmax
|
215 |
+
+ att_over_values_computation
|
216 |
+
+ post_attention_linear_proj
|
217 |
+
+ expansion_linear_layers
|
218 |
+
)
|
219 |
+
|
220 |
+
# This computation should only be added if the model has a language head
|
221 |
+
if vocab_size:
|
222 |
+
language_head_computation = multiply_add_factor * batch_size * q_seq_len * hidden_size * vocab_size
|
223 |
+
else:
|
224 |
+
language_head_computation = torch.tensor(0)
|
225 |
+
|
226 |
+
forward_fact = 1
|
227 |
+
backward_factor = 2 if count_backward else 0
|
228 |
+
grad_checkpointing_factor = 1 if use_grad_checkpointing else 0
|
229 |
+
model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (
|
230 |
+
num_layers * transformer_block_flops + language_head_computation
|
231 |
+
)
|
232 |
+
model_tflops = model_flops / (10**12)
|
233 |
+
|
234 |
+
return model_tflops
|
235 |
+
|
236 |
+
|
237 |
+
def compute_perceiver_tflops_per_batch_per_gpu(
|
238 |
+
num_layers,
|
239 |
+
batch_size,
|
240 |
+
q_seq_len,
|
241 |
+
vision_embed_seq_len,
|
242 |
+
q_k_v_input_dim,
|
243 |
+
attention_hidden_size,
|
244 |
+
ff_exp_factor=None,
|
245 |
+
count_backward=False,
|
246 |
+
use_grad_checkpointing=False,
|
247 |
+
):
|
248 |
+
multiply_add_factor = torch.tensor(2)
|
249 |
+
query_transformation = multiply_add_factor * batch_size * q_seq_len * q_k_v_input_dim * attention_hidden_size
|
250 |
+
# k_seq_len == v_seq_len
|
251 |
+
key_value_transformation = (
|
252 |
+
multiply_add_factor * batch_size * vision_embed_seq_len * (2 * attention_hidden_size * q_k_v_input_dim)
|
253 |
+
)
|
254 |
+
|
255 |
+
k_seq_len = vision_embed_seq_len + q_seq_len
|
256 |
+
attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size
|
257 |
+
attention_softmax = multiply_add_factor * q_seq_len * k_seq_len
|
258 |
+
att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size
|
259 |
+
post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * attention_hidden_size * q_k_v_input_dim
|
260 |
+
|
261 |
+
# There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size
|
262 |
+
# When using a classic decoder, some blocks don't have those feed-forward layers
|
263 |
+
if ff_exp_factor:
|
264 |
+
expansion_linear_layers = 2 * (
|
265 |
+
multiply_add_factor * batch_size * q_seq_len * (q_k_v_input_dim * ff_exp_factor) * q_k_v_input_dim
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
expansion_linear_layers = torch.tensor(0)
|
269 |
+
|
270 |
+
transformer_block_flops = (
|
271 |
+
query_transformation
|
272 |
+
+ key_value_transformation
|
273 |
+
+ attention_matrix_computation
|
274 |
+
+ attention_softmax
|
275 |
+
+ att_over_values_computation
|
276 |
+
+ post_attention_linear_proj
|
277 |
+
+ expansion_linear_layers
|
278 |
+
)
|
279 |
+
|
280 |
+
forward_fact = 1
|
281 |
+
backward_factor = 2 if count_backward else 0
|
282 |
+
grad_checkpointing_factor = 1 if use_grad_checkpointing else 0
|
283 |
+
model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (num_layers * transformer_block_flops)
|
284 |
+
model_tflops = model_flops / (10**12)
|
285 |
+
|
286 |
+
return model_tflops
|
287 |
+
|
288 |
+
|
289 |
+
def mem_usage_formatted(logging_type=LoggingTypes.PRINT):
|
290 |
+
# adapted from deepspeed's see_memory_usage
|
291 |
+
|
292 |
+
torch.cuda.empty_cache()
|
293 |
+
|
294 |
+
# python doesn't do real-time garbage collection so do it explicitly to get the correct usage reports
|
295 |
+
gc.collect()
|
296 |
+
vm_stats = psutil.virtual_memory()
|
297 |
+
|
298 |
+
mem = {
|
299 |
+
"gpu mem alloc": f"{torch.cuda.memory_allocated()/2**30:0.2f}GB",
|
300 |
+
"max alloc": f"{torch.cuda.max_memory_allocated()/2**30:0.2f}GB",
|
301 |
+
"reserv": f"{torch.cuda.memory_reserved()/2**30:0.2f}GB",
|
302 |
+
"max reserv": f"{torch.cuda.max_memory_reserved()/2**30:0.2f}GB",
|
303 |
+
"cpu vm used": f"{(vm_stats.total-vm_stats.available)/2**30:0.2f}GB {vm_stats.percent}%",
|
304 |
+
}
|
305 |
+
|
306 |
+
if logging_type == LoggingTypes.PRINT:
|
307 |
+
mem = " | ".join([f"{k}: {v}" for k, v in mem.items()]) + " | "
|
308 |
+
|
309 |
+
# get the peak memory to report correct data, so reset the max_memory_allocated counter for the next call
|
310 |
+
torch.cuda.reset_peak_memory_stats()
|
311 |
+
|
312 |
+
return mem
|
313 |
+
|
314 |
+
|
315 |
+
def is_deepspeed_used():
|
316 |
+
deepspeed_plugin = get_deepspeed_plugin()
|
317 |
+
return deepspeed_plugin is not None
|
318 |
+
|
319 |
+
|
320 |
+
def get_deepspeed_stage():
|
321 |
+
deepspeed_plugin = get_deepspeed_plugin()
|
322 |
+
if deepspeed_plugin is None:
|
323 |
+
return 0
|
324 |
+
ds_config = deepspeed_plugin.deepspeed_config
|
325 |
+
stage = ds_config.get("zero_optimization", {}).get("stage", 0)
|
326 |
+
# from accelerate>=0.17.1 can do instead:
|
327 |
+
# stage = deepspeed_plugin.zero_stage
|
328 |
+
return stage
|
329 |
+
|
330 |
+
|
331 |
+
def is_deepspeed_zero3_used():
|
332 |
+
return get_deepspeed_stage() == 3
|
333 |
+
|
334 |
+
|
335 |
+
def accelerate_torch_dtype():
|
336 |
+
"""
|
337 |
+
derive and return `torch_dtype` to be used in `from_pretrained` from either Deepspeed config or if
|
338 |
+
Deepspeed isn't used than accelerator state
|
339 |
+
"""
|
340 |
+
if not is_accelerate_initialized():
|
341 |
+
return None
|
342 |
+
|
343 |
+
accelerator_state = AcceleratorState()
|
344 |
+
|
345 |
+
if is_deepspeed_used():
|
346 |
+
deepspeed_plugin = accelerator_state.deepspeed_plugin
|
347 |
+
ds_config = deepspeed_plugin.deepspeed_config
|
348 |
+
if ds_config.get("fp16", {}).get("enabled", False):
|
349 |
+
torch_dtype = torch.float16
|
350 |
+
elif ds_config.get("bf16", {}).get("enabled", False):
|
351 |
+
torch_dtype = torch.bfloat16
|
352 |
+
else:
|
353 |
+
torch_dtype = None
|
354 |
+
else: # no Deepspeed
|
355 |
+
if accelerator_state.mixed_precision == "fp16":
|
356 |
+
torch_dtype = torch.float16
|
357 |
+
elif accelerator_state.mixed_precision == "bf16":
|
358 |
+
torch_dtype = torch.bfloat16
|
359 |
+
else:
|
360 |
+
torch_dtype = None
|
361 |
+
|
362 |
+
return torch_dtype
|
363 |
+
|
364 |
+
|
365 |
+
def is_accelerate_initialized():
|
366 |
+
return accelerate.state.is_initialized()
|
367 |
+
|
368 |
+
|
369 |
+
def get_deepspeed_plugin():
|
370 |
+
if is_accelerate_initialized():
|
371 |
+
return AcceleratorState().deepspeed_plugin
|
372 |
+
else:
|
373 |
+
return None
|
374 |
+
|
375 |
+
|
376 |
+
def get_deepspeed_engine(accelerator):
|
377 |
+
return accelerator.deepspeed_engine_wrapped.engine
|
378 |
+
|
379 |
+
|
380 |
+
def is_deepspeed_zero_init_enabled():
|
381 |
+
deepspeed_plugin = get_deepspeed_plugin()
|
382 |
+
if deepspeed_plugin is not None:
|
383 |
+
return deepspeed_plugin.is_zero3_init_enabled()
|
384 |
+
else:
|
385 |
+
return False
|
386 |
+
|
387 |
+
|
388 |
+
@contextmanager
|
389 |
+
def hf_trainer_disable_zero3_init_context_manager():
|
390 |
+
# monkey patch hack to emulate a context that has zero_init disabled as it's used in
|
391 |
+
# modeling_utils.py in transformers for from_config and from_pretrained.
|
392 |
+
import transformers.modeling_utils # noqa
|
393 |
+
|
394 |
+
orig = transformers.modeling_utils.is_deepspeed_zero3_enabled
|
395 |
+
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: False
|
396 |
+
yield
|
397 |
+
transformers.modeling_utils.is_deepspeed_zero3_enabled = orig
|
398 |
+
|
399 |
+
|
400 |
+
def deepspeed_zero_init_disabled_context_manager():
|
401 |
+
"""
|
402 |
+
returns either a context list that includes one that will disable zero.Init or an empty context list
|
403 |
+
"""
|
404 |
+
deepspeed_plugin = get_deepspeed_plugin()
|
405 |
+
if deepspeed_plugin is not None:
|
406 |
+
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
407 |
+
else:
|
408 |
+
return [hf_trainer_disable_zero3_init_context_manager()]
|
409 |
+
|
410 |
+
|
411 |
+
def deepspeed_gathered_parameters_context_manager(params, modify=True):
|
412 |
+
"""
|
413 |
+
Under zero.Init returns a context manager that will gather the sharded param, otherwise returns an empty list
|
414 |
+
|
415 |
+
If `modify` is `True`, gather the shards and once the context exits update the shards with the
|
416 |
+
modified data - one wants that when modifying the gathered param. If one wants to just gather
|
417 |
+
the shards in order to read the param and no modifications are done to it, use `modify=False` as
|
418 |
+
it's more efficient.
|
419 |
+
|
420 |
+
`params` - can be a single parameter, a list, or a tuple of parameters to collect.
|
421 |
+
|
422 |
+
Example:
|
423 |
+
|
424 |
+
from transformers.utils import ContextManagers
|
425 |
+
from m4.training.utils import deepspeed_gathered_parameters_context_manager
|
426 |
+
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
|
427 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
428 |
+
if module.padding_idx is not None:
|
429 |
+
module.weight.data[module.padding_idx].zero_()
|
430 |
+
|
431 |
+
|
432 |
+
"""
|
433 |
+
if is_deepspeed_zero_init_enabled():
|
434 |
+
import deepspeed
|
435 |
+
|
436 |
+
# 0 is for updating `params` shards after modifying it, `None` is for read-only (only gather)
|
437 |
+
modifier_rank = 0 if modify else None
|
438 |
+
return [deepspeed.zero.GatheredParameters(params, modifier_rank=modifier_rank)]
|
439 |
+
else:
|
440 |
+
return []
|
441 |
+
|
442 |
+
|
443 |
+
# adapted from https://github.com/huggingface/transformers/blob/a081f292ca8479eaf66d7396186021268f128829/src/transformers/modeling_utils.py#L438-L496
|
444 |
+
# as it appears to be a private function
|
445 |
+
def load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
446 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
447 |
+
old_keys = []
|
448 |
+
new_keys = []
|
449 |
+
for key in state_dict.keys():
|
450 |
+
new_key = None
|
451 |
+
if "gamma" in key:
|
452 |
+
new_key = key.replace("gamma", "weight")
|
453 |
+
if "beta" in key:
|
454 |
+
new_key = key.replace("beta", "bias")
|
455 |
+
if new_key:
|
456 |
+
old_keys.append(key)
|
457 |
+
new_keys.append(new_key)
|
458 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
459 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
460 |
+
|
461 |
+
# copy state_dict so _load_from_state_dict can modify it
|
462 |
+
metadata = getattr(state_dict, "_metadata", None)
|
463 |
+
state_dict = state_dict.copy()
|
464 |
+
if metadata is not None:
|
465 |
+
state_dict._metadata = metadata
|
466 |
+
|
467 |
+
error_msgs = []
|
468 |
+
|
469 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
470 |
+
# so we need to apply the function recursively.
|
471 |
+
def load(module: torch.nn.Module, state_dict, prefix=""):
|
472 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
473 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
474 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
475 |
+
# state_dict
|
476 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
477 |
+
if is_deepspeed_zero_init_enabled():
|
478 |
+
import deepspeed
|
479 |
+
|
480 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
481 |
+
# parameters that are in the current state_dict.
|
482 |
+
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
483 |
+
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
484 |
+
if len(params_to_gather) > 0:
|
485 |
+
# because zero3 puts placeholders in model params, this context
|
486 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
487 |
+
# the state dict and then re-partitions them again
|
488 |
+
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
489 |
+
if torch.distributed.get_rank() == 0:
|
490 |
+
module._load_from_state_dict(*args)
|
491 |
+
else:
|
492 |
+
module._load_from_state_dict(*args)
|
493 |
+
|
494 |
+
for name, child in module._modules.items():
|
495 |
+
if child is not None:
|
496 |
+
load(child, state_dict, prefix + name + ".")
|
497 |
+
|
498 |
+
load(model_to_load, state_dict, prefix=start_prefix)
|
499 |
+
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
500 |
+
# it's safe to delete it.
|
501 |
+
del state_dict
|
502 |
+
|
503 |
+
return error_msgs
|
504 |
+
|
505 |
+
|
506 |
+
def get_stats(var, ctx):
|
507 |
+
if var is None:
|
508 |
+
return {}
|
509 |
+
var = var.float()
|
510 |
+
abs_var = var.abs()
|
511 |
+
return {
|
512 |
+
f"{ctx}_var_min": var.min().item(),
|
513 |
+
f"{ctx}_var_max": var.max().item(),
|
514 |
+
f"{ctx}_var_mean": var.mean().item(),
|
515 |
+
f"{ctx}_var_std": var.std().item(),
|
516 |
+
f"{ctx}_abs_var_min": abs_var.min().item(),
|
517 |
+
f"{ctx}_abs_var_max": abs_var.max().item(),
|
518 |
+
f"{ctx}_abs_var_mean": abs_var.mean().item(),
|
519 |
+
f"{ctx}_abs_var_std": abs_var.std().item(),
|
520 |
+
f"{ctx}_var_norm_2": (var.norm(p=2) / var.numel()).item(),
|
521 |
+
f"{ctx}_var_norm_1": (var.norm(p=1) / var.numel()).item(),
|
522 |
+
f"{ctx}_nonzero": (var != 0).sum().item(),
|
523 |
+
}
|
524 |
+
|
525 |
+
|
526 |
+
def get_stats_format(ctx):
|
527 |
+
return {
|
528 |
+
f"{ctx}_var_min": "e",
|
529 |
+
f"{ctx}_var_max": "e",
|
530 |
+
f"{ctx}_var_mean": "e",
|
531 |
+
f"{ctx}_var_std": "e",
|
532 |
+
f"{ctx}_abs_var_min": "e",
|
533 |
+
f"{ctx}_abs_var_max": "e",
|
534 |
+
f"{ctx}_abs_var_mean": "e",
|
535 |
+
f"{ctx}_abs_var_std": "e",
|
536 |
+
f"{ctx}_var_norm_2": "e",
|
537 |
+
f"{ctx}_var_norm_1": "e",
|
538 |
+
f"{ctx}_nonzero": "",
|
539 |
+
}
|
m4/utils/__init__.py
ADDED
File without changes
|