Spaces:
Build error
Build error
| """Modified from https://github.com/mlfoundations/open_flamingo""" | |
| import argparse | |
| import copy | |
| import glob | |
| import os | |
| import random | |
| import time | |
| import numpy as np | |
| import torch | |
| import wandb | |
| from mmengine import Config | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from tqdm import tqdm | |
| from transformers import ( | |
| get_constant_schedule_with_warmup, | |
| get_cosine_schedule_with_warmup, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from mmgpt import create_model_and_transforms | |
| from mmgpt.models.builder import create_toy_model_and_transforms | |
| from mmgpt.datasets import InfiniteSampler, build_dataset | |
| from mmgpt.train.distributed import init_distributed_device, world_info_from_env | |
| from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint | |
| def random_seed(seed=42, rank=0): | |
| torch.manual_seed(seed + rank) | |
| np.random.seed(seed + rank) | |
| random.seed(seed + rank) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) | |
| parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) | |
| parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str) | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| default="checkpoints/llama-7b_hf", | |
| type=str, | |
| help="path to tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--pretrained_path", | |
| default="checkpoints/OpenFlamingo-9B/checkpoint.pt", | |
| type=str, | |
| help="path to pretrained model", | |
| ) | |
| parser.add_argument( | |
| "--run_name", | |
| type=str, | |
| default="train-my-gpt4", | |
| help="used to name saving directory and wandb run", | |
| ) | |
| parser.add_argument("--use_media_placement_augmentation", action="store_true") | |
| parser.add_argument("--offline", action="store_true") | |
| parser.add_argument("--num_epochs", type=int, default=1) | |
| parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps") | |
| # Sum of gradient optimization batch size | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--delete_previous_checkpoint", | |
| action="store_true", | |
| help="delete previous checkpoint when saving new checkpoint", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--learning_rate", default=1e-5, type=float) | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| default="constant", | |
| type=str, | |
| help="constant, linear, or cosine", | |
| ) | |
| parser.add_argument("--warmup_steps", default=100, type=int) | |
| parser.add_argument("--weight_decay", default=0.1, type=float) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], | |
| default="amp", | |
| help="Floating point precision.", | |
| ) | |
| # data args | |
| parser.add_argument("--workers", type=int, default=0) | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file") | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=16) | |
| # Finetune config | |
| parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file") | |
| # distributed training args | |
| parser.add_argument( | |
| "--dist-url", | |
| default="env://", | |
| type=str, | |
| help="url used to set up distributed training", | |
| ) | |
| parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") | |
| parser.add_argument( | |
| "--horovod", | |
| default=False, | |
| action="store_true", | |
| help="Use horovod for distributed training.", | |
| ) | |
| parser.add_argument( | |
| "--no-set-device-rank", | |
| default=False, | |
| action="store_true", | |
| help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
| ) | |
| # wandb args | |
| parser.add_argument("--report_to_wandb", default=False, action="store_true") | |
| parser.add_argument( | |
| "--wandb_project", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--wandb_entity", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--save_checkpoints_to_wandb", | |
| default=False, | |
| action="store_true", | |
| help="save checkpoints to wandb", | |
| ) | |
| args = parser.parse_args() | |
| if args.save_checkpoints_to_wandb and not args.report_to_wandb: | |
| raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") | |
| if args.offline: | |
| os.environ["WANDB_MODE"] = "offline" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| if args.rank == 0: | |
| if not os.path.exists(args.run_name): | |
| os.makedirs(args.run_name) | |
| device_id = init_distributed_device(args) | |
| random_seed(args.seed) | |
| if args.tuning_config is not None: | |
| tuning_config = Config.fromfile(args.tuning_config) | |
| else: | |
| raise ValueError("tuning_config must be specified") | |
| model, image_processor, tokenizer = create_model_and_transforms( | |
| model_name="open_flamingo", | |
| clip_vision_encoder_path=args.vision_encoder_path, | |
| clip_vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| lang_encoder_path=args.lm_path, | |
| tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path, | |
| use_media_placement_augmentation=args.use_media_placement_augmentation, | |
| pretrained_model_path=args.pretrained_path, | |
| tuning_config=tuning_config.tuning_config, | |
| ) | |
| if args.dataset_config is not None: | |
| dataset_config = Config.fromfile(args.dataset_config) | |
| else: | |
| raise ValueError("dataset_config must be specified") | |
| dataset = build_dataset( | |
| dataset_config=dataset_config.visual_datasets, | |
| vis_processor=image_processor, | |
| tokenizer=tokenizer, | |
| ) | |
| train_dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.workers, | |
| sampler=DistributedSampler(dataset, shuffle=True, drop_last=True), | |
| collate_fn=dataset.collater, | |
| ) | |
| # build language dataset and dataloader for multi-modality training | |
| if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0: | |
| lang_dataset = build_dataset( | |
| dataset_config=dataset_config.language_datasets, | |
| tokenizer=tokenizer, | |
| ) | |
| lang_dataloader = DataLoader( | |
| lang_dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.workers, | |
| sampler=InfiniteSampler(lang_dataset, shuffle=True), | |
| collate_fn=lang_dataset.collater, | |
| ) | |
| lang_dataloader = iter(lang_dataloader) | |
| else: | |
| lang_dataloader = None | |
| random_seed(args.seed, args.rank) | |
| print(f"Start running training on rank {args.rank}.") | |
| if args.rank == 0 and args.report_to_wandb: | |
| wandb.init( | |
| project=args.wandb_project, | |
| entity=args.wandb_entity, | |
| name=args.run_name, | |
| config=vars(args), | |
| ) | |
| device_id = args.rank % torch.cuda.device_count() | |
| model = model.to(device_id) | |
| ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) | |
| def get_grouped_params(model): | |
| params_with_wd, params_without_wd = [], [] | |
| def apply_decay(x): | |
| return ( | |
| "gated_cross_attn_layer" in x | |
| and "ff_gate" not in x | |
| and "attn_gate" not in x | |
| and "norm" not in x | |
| and "bias" not in x | |
| ) | |
| for n, p in model.named_parameters(): | |
| # if p.requires_grad: | |
| if apply_decay(n): | |
| params_with_wd.append(p) | |
| else: | |
| params_without_wd.append(p) | |
| return [ | |
| {"params": params_with_wd, "weight_decay": args.weight_decay}, | |
| {"params": params_without_wd, "weight_decay": 0.0}, | |
| ] | |
| optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate) | |
| total_training_steps = len(train_dataloader) * args.num_epochs | |
| if args.rank == 0: | |
| print(f"Total training steps: {total_training_steps}") | |
| if args.lr_scheduler == "linear": | |
| lr_scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps // args.gradient_accumulation_steps, | |
| ) | |
| elif args.lr_scheduler == "cosine": | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps // args.gradient_accumulation_steps, | |
| ) | |
| else: | |
| lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) | |
| # check if a checkpoint exists for this run | |
| if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: | |
| checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") | |
| if len(checkpoint_list) == 0: | |
| print(f"Found no checkpoints for run {args.run_name}.") | |
| else: | |
| args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] | |
| print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") | |
| resume_from_epoch = 0 | |
| if args.resume_from_checkpoint is not None: | |
| if args.rank == 0: | |
| print(f"Loading checkpoint from {args.resume_from_checkpoint}") | |
| checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") | |
| ddp_model.load_state_dict(checkpoint["model_state_dict"], False) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
| resume_from_epoch = checkpoint["epoch"] + 1 | |
| ddp_model.train() | |
| for epoch in range(resume_from_epoch, args.num_epochs): | |
| train_dataloader.sampler.set_epoch(epoch) | |
| train_one_epoch( | |
| args=args, | |
| model=ddp_model, | |
| epoch=epoch, | |
| tokenizer=tokenizer, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| train_dataloader=train_dataloader, | |
| language_dataloader=lang_dataloader, | |
| device_id=device_id, | |
| wandb=wandb, | |
| ) | |
| if args.rank == 0: | |
| if not os.path.exists(args.run_name): | |
| os.makedirs(args.run_name) | |
| checkpoint_dict = { | |
| "epoch": epoch, | |
| "model_state_dict": get_checkpoint(ddp_model), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
| "tuning_config": tuning_config, | |
| } | |
| print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") | |
| torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") | |
| if args.report_to_wandb and args.save_checkpoints_to_wandb: | |
| wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt") | |
| if args.delete_previous_checkpoint: | |
| if epoch > 0: | |
| os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") | |
| if args.rank == 0: | |
| torch.save( | |
| {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config}, | |
| f"{args.run_name}/final_weights.pt", | |
| ) | |
| if args.report_to_wandb and args.save_checkpoints_to_wandb: | |
| wandb.save(f"{args.run_name}/final_weights.pt") | |
| def train_one_epoch( | |
| args, | |
| model, | |
| epoch, | |
| train_dataloader, | |
| language_dataloader, | |
| tokenizer, | |
| optimizer, | |
| lr_scheduler, | |
| device_id, | |
| wandb, | |
| ): | |
| num_batches_per_epoch = len(train_dataloader) | |
| total_training_steps = num_batches_per_epoch * args.num_epochs | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| model.train() | |
| # setup logging | |
| step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) | |
| data_time_m = ( | |
| AverageMeter() | |
| ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) | |
| end = time.time() | |
| # loop through dataloader | |
| for num_steps, batch in tqdm( | |
| enumerate(train_dataloader), | |
| disable=args.rank != 0, | |
| total=total_training_steps, | |
| initial=(epoch * num_batches_per_epoch), | |
| ): | |
| data_time_m.update(time.time() - end) | |
| global_step = num_steps + epoch * num_batches_per_epoch | |
| #### VISION FORWARD PASS #### | |
| images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) | |
| input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| with autocast(): | |
| loss_batch = model( | |
| vision_x=images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| )[0] | |
| loss = loss_batch / args.gradient_accumulation_steps | |
| loss_vision = loss # for logging | |
| #### BACKWARD PASS #### | |
| loss.backward() | |
| #### LANGUAGE FORWARD PASS #### | |
| if language_dataloader is not None: | |
| batch_lang = next(language_dataloader) | |
| lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| with autocast(): | |
| lang_loss_batch = model( | |
| vision_x=None, | |
| lang_x=lang_input_ids, | |
| attention_mask=lang_attention_mask, | |
| labels=lang_labels, | |
| )[0] | |
| lang_loss = lang_loss_batch / args.gradient_accumulation_steps | |
| #### BACKWARD PASS #### | |
| lang_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| # step optimizer and log | |
| if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # step time and reset end outside of rank 0 | |
| step_time_m.update(time.time() - end) | |
| end = time.time() | |
| if args.rank == 0 and args.report_to_wandb: | |
| # compute within rank 0 | |
| samples_per_second = ( | |
| args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val | |
| ) | |
| samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val | |
| wandb.log( | |
| { | |
| "data_time": data_time_m.avg, | |
| "step_time": step_time_m.avg, | |
| "samples_per_second": samples_per_second, | |
| "samples_per_second_per_gpu": samples_per_second_per_gpu, | |
| "lr": optimizer.param_groups[0]["lr"], | |
| }, | |
| commit=False, | |
| ) | |
| step_time_m.reset() | |
| data_time_m.reset() | |
| loss_log = { | |
| "loss": loss.item(), | |
| "loss_vision": loss_vision.item(), | |
| "global_step": global_step, | |
| } | |
| if language_dataloader is not None: | |
| loss_log["loss_lang"] = lang_loss.item() | |
| wandb.log(loss_log, commit=True) | |
| # Log loss to console | |
| if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: | |
| print( | |
| f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |