Spaces:
Paused
Paused
import dataclasses | |
import gc | |
import json | |
import logging | |
from contextlib import contextmanager | |
from enum import Enum | |
import accelerate | |
import psutil | |
import pynvml | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from accelerate.state import AcceleratorState | |
from PIL import Image | |
from transformers import ( # AddedToken is needed for the eval of the tokenizer params # noqa: F401 | |
AddedToken, | |
AutoTokenizer, | |
) | |
IMAGE_TOKEN = "<image>" | |
FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>" | |
FAKE_TOKEN_AROUND_IMAGE_V1 = "\n\n" | |
# Originally taken from the values used in OpenCLIP | |
IMAGE_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
IMAGE_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | |
logger = logging.getLogger(__name__) | |
class LoggingTypes(Enum): | |
"""Types of logging to use for the gradient and parameter statistics""" | |
JSONL = "jsonl" | |
WANDB = "wandb" | |
PRINT = "print" | |
class JSONEncoderForDataclasses(json.JSONEncoder): | |
""" | |
Use to serialize dataclass object, like so: | |
json.dump(data, fp, indent=2, cls=JSONEncoderForDataclasses) | |
""" | |
def default(self, obj): | |
if dataclasses.is_dataclass(obj): | |
return dataclasses.asdict(obj) | |
return super().default(obj) | |
def freeze_model(model, module_exceptions=[]): | |
mapping = { | |
"LayerNorm": nn.LayerNorm, | |
"Linear": nn.Linear, | |
"Embedding": nn.Embedding, | |
} | |
module_exceptions_mapped = [mapping[m] for m in module_exceptions] | |
for module in model.modules(): | |
if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]): | |
module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes | |
else: | |
module.requires_grad_(False) | |
return model | |
def _convert_to_rgb(image): | |
# `image.convert("RGB")` would only work for .jpg images, as it creates | |
# a wrong background for transparent images. The call to `alpha_composite` | |
# handles this case | |
if image.mode == "RGB": | |
return image | |
image_rgba = image.convert("RGBA") | |
background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
alpha_composite = Image.alpha_composite(background, image_rgba) | |
alpha_composite = alpha_composite.convert("RGB") | |
return alpha_composite | |
# TODO(aps): Take parameters from config | |
def build_image_transform(image_size=224, eval=False): | |
return transforms.Compose( | |
[ | |
_convert_to_rgb, | |
( | |
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC) | |
if eval | |
else transforms.RandomResizedCrop( | |
(image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC | |
) | |
), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=IMAGE_DATASET_MEAN, std=IMAGE_DATASET_STD), | |
] | |
) | |
def get_tokenizer( | |
tokenizer_name: str, | |
tokenizer_add_tokens, | |
tokenizer_add_special_tokens, | |
tokenizer_params, | |
additional_vocab_size, | |
model_vocab_size=None, | |
): | |
""" | |
We artificially separate `tokenizer_add_tokens` and `tokenizer_add_special_tokens` is a dictionary whose keys only takes into account special tokens (eos, pad, cls, etc.). | |
On the contrary, `tokenizer_add_tokens` is a list of string of `AddedToken`. | |
In practise, we use `tokenizer.add_special_tokens` to add all of these new special tokens or update the existing ones. | |
NB: we constraint to tokenizer to be a fast tokenizer because with the slow tokenizer, we can't set the arguments of the added tokens (cf `.add_tokens`) and by default, the separators are stripped. | |
""" | |
tokenizer_params = eval(tokenizer_params) | |
assert isinstance(tokenizer_params, dict) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_params) | |
if model_vocab_size is not None: | |
if model_vocab_size > len(tokenizer): | |
logger.warning( | |
f"The model vocabulary size ({model_vocab_size}) is larger than the tokenizer vocabulary size " | |
f"({len(tokenizer)}). Updating the tokenizer to match." | |
) | |
if "additional_special_tokens" in tokenizer_params: | |
raise ValueError( | |
"You can't use `additional_special_tokens` in `tokenizer_params` with a model vocab " | |
"size > tokenizer vocab size. We need to adjust tokenizer before adding special " | |
"tokens. Please use `tokenizer_add_tokens` instead." | |
) | |
# We need to pad the tokenizer vocab with fake tokens | |
tokenizer.add_tokens(["<fake_token_{}>".format(i) for i in range(model_vocab_size - len(tokenizer))]) | |
assert str(eval(tokenizer_add_tokens)[-1]) == IMAGE_TOKEN | |
assert str(eval(tokenizer_add_tokens)[-2]) == FAKE_TOKEN_AROUND_IMAGE_V2 | |
# This check ensures that the image token and the fake token around it will be in the `DecoupledEmbedding.additional_weight`. | |
existing_special_tokens = ( | |
[*tokenizer.special_tokens_map_extended["additional_special_tokens"]] | |
if "additional_special_tokens" in tokenizer.special_tokens_map_extended | |
else [] | |
) | |
add_special_tokens_dict = {"additional_special_tokens": existing_special_tokens + eval(tokenizer_add_tokens)} | |
if tokenizer_add_special_tokens is not None: | |
add_special_tokens_dict.update(eval(tokenizer_add_special_tokens)) | |
tokenizer.add_special_tokens(add_special_tokens_dict) | |
assert IMAGE_TOKEN in tokenizer.convert_ids_to_tokens( | |
[idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))] | |
) | |
assert FAKE_TOKEN_AROUND_IMAGE_V2 in tokenizer.convert_ids_to_tokens( | |
[idx for idx in range(len(tokenizer) - additional_vocab_size, len(tokenizer))] | |
) | |
# This verifies that `<image>` was correctly added to the tokenizer vocabulary | |
# XXX: opt-1.3b fails here | |
# assert tokenizer.is_fast == tokenizer_params.get("use_fast", True) | |
return tokenizer | |
def pynmvl_handle(accelerator): | |
if not torch.cuda.is_available(): | |
return None | |
pynvml.nvmlInit() | |
return pynvml.nvmlDeviceGetHandleByIndex(accelerator.local_process_index) | |
def pynvml_get_total_energy_in_joules(handle): | |
if not torch.cuda.is_available(): | |
return 0 | |
return pynvml.nvmlDeviceGetTotalEnergyConsumption(handle) / 1000 | |
def compute_tflops_per_batch_per_gpu( | |
num_layers, | |
batch_size, | |
q_seq_len, | |
k_seq_len, | |
hidden_size, | |
kv_in_dim, | |
ff_exp_factor=None, | |
grad_acc_size=1, | |
swiglu=False, | |
vocab_size=None, | |
count_backward=False, | |
use_grad_checkpointing=False, | |
): | |
multiply_add_factor = torch.tensor(2) | |
query_transformation = multiply_add_factor * batch_size * q_seq_len * hidden_size**2 | |
# k_seq_len == v_seq_len | |
key_value_transformation = multiply_add_factor * batch_size * k_seq_len * (2 * hidden_size * kv_in_dim) | |
attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size | |
attention_softmax = multiply_add_factor * q_seq_len * k_seq_len | |
att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * hidden_size | |
post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * hidden_size**2 | |
# There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size | |
# When using a classic decoder, some blocks don't have those feed-forward layers | |
# Swiglu duplicates the first linear layer, so we have to account for 3 of them instead of 2 | |
if ff_exp_factor and swiglu: | |
expansion_linear_layers = 3 * ( | |
multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size | |
) | |
elif ff_exp_factor: | |
expansion_linear_layers = 2 * ( | |
multiply_add_factor * batch_size * q_seq_len * (hidden_size * ff_exp_factor) * hidden_size | |
) | |
else: | |
expansion_linear_layers = torch.tensor(0) | |
transformer_block_flops = ( | |
query_transformation | |
+ key_value_transformation | |
+ attention_matrix_computation | |
+ attention_softmax | |
+ att_over_values_computation | |
+ post_attention_linear_proj | |
+ expansion_linear_layers | |
) | |
# This computation should only be added if the model has a language head | |
if vocab_size: | |
language_head_computation = multiply_add_factor * batch_size * q_seq_len * hidden_size * vocab_size | |
else: | |
language_head_computation = torch.tensor(0) | |
forward_fact = 1 | |
backward_factor = 2 if count_backward else 0 | |
grad_checkpointing_factor = 1 if use_grad_checkpointing else 0 | |
model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * ( | |
num_layers * transformer_block_flops + language_head_computation | |
) | |
model_tflops = model_flops / (10**12) | |
return model_tflops | |
def compute_perceiver_tflops_per_batch_per_gpu( | |
num_layers, | |
batch_size, | |
q_seq_len, | |
vision_embed_seq_len, | |
q_k_v_input_dim, | |
attention_hidden_size, | |
ff_exp_factor=None, | |
count_backward=False, | |
use_grad_checkpointing=False, | |
): | |
multiply_add_factor = torch.tensor(2) | |
query_transformation = multiply_add_factor * batch_size * q_seq_len * q_k_v_input_dim * attention_hidden_size | |
# k_seq_len == v_seq_len | |
key_value_transformation = ( | |
multiply_add_factor * batch_size * vision_embed_seq_len * (2 * attention_hidden_size * q_k_v_input_dim) | |
) | |
k_seq_len = vision_embed_seq_len + q_seq_len | |
attention_matrix_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size | |
attention_softmax = multiply_add_factor * q_seq_len * k_seq_len | |
att_over_values_computation = multiply_add_factor * batch_size * q_seq_len * k_seq_len * attention_hidden_size | |
post_attention_linear_proj = multiply_add_factor * batch_size * q_seq_len * attention_hidden_size * q_k_v_input_dim | |
# There are usually 2 expansion_linear_layers because first one expands, and second one retracts back to hidden_size | |
# When using a classic decoder, some blocks don't have those feed-forward layers | |
if ff_exp_factor: | |
expansion_linear_layers = 2 * ( | |
multiply_add_factor * batch_size * q_seq_len * (q_k_v_input_dim * ff_exp_factor) * q_k_v_input_dim | |
) | |
else: | |
expansion_linear_layers = torch.tensor(0) | |
transformer_block_flops = ( | |
query_transformation | |
+ key_value_transformation | |
+ attention_matrix_computation | |
+ attention_softmax | |
+ att_over_values_computation | |
+ post_attention_linear_proj | |
+ expansion_linear_layers | |
) | |
forward_fact = 1 | |
backward_factor = 2 if count_backward else 0 | |
grad_checkpointing_factor = 1 if use_grad_checkpointing else 0 | |
model_flops = (forward_fact + backward_factor + grad_checkpointing_factor) * (num_layers * transformer_block_flops) | |
model_tflops = model_flops / (10**12) | |
return model_tflops | |
def mem_usage_formatted(logging_type=LoggingTypes.PRINT): | |
# adapted from deepspeed's see_memory_usage | |
torch.cuda.empty_cache() | |
# python doesn't do real-time garbage collection so do it explicitly to get the correct usage reports | |
gc.collect() | |
vm_stats = psutil.virtual_memory() | |
mem = { | |
"gpu mem alloc": f"{torch.cuda.memory_allocated()/2**30:0.2f}GB", | |
"max alloc": f"{torch.cuda.max_memory_allocated()/2**30:0.2f}GB", | |
"reserv": f"{torch.cuda.memory_reserved()/2**30:0.2f}GB", | |
"max reserv": f"{torch.cuda.max_memory_reserved()/2**30:0.2f}GB", | |
"cpu vm used": f"{(vm_stats.total-vm_stats.available)/2**30:0.2f}GB {vm_stats.percent}%", | |
} | |
if logging_type == LoggingTypes.PRINT: | |
mem = " | ".join([f"{k}: {v}" for k, v in mem.items()]) + " | " | |
# get the peak memory to report correct data, so reset the max_memory_allocated counter for the next call | |
torch.cuda.reset_peak_memory_stats() | |
return mem | |
def is_deepspeed_used(): | |
deepspeed_plugin = get_deepspeed_plugin() | |
return deepspeed_plugin is not None | |
def get_deepspeed_stage(): | |
deepspeed_plugin = get_deepspeed_plugin() | |
if deepspeed_plugin is None: | |
return 0 | |
ds_config = deepspeed_plugin.deepspeed_config | |
stage = ds_config.get("zero_optimization", {}).get("stage", 0) | |
# from accelerate>=0.17.1 can do instead: | |
# stage = deepspeed_plugin.zero_stage | |
return stage | |
def is_deepspeed_zero3_used(): | |
return get_deepspeed_stage() == 3 | |
def accelerate_torch_dtype(): | |
""" | |
derive and return `torch_dtype` to be used in `from_pretrained` from either Deepspeed config or if | |
Deepspeed isn't used than accelerator state | |
""" | |
if not is_accelerate_initialized(): | |
return None | |
accelerator_state = AcceleratorState() | |
if is_deepspeed_used(): | |
deepspeed_plugin = accelerator_state.deepspeed_plugin | |
ds_config = deepspeed_plugin.deepspeed_config | |
if ds_config.get("fp16", {}).get("enabled", False): | |
torch_dtype = torch.float16 | |
elif ds_config.get("bf16", {}).get("enabled", False): | |
torch_dtype = torch.bfloat16 | |
else: | |
torch_dtype = None | |
else: # no Deepspeed | |
if accelerator_state.mixed_precision == "fp16": | |
torch_dtype = torch.float16 | |
elif accelerator_state.mixed_precision == "bf16": | |
torch_dtype = torch.bfloat16 | |
else: | |
torch_dtype = None | |
return torch_dtype | |
def is_accelerate_initialized(): | |
return accelerate.state.is_initialized() | |
def get_deepspeed_plugin(): | |
if is_accelerate_initialized(): | |
return AcceleratorState().deepspeed_plugin | |
else: | |
return None | |
def get_deepspeed_engine(accelerator): | |
return accelerator.deepspeed_engine_wrapped.engine | |
def is_deepspeed_zero_init_enabled(): | |
deepspeed_plugin = get_deepspeed_plugin() | |
if deepspeed_plugin is not None: | |
return deepspeed_plugin.is_zero3_init_enabled() | |
else: | |
return False | |
def hf_trainer_disable_zero3_init_context_manager(): | |
# monkey patch hack to emulate a context that has zero_init disabled as it's used in | |
# modeling_utils.py in transformers for from_config and from_pretrained. | |
import transformers.modeling_utils # noqa | |
orig = transformers.modeling_utils.is_deepspeed_zero3_enabled | |
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: False | |
yield | |
transformers.modeling_utils.is_deepspeed_zero3_enabled = orig | |
def deepspeed_zero_init_disabled_context_manager(): | |
""" | |
returns either a context list that includes one that will disable zero.Init or an empty context list | |
""" | |
deepspeed_plugin = get_deepspeed_plugin() | |
if deepspeed_plugin is not None: | |
return [deepspeed_plugin.zero3_init_context_manager(enable=False)] | |
else: | |
return [hf_trainer_disable_zero3_init_context_manager()] | |
def deepspeed_gathered_parameters_context_manager(params, modify=True): | |
""" | |
Under zero.Init returns a context manager that will gather the sharded param, otherwise returns an empty list | |
If `modify` is `True`, gather the shards and once the context exits update the shards with the | |
modified data - one wants that when modifying the gathered param. If one wants to just gather | |
the shards in order to read the param and no modifications are done to it, use `modify=False` as | |
it's more efficient. | |
`params` - can be a single parameter, a list, or a tuple of parameters to collect. | |
Example: | |
from transformers.utils import ContextManagers | |
from m4.training.utils import deepspeed_gathered_parameters_context_manager | |
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
""" | |
if is_deepspeed_zero_init_enabled(): | |
import deepspeed | |
# 0 is for updating `params` shards after modifying it, `None` is for read-only (only gather) | |
modifier_rank = 0 if modify else None | |
return [deepspeed.zero.GatheredParameters(params, modifier_rank=modifier_rank)] | |
else: | |
return [] | |
# adapted from https://github.com/huggingface/transformers/blob/a081f292ca8479eaf66d7396186021268f128829/src/transformers/modeling_utils.py#L438-L496 | |
# as it appears to be a private function | |
def load_state_dict_into_model(model_to_load, state_dict, start_prefix): | |
# Convert old format to new format if needed from a PyTorch state_dict | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
new_key = None | |
if "gamma" in key: | |
new_key = key.replace("gamma", "weight") | |
if "beta" in key: | |
new_key = key.replace("beta", "bias") | |
if new_key: | |
old_keys.append(key) | |
new_keys.append(new_key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, "_metadata", None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
error_msgs = [] | |
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
# so we need to apply the function recursively. | |
def load(module: torch.nn.Module, state_dict, prefix=""): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | |
# Parameters of module and children will start with prefix. We can exit early if there are none in this | |
# state_dict | |
if len([key for key in state_dict if key.startswith(prefix)]) > 0: | |
if is_deepspeed_zero_init_enabled(): | |
import deepspeed | |
# In sharded models, each shard has only part of the full state_dict, so only gather | |
# parameters that are in the current state_dict. | |
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | |
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | |
if len(params_to_gather) > 0: | |
# because zero3 puts placeholders in model params, this context | |
# manager gathers (unpartitions) the params of the current layer, then loads from | |
# the state dict and then re-partitions them again | |
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | |
if torch.distributed.get_rank() == 0: | |
module._load_from_state_dict(*args) | |
else: | |
module._load_from_state_dict(*args) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, state_dict, prefix + name + ".") | |
load(model_to_load, state_dict, prefix=start_prefix) | |
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so | |
# it's safe to delete it. | |
del state_dict | |
return error_msgs | |
def get_stats(var, ctx): | |
if var is None: | |
return {} | |
var = var.float() | |
abs_var = var.abs() | |
return { | |
f"{ctx}_var_min": var.min().item(), | |
f"{ctx}_var_max": var.max().item(), | |
f"{ctx}_var_mean": var.mean().item(), | |
f"{ctx}_var_std": var.std().item(), | |
f"{ctx}_abs_var_min": abs_var.min().item(), | |
f"{ctx}_abs_var_max": abs_var.max().item(), | |
f"{ctx}_abs_var_mean": abs_var.mean().item(), | |
f"{ctx}_abs_var_std": abs_var.std().item(), | |
f"{ctx}_var_norm_2": (var.norm(p=2) / var.numel()).item(), | |
f"{ctx}_var_norm_1": (var.norm(p=1) / var.numel()).item(), | |
f"{ctx}_nonzero": (var != 0).sum().item(), | |
} | |
def get_stats_format(ctx): | |
return { | |
f"{ctx}_var_min": "e", | |
f"{ctx}_var_max": "e", | |
f"{ctx}_var_mean": "e", | |
f"{ctx}_var_std": "e", | |
f"{ctx}_abs_var_min": "e", | |
f"{ctx}_abs_var_max": "e", | |
f"{ctx}_abs_var_mean": "e", | |
f"{ctx}_abs_var_std": "e", | |
f"{ctx}_var_norm_2": "e", | |
f"{ctx}_var_norm_1": "e", | |
f"{ctx}_nonzero": "", | |
} | |