Spaces:
Runtime error
Runtime error
"""Modified from https://github.com/mlfoundations/open_flamingo""" | |
import argparse | |
import copy | |
import glob | |
import os | |
import random | |
import time | |
import numpy as np | |
import torch | |
import wandb | |
from mmengine import Config | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.data import DataLoader, DistributedSampler | |
from tqdm import tqdm | |
from transformers import ( | |
get_constant_schedule_with_warmup, | |
get_cosine_schedule_with_warmup, | |
get_linear_schedule_with_warmup, | |
) | |
from mmgpt import create_model_and_transforms | |
from mmgpt.models.builder import create_toy_model_and_transforms | |
from mmgpt.datasets import InfiniteSampler, build_dataset | |
from mmgpt.train.distributed import init_distributed_device, world_info_from_env | |
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint | |
def random_seed(seed=42, rank=0): | |
torch.manual_seed(seed + rank) | |
np.random.seed(seed + rank) | |
random.seed(seed + rank) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) | |
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) | |
parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str) | |
parser.add_argument( | |
"--tokenizer_path", | |
default="checkpoints/llama-7b_hf", | |
type=str, | |
help="path to tokenizer", | |
) | |
parser.add_argument( | |
"--pretrained_path", | |
default="checkpoints/OpenFlamingo-9B/checkpoint.pt", | |
type=str, | |
help="path to pretrained model", | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
default="train-my-gpt4", | |
help="used to name saving directory and wandb run", | |
) | |
parser.add_argument("--use_media_placement_augmentation", action="store_true") | |
parser.add_argument("--offline", action="store_true") | |
parser.add_argument("--num_epochs", type=int, default=1) | |
parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps") | |
# Sum of gradient optimization batch size | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", | |
default=None, | |
) | |
parser.add_argument( | |
"--delete_previous_checkpoint", | |
action="store_true", | |
help="delete previous checkpoint when saving new checkpoint", | |
) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--learning_rate", default=1e-5, type=float) | |
parser.add_argument( | |
"--lr_scheduler", | |
default="constant", | |
type=str, | |
help="constant, linear, or cosine", | |
) | |
parser.add_argument("--warmup_steps", default=100, type=int) | |
parser.add_argument("--weight_decay", default=0.1, type=float) | |
parser.add_argument( | |
"--precision", | |
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], | |
default="amp", | |
help="Floating point precision.", | |
) | |
# data args | |
parser.add_argument("--workers", type=int, default=0) | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file") | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=16) | |
# Finetune config | |
parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file") | |
# distributed training args | |
parser.add_argument( | |
"--dist-url", | |
default="env://", | |
type=str, | |
help="url used to set up distributed training", | |
) | |
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") | |
parser.add_argument( | |
"--horovod", | |
default=False, | |
action="store_true", | |
help="Use horovod for distributed training.", | |
) | |
parser.add_argument( | |
"--no-set-device-rank", | |
default=False, | |
action="store_true", | |
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
) | |
# wandb args | |
parser.add_argument("--report_to_wandb", default=False, action="store_true") | |
parser.add_argument( | |
"--wandb_project", | |
type=str, | |
) | |
parser.add_argument( | |
"--wandb_entity", | |
type=str, | |
) | |
parser.add_argument( | |
"--save_checkpoints_to_wandb", | |
default=False, | |
action="store_true", | |
help="save checkpoints to wandb", | |
) | |
args = parser.parse_args() | |
if args.save_checkpoints_to_wandb and not args.report_to_wandb: | |
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") | |
if args.offline: | |
os.environ["WANDB_MODE"] = "offline" | |
os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
args.local_rank, args.rank, args.world_size = world_info_from_env() | |
if args.rank == 0: | |
if not os.path.exists(args.run_name): | |
os.makedirs(args.run_name) | |
device_id = init_distributed_device(args) | |
random_seed(args.seed) | |
if args.tuning_config is not None: | |
tuning_config = Config.fromfile(args.tuning_config) | |
else: | |
raise ValueError("tuning_config must be specified") | |
model, image_processor, tokenizer = create_model_and_transforms( | |
model_name="open_flamingo", | |
clip_vision_encoder_path=args.vision_encoder_path, | |
clip_vision_encoder_pretrained=args.vision_encoder_pretrained, | |
lang_encoder_path=args.lm_path, | |
tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path, | |
use_media_placement_augmentation=args.use_media_placement_augmentation, | |
pretrained_model_path=args.pretrained_path, | |
tuning_config=tuning_config.tuning_config, | |
) | |
if args.dataset_config is not None: | |
dataset_config = Config.fromfile(args.dataset_config) | |
else: | |
raise ValueError("dataset_config must be specified") | |
dataset = build_dataset( | |
dataset_config=dataset_config.visual_datasets, | |
vis_processor=image_processor, | |
tokenizer=tokenizer, | |
) | |
train_dataloader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
num_workers=args.workers, | |
sampler=DistributedSampler(dataset, shuffle=True, drop_last=True), | |
collate_fn=dataset.collater, | |
) | |
# build language dataset and dataloader for multi-modality training | |
if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0: | |
lang_dataset = build_dataset( | |
dataset_config=dataset_config.language_datasets, | |
tokenizer=tokenizer, | |
) | |
lang_dataloader = DataLoader( | |
lang_dataset, | |
batch_size=args.batch_size, | |
num_workers=args.workers, | |
sampler=InfiniteSampler(lang_dataset, shuffle=True), | |
collate_fn=lang_dataset.collater, | |
) | |
lang_dataloader = iter(lang_dataloader) | |
else: | |
lang_dataloader = None | |
random_seed(args.seed, args.rank) | |
print(f"Start running training on rank {args.rank}.") | |
if args.rank == 0 and args.report_to_wandb: | |
wandb.init( | |
project=args.wandb_project, | |
entity=args.wandb_entity, | |
name=args.run_name, | |
config=vars(args), | |
) | |
device_id = args.rank % torch.cuda.device_count() | |
model = model.to(device_id) | |
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) | |
def get_grouped_params(model): | |
params_with_wd, params_without_wd = [], [] | |
def apply_decay(x): | |
return ( | |
"gated_cross_attn_layer" in x | |
and "ff_gate" not in x | |
and "attn_gate" not in x | |
and "norm" not in x | |
and "bias" not in x | |
) | |
for n, p in model.named_parameters(): | |
# if p.requires_grad: | |
if apply_decay(n): | |
params_with_wd.append(p) | |
else: | |
params_without_wd.append(p) | |
return [ | |
{"params": params_with_wd, "weight_decay": args.weight_decay}, | |
{"params": params_without_wd, "weight_decay": 0.0}, | |
] | |
optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate) | |
total_training_steps = len(train_dataloader) * args.num_epochs | |
if args.rank == 0: | |
print(f"Total training steps: {total_training_steps}") | |
if args.lr_scheduler == "linear": | |
lr_scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps // args.gradient_accumulation_steps, | |
) | |
elif args.lr_scheduler == "cosine": | |
lr_scheduler = get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=total_training_steps // args.gradient_accumulation_steps, | |
) | |
else: | |
lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) | |
# check if a checkpoint exists for this run | |
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: | |
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") | |
if len(checkpoint_list) == 0: | |
print(f"Found no checkpoints for run {args.run_name}.") | |
else: | |
args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] | |
print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") | |
resume_from_epoch = 0 | |
if args.resume_from_checkpoint is not None: | |
if args.rank == 0: | |
print(f"Loading checkpoint from {args.resume_from_checkpoint}") | |
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") | |
ddp_model.load_state_dict(checkpoint["model_state_dict"], False) | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
resume_from_epoch = checkpoint["epoch"] + 1 | |
ddp_model.train() | |
for epoch in range(resume_from_epoch, args.num_epochs): | |
train_dataloader.sampler.set_epoch(epoch) | |
train_one_epoch( | |
args=args, | |
model=ddp_model, | |
epoch=epoch, | |
tokenizer=tokenizer, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
train_dataloader=train_dataloader, | |
language_dataloader=lang_dataloader, | |
device_id=device_id, | |
wandb=wandb, | |
) | |
if args.rank == 0: | |
if not os.path.exists(args.run_name): | |
os.makedirs(args.run_name) | |
checkpoint_dict = { | |
"epoch": epoch, | |
"model_state_dict": get_checkpoint(ddp_model), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
"tuning_config": tuning_config, | |
} | |
print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") | |
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") | |
if args.report_to_wandb and args.save_checkpoints_to_wandb: | |
wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt") | |
if args.delete_previous_checkpoint: | |
if epoch > 0: | |
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") | |
if args.rank == 0: | |
torch.save( | |
{"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config}, | |
f"{args.run_name}/final_weights.pt", | |
) | |
if args.report_to_wandb and args.save_checkpoints_to_wandb: | |
wandb.save(f"{args.run_name}/final_weights.pt") | |
def train_one_epoch( | |
args, | |
model, | |
epoch, | |
train_dataloader, | |
language_dataloader, | |
tokenizer, | |
optimizer, | |
lr_scheduler, | |
device_id, | |
wandb, | |
): | |
num_batches_per_epoch = len(train_dataloader) | |
total_training_steps = num_batches_per_epoch * args.num_epochs | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
model.train() | |
# setup logging | |
step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) | |
data_time_m = ( | |
AverageMeter() | |
) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) | |
end = time.time() | |
# loop through dataloader | |
for num_steps, batch in tqdm( | |
enumerate(train_dataloader), | |
disable=args.rank != 0, | |
total=total_training_steps, | |
initial=(epoch * num_batches_per_epoch), | |
): | |
data_time_m.update(time.time() - end) | |
global_step = num_steps + epoch * num_batches_per_epoch | |
#### VISION FORWARD PASS #### | |
images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) | |
input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
with autocast(): | |
loss_batch = model( | |
vision_x=images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
)[0] | |
loss = loss_batch / args.gradient_accumulation_steps | |
loss_vision = loss # for logging | |
#### BACKWARD PASS #### | |
loss.backward() | |
#### LANGUAGE FORWARD PASS #### | |
if language_dataloader is not None: | |
batch_lang = next(language_dataloader) | |
lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
with autocast(): | |
lang_loss_batch = model( | |
vision_x=None, | |
lang_x=lang_input_ids, | |
attention_mask=lang_attention_mask, | |
labels=lang_labels, | |
)[0] | |
lang_loss = lang_loss_batch / args.gradient_accumulation_steps | |
#### BACKWARD PASS #### | |
lang_loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
# step optimizer and log | |
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# step time and reset end outside of rank 0 | |
step_time_m.update(time.time() - end) | |
end = time.time() | |
if args.rank == 0 and args.report_to_wandb: | |
# compute within rank 0 | |
samples_per_second = ( | |
args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val | |
) | |
samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val | |
wandb.log( | |
{ | |
"data_time": data_time_m.avg, | |
"step_time": step_time_m.avg, | |
"samples_per_second": samples_per_second, | |
"samples_per_second_per_gpu": samples_per_second_per_gpu, | |
"lr": optimizer.param_groups[0]["lr"], | |
}, | |
commit=False, | |
) | |
step_time_m.reset() | |
data_time_m.reset() | |
loss_log = { | |
"loss": loss.item(), | |
"loss_vision": loss_vision.item(), | |
"global_step": global_step, | |
} | |
if language_dataloader is not None: | |
loss_log["loss_lang"] = lang_loss.item() | |
wandb.log(loss_log, commit=True) | |
# Log loss to console | |
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: | |
print( | |
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}" | |
) | |
if __name__ == "__main__": | |
main() | |