"""Modified from https://github.com/mlfoundations/open_flamingo""" import argparse import copy import glob import os import random import time import numpy as np import torch import wandb from mmengine import Config from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from transformers import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) from mmgpt import create_model_and_transforms from mmgpt.models.builder import create_toy_model_and_transforms from mmgpt.datasets import InfiniteSampler, build_dataset from mmgpt.train.distributed import init_distributed_device, world_info_from_env from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint def random_seed(seed=42, rank=0): torch.manual_seed(seed + rank) np.random.seed(seed + rank) random.seed(seed + rank) def main(): parser = argparse.ArgumentParser() parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str) parser.add_argument( "--tokenizer_path", default="checkpoints/llama-7b_hf", type=str, help="path to tokenizer", ) parser.add_argument( "--pretrained_path", default="checkpoints/OpenFlamingo-9B/checkpoint.pt", type=str, help="path to pretrained model", ) parser.add_argument( "--run_name", type=str, default="train-my-gpt4", help="used to name saving directory and wandb run", ) parser.add_argument("--use_media_placement_augmentation", action="store_true") parser.add_argument("--offline", action="store_true") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps") # Sum of gradient optimization batch size parser.add_argument( "--resume_from_checkpoint", type=str, help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", default=None, ) parser.add_argument( "--delete_previous_checkpoint", action="store_true", help="delete previous checkpoint when saving new checkpoint", ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--learning_rate", default=1e-5, type=float) parser.add_argument( "--lr_scheduler", default="constant", type=str, help="constant, linear, or cosine", ) parser.add_argument("--warmup_steps", default=100, type=int) parser.add_argument("--weight_decay", default=0.1, type=float) parser.add_argument( "--precision", choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], default="amp", help="Floating point precision.", ) # data args parser.add_argument("--workers", type=int, default=0) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file") parser.add_argument("--gradient_accumulation_steps", type=int, default=16) # Finetune config parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file") # distributed training args parser.add_argument( "--dist-url", default="env://", type=str, help="url used to set up distributed training", ) parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") parser.add_argument( "--horovod", default=False, action="store_true", help="Use horovod for distributed training.", ) parser.add_argument( "--no-set-device-rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) # wandb args parser.add_argument("--report_to_wandb", default=False, action="store_true") parser.add_argument( "--wandb_project", type=str, ) parser.add_argument( "--wandb_entity", type=str, ) parser.add_argument( "--save_checkpoints_to_wandb", default=False, action="store_true", help="save checkpoints to wandb", ) args = parser.parse_args() if args.save_checkpoints_to_wandb and not args.report_to_wandb: raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") if args.offline: os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" args.local_rank, args.rank, args.world_size = world_info_from_env() if args.rank == 0: if not os.path.exists(args.run_name): os.makedirs(args.run_name) device_id = init_distributed_device(args) random_seed(args.seed) if args.tuning_config is not None: tuning_config = Config.fromfile(args.tuning_config) else: raise ValueError("tuning_config must be specified") model, image_processor, tokenizer = create_model_and_transforms( model_name="open_flamingo", clip_vision_encoder_path=args.vision_encoder_path, clip_vision_encoder_pretrained=args.vision_encoder_pretrained, lang_encoder_path=args.lm_path, tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path, use_media_placement_augmentation=args.use_media_placement_augmentation, pretrained_model_path=args.pretrained_path, tuning_config=tuning_config.tuning_config, ) if args.dataset_config is not None: dataset_config = Config.fromfile(args.dataset_config) else: raise ValueError("dataset_config must be specified") dataset = build_dataset( dataset_config=dataset_config.visual_datasets, vis_processor=image_processor, tokenizer=tokenizer, ) train_dataloader = DataLoader( dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=DistributedSampler(dataset, shuffle=True, drop_last=True), collate_fn=dataset.collater, ) # build language dataset and dataloader for multi-modality training if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0: lang_dataset = build_dataset( dataset_config=dataset_config.language_datasets, tokenizer=tokenizer, ) lang_dataloader = DataLoader( lang_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=InfiniteSampler(lang_dataset, shuffle=True), collate_fn=lang_dataset.collater, ) lang_dataloader = iter(lang_dataloader) else: lang_dataloader = None random_seed(args.seed, args.rank) print(f"Start running training on rank {args.rank}.") if args.rank == 0 and args.report_to_wandb: wandb.init( project=args.wandb_project, entity=args.wandb_entity, name=args.run_name, config=vars(args), ) device_id = args.rank % torch.cuda.device_count() model = model.to(device_id) ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) def get_grouped_params(model): params_with_wd, params_without_wd = [], [] def apply_decay(x): return ( "gated_cross_attn_layer" in x and "ff_gate" not in x and "attn_gate" not in x and "norm" not in x and "bias" not in x ) for n, p in model.named_parameters(): # if p.requires_grad: if apply_decay(n): params_with_wd.append(p) else: params_without_wd.append(p) return [ {"params": params_with_wd, "weight_decay": args.weight_decay}, {"params": params_without_wd, "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate) total_training_steps = len(train_dataloader) * args.num_epochs if args.rank == 0: print(f"Total training steps: {total_training_steps}") if args.lr_scheduler == "linear": lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_training_steps // args.gradient_accumulation_steps, ) elif args.lr_scheduler == "cosine": lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_training_steps // args.gradient_accumulation_steps, ) else: lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) # check if a checkpoint exists for this run if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") if len(checkpoint_list) == 0: print(f"Found no checkpoints for run {args.run_name}.") else: args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") resume_from_epoch = 0 if args.resume_from_checkpoint is not None: if args.rank == 0: print(f"Loading checkpoint from {args.resume_from_checkpoint}") checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") ddp_model.load_state_dict(checkpoint["model_state_dict"], False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) resume_from_epoch = checkpoint["epoch"] + 1 ddp_model.train() for epoch in range(resume_from_epoch, args.num_epochs): train_dataloader.sampler.set_epoch(epoch) train_one_epoch( args=args, model=ddp_model, epoch=epoch, tokenizer=tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, language_dataloader=lang_dataloader, device_id=device_id, wandb=wandb, ) if args.rank == 0: if not os.path.exists(args.run_name): os.makedirs(args.run_name) checkpoint_dict = { "epoch": epoch, "model_state_dict": get_checkpoint(ddp_model), "optimizer_state_dict": optimizer.state_dict(), "lr_scheduler_state_dict": lr_scheduler.state_dict(), "tuning_config": tuning_config, } print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") if args.report_to_wandb and args.save_checkpoints_to_wandb: wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt") if args.delete_previous_checkpoint: if epoch > 0: os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") if args.rank == 0: torch.save( {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config}, f"{args.run_name}/final_weights.pt", ) if args.report_to_wandb and args.save_checkpoints_to_wandb: wandb.save(f"{args.run_name}/final_weights.pt") def train_one_epoch( args, model, epoch, train_dataloader, language_dataloader, tokenizer, optimizer, lr_scheduler, device_id, wandb, ): num_batches_per_epoch = len(train_dataloader) total_training_steps = num_batches_per_epoch * args.num_epochs autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) model.train() # setup logging step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) data_time_m = ( AverageMeter() ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) end = time.time() # loop through dataloader for num_steps, batch in tqdm( enumerate(train_dataloader), disable=args.rank != 0, total=total_training_steps, initial=(epoch * num_batches_per_epoch), ): data_time_m.update(time.time() - end) global_step = num_steps + epoch * num_batches_per_epoch #### VISION FORWARD PASS #### images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) with autocast(): loss_batch = model( vision_x=images, lang_x=input_ids, attention_mask=attention_mask, labels=labels, )[0] loss = loss_batch / args.gradient_accumulation_steps loss_vision = loss # for logging #### BACKWARD PASS #### loss.backward() #### LANGUAGE FORWARD PASS #### if language_dataloader is not None: batch_lang = next(language_dataloader) lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) with autocast(): lang_loss_batch = model( vision_x=None, lang_x=lang_input_ids, attention_mask=lang_attention_mask, labels=lang_labels, )[0] lang_loss = lang_loss_batch / args.gradient_accumulation_steps #### BACKWARD PASS #### lang_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # step optimizer and log if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): optimizer.step() lr_scheduler.step() optimizer.zero_grad() # step time and reset end outside of rank 0 step_time_m.update(time.time() - end) end = time.time() if args.rank == 0 and args.report_to_wandb: # compute within rank 0 samples_per_second = ( args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val ) samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val wandb.log( { "data_time": data_time_m.avg, "step_time": step_time_m.avg, "samples_per_second": samples_per_second, "samples_per_second_per_gpu": samples_per_second_per_gpu, "lr": optimizer.param_groups[0]["lr"], }, commit=False, ) step_time_m.reset() data_time_m.reset() loss_log = { "loss": loss.item(), "loss_vision": loss_vision.item(), "global_step": global_step, } if language_dataloader is not None: loss_log["loss_lang"] = lang_loss.item() wandb.log(loss_log, commit=True) # Log loss to console if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: print( f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}" ) if __name__ == "__main__": main()