Spaces:
Runtime error
Runtime error
import argparse | |
import torch.distributed as dist | |
from transformers import GPT2TokenizerFast | |
import deepspeed | |
from pathlib import Path | |
import wandb | |
import os | |
import yaml | |
import torch | |
from collections import defaultdict | |
from torchtyping import TensorType | |
import gdown | |
def is_main(): | |
if dist.is_initialized(): | |
return dist.get_rank() == 0 | |
return True | |
def print_main(*msg): | |
if is_main(): | |
print(*msg) | |
def reduce_losses(losses): | |
"""Reduce a tensor of losses across all GPUs.""" | |
if dist.is_initialized(): | |
losses = losses.detach().clone() | |
# We use `all_reduce` because it is better supported than `reduce` | |
dist.all_reduce(losses, dist.ReduceOp.SUM) | |
return losses / dist.get_world_size() | |
else: | |
return losses | |
def cycle(loader): | |
while True: | |
for data in loader: | |
yield data | |
def get_tokenizer(name="gpt2", sequence_length=2048): | |
""" | |
Gets tokenizer for LM | |
""" | |
if name == "gpt2": | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
tokenizer.pad_token_id = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
tokenizer.model_max_length = sequence_length | |
# setup lm settings | |
tokenizer.add_special_tokens( | |
{"cls_token": "<|image|>"} | |
) # add special image token to tokenizer | |
else: | |
raise ValueError(f"Tokenizer {name} not recognized") | |
return tokenizer | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", type=str, required=False, help="path to your training config" | |
) | |
parser.add_argument( | |
"--local_rank", | |
type=int, | |
default=-1, | |
help="local rank passed from distributed launcher", | |
) | |
deepspeed.add_config_arguments(parser) | |
args = parser.parse_args() | |
args.deepspeed = True | |
return args | |
def wandb_log(*args, **kwargs): | |
if is_main(): | |
wandb.log(*args, **kwargs) | |
def wandb_init(*args, **kwargs): | |
if is_main(): | |
wandb.init(*args, **kwargs) | |
def save_model(model_engine, save_dir, global_step, config=None): | |
os.makedirs(save_dir, exist_ok=True) | |
if config is not None: | |
config = config.to_dict() | |
with open(str(Path(save_dir) / "config.yml"), "w") as f: | |
yaml.dump(config, f, default_flow_style=False) | |
sd = {"global_step": global_step, "config": config} | |
model_engine.save_checkpoint(save_dir, client_state=sd) | |
def load_model( | |
model_engine, load_dir, load_optimizer_states=True, load_lr_scheduler_states=True | |
): | |
""" | |
Loads a model from disk and returns the global step to resume from if loading was successful, otherwise returns 0 | |
""" | |
try: | |
load_path, sd = model_engine.load_checkpoint( | |
load_dir, | |
load_optimizer_states=load_optimizer_states, | |
load_lr_scheduler_states=load_lr_scheduler_states, | |
) | |
except AssertionError as e: | |
load_path = None | |
print(e) | |
if load_path is None: | |
print("Model loading failed - starting from global step 0") | |
return 0 | |
return sd["global_step"] | |
def get_params_for_weight_decay_optimization(module, config): | |
""" | |
Divide params into with-weight-decay and without-weight-decay groups. | |
Layernorms and biases will have no weight decay but the rest will. | |
""" | |
weight_decay_params = {"params": []} | |
no_weight_decay_params = {"params": [], "weight_decay": 0.0} | |
blacklist_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
for module_ in module.modules(): | |
if isinstance(module_, blacklist_modules) or ( | |
config.weight_decay == 0.0 | |
): # also include all parameters here if no weight decay is being done | |
no_weight_decay_params["params"].extend( | |
[ | |
p | |
for p in list(module_._parameters.values()) | |
if (p is not None) and p.requires_grad | |
] | |
) | |
else: | |
for n, p in list(module_._parameters.items()): | |
if p is not None and p.requires_grad: | |
if n != "bias": | |
weight_decay_params["params"].append(p) | |
else: | |
no_weight_decay_params["params"].append(p) | |
param_dict = { | |
pn: p | |
for pn, p in module.named_parameters() | |
if p is not None and p.requires_grad | |
} | |
assert len(no_weight_decay_params["params"]) + len( | |
weight_decay_params["params"] | |
) == len( | |
param_dict.keys() | |
), "Number of params in both groups != total number of trainable params" | |
if config.weight_decay == 0.0: | |
# only return a single param group if no weight decay is being used anyway | |
return [no_weight_decay_params] | |
return [weight_decay_params, no_weight_decay_params] | |
def configure_param_groups(model, config): | |
""" | |
Configures the different parameter groups in the model for training. | |
If a separate learning rate for the image prefix is provided, we separate out the groups here. | |
Additionally, parameters to which weight decay shouldn't be applied (layernorms / biases) are separated. | |
""" | |
if config.image_enc_lr is not None: | |
# get the params for the image prefix / proj | |
image_enc_params = get_params_for_weight_decay_optimization( | |
model.image_prefix.enc, config | |
) | |
for pdict in image_enc_params: | |
pdict["lr"] = config.image_enc_lr | |
image_proj_params = get_params_for_weight_decay_optimization( | |
model.image_prefix.proj, config | |
) | |
# get the params for layernorm if it exists | |
if config.use_image_embed_layernorm: | |
image_ln_params = get_params_for_weight_decay_optimization( | |
model.image_prefix.ln, config | |
) | |
image_proj_params += image_ln_params | |
# get the params for the lm | |
lm_params = get_params_for_weight_decay_optimization(model.lm, config) | |
# get params for class head if it exists | |
class_params = [] | |
if hasattr(model, "class_head") and model.class_head is not None: | |
class_params = get_params_for_weight_decay_optimization( | |
model.class_head, config | |
) | |
all_params = [] | |
for p in image_enc_params + lm_params + image_proj_params + class_params: | |
if p["params"]: | |
all_params.append(p) | |
else: | |
all_params = get_params_for_weight_decay_optimization(model, config) | |
# merge param dicts with shared lr / wd values | |
d = defaultdict(dict) | |
for param_group in all_params: | |
lr = param_group.get("lr", None) | |
wd = param_group.get("weight_decay", None) | |
key = f"lr_{lr}_wd_{wd}" | |
if d[key].get("params") is None: | |
d[key]["params"] = [] | |
d[key]["params"].extend(param_group["params"]) | |
if lr is not None: | |
d[key]["lr"] = lr | |
if wd is not None: | |
d[key]["weight_decay"] = wd | |
all_params = list(d.values()) | |
n_params = sum([len(d["params"]) for d in all_params]) | |
param_dict = { | |
pn: p for pn, p in model.named_parameters() if p is not None and p.requires_grad | |
} | |
assert n_params == len( | |
param_dict | |
), f"Some parameters are missing from param groups ({n_params} | {len(param_dict)})" | |
# if we're using multiple param groups, set the min / max lr for each one[] | |
# appropriately in deepspeed's scheduler | |
config.deepspeed_config_params["scheduler"]["params"]["warmup_min_lr"] = [ | |
config.min_lr for _ in all_params | |
] | |
config.deepspeed_config_params["scheduler"]["params"]["warmup_max_lr"] = [ | |
d.get("lr", config.lr) for d in all_params | |
] | |
return all_params | |
def count_parameters(model): | |
""" | |
Counts the number of trainable parameters in a model | |
""" | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def log_table(name, model_outputs, gt_answers_list, global_step): | |
results_table = wandb.Table(columns=["model output", "ground truth(s)"]) | |
for o, gt in zip(model_outputs, gt_answers_list): | |
results_table.add_data(o, gt) | |
wandb_log({f"eval/{name}": results_table}, step=global_step) | |
def get_world_info(): | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
return local_rank, rank, world_size | |
def init_distributed(backend="nccl"): | |
if not torch.distributed.is_initialized(): | |
deepspeed.init_distributed( | |
dist_backend=backend, verbose=True, auto_mpi_discovery=True | |
) | |
local_rank, rank, world_size = get_world_info() | |
torch.cuda.set_device(local_rank) | |
return local_rank, rank, world_size | |
def collate_fn_classification(batch_data, seq_len=2048): | |
# for nvlr2: list(zip*(batch_data)) = [l_images, r_images, captions, class_labels] | |
image_list = list(zip(*batch_data))[:-2] | |
captions, class_labels = list(zip(*batch_data))[-2:] | |
# images, captions, class_labels = list(zip(*batch_data)) | |
images_list = [torch.cat(image) for image in image_list] | |
captions = torch.cat([i[:, :seq_len] for i in captions]) | |
class_labels = torch.stack(class_labels) | |
return images_list, captions, class_labels | |
def infer_checkpoint_path_from_config(config): | |
checkpoint_folder = config.save | |
if checkpoint_folder is None: | |
raise ValueError( | |
"No checkpoint folder specified in config. Please provide a checkpoint." | |
) | |
# check for 'latest' tag in checkpoint folder | |
if (Path(checkpoint_folder) / "latest").exists(): | |
latest_ckpt = (Path(checkpoint_folder) / "latest").read_text().strip() | |
else: | |
raise ValueError( | |
f"No checkpoint found in {checkpoint_folder}. Please provide a checkpoint." | |
) | |
checkpoint_path = str( | |
Path(checkpoint_folder) / latest_ckpt / "mp_rank_00_model_states.pt" | |
) | |
if not Path(checkpoint_path).exists(): | |
raise ValueError( | |
f"No checkpoint found in {checkpoint_path}. Please provide a checkpoint." | |
) | |
return checkpoint_path | |
# [tensor_1, tensor_2], tensor_3, tensor_4 = to_cuda_half([tensor_1, tensor_2], tensor_3, tensor_4) | |
# probably not working yet | |
def to_cuda_half(*args): | |
cuda_half_args = [] | |
for x in args: | |
if isinstance(x, list): | |
x_cuda_half = to_cuda_half(*x) | |
cuda_half_args.append(x_cuda_half) | |
elif isinstance(x, tuple): | |
x_cuda_half = to_cuda_half(*x) | |
cuda_half_args.append(x_cuda_half) | |
else: | |
if x.dtype in [torch.float32, torch.float16]: | |
cuda_half_args.append(x.cuda().half()) | |
elif x.dtype == torch.long: | |
cuda_half_args.append(x.cuda()) | |
if len(cuda_half_args) == 1: | |
return cuda_half_args[0] | |
else: | |
return cuda_half_args | |
def build_labels( | |
input_embeddings: TensorType["b", "s", "d"], | |
captions: TensorType["b", "s"], | |
eos_token, | |
device, | |
) -> TensorType["b", "s"]: | |
""" | |
Builds labels from input embeddings. | |
Masks out the labels with -100 in positions up to the seq length of the embeddings, so loss is only computed for captions, | |
and not for image tokens. | |
Additionally, masks out everything *after* the first eos token. | |
""" | |
shape = input_embeddings.shape[:2] # b, s | |
assert captions.shape[1] >= shape[1] | |
# make sure to add masked embedding tokens in the appropriate locations in the labels | |
embedding_tokens = torch.zeros(shape, dtype=torch.int64).to(device) - 100 | |
labels = torch.cat( | |
(embedding_tokens, captions[:, : -shape[1]]), dim=1 | |
) # we truncate the sequence length of the captions, as they are always padded to the full sequence length | |
# mask out repeating eos tokens | |
for label in labels: | |
for k, token in enumerate(label): | |
if token == eos_token: | |
label[k + 1 :] = -100 | |
break | |
return labels | |
def is_url(string): | |
return string.startswith("http://") or string.startswith("https://") | |
def download_checkpoint(checkpoint_url, save_as): | |
gdown.download(url = checkpoint_url, output = save_as, quiet=False) | |