| | |
| | |
| | import gc |
| | import json |
| | import math |
| | import os |
| | import os.path as osp |
| | import random |
| | import sys |
| | import time |
| | import traceback |
| | from collections import deque |
| | from contextlib import nullcontext |
| | from functools import partial |
| | from distutils.util import strtobool |
| | from typing import List, Optional, Tuple |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| | os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = '1' |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| | torch._dynamo.config.cache_size_limit = 64 |
| | from torch.nn import functional as F |
| | from torch.profiler import record_function |
| | from torch.utils.data import DataLoader |
| | from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast |
| | import torch.distributed as tdist |
| |
|
| | import infinity.utils.dist as dist |
| | from infinity.dataset.build import build_joint_dataset |
| | from infinity.utils.save_and_load import CKPTSaver, omnistoreCheckpoint, auto_resume, omnistore_auto_resume |
| | from infinity.models.ema import get_ema_model |
| | from infinity.utils import arg_util, misc, wandb_utils |
| | from infinity.trainer import get_trainer |
| | |
| |
|
| | def build_everything_from_args(args: arg_util.Args, saver): |
| | |
| | args.set_initial_seed(benchmark=True) |
| | |
| | print(f'Loading T5 from {args.t5_path}...') |
| | if 'flan-t5' in args.t5_path: |
| | from transformers import T5EncoderModel, T5TokenizerFast |
| | text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(args.t5_path, revision=None, legacy=True) |
| | text_tokenizer.model_max_length = args.tlen |
| | text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(args.t5_path, torch_dtype=torch.float16) |
| | text_encoder.to(args.device) |
| | text_encoder.eval() |
| | text_encoder.requires_grad_(False) |
| | args.text_tokenizer_type = 'flan_t5' |
| | args.text_tokenizer = text_tokenizer |
| | else: |
| | raise ValueError("Only flan-t5 is supported now.") |
| |
|
| | |
| | vae_local, gpt_uncompiled, gpt_wo_ddp, gpt_ddp, gpt_wo_ddp_ema, gpt_ddp_ema, gpt_optim = build_model_optimizer(args) |
| | |
| | |
| | InfinityTrainer = get_trainer(args) |
| | |
| | trainer = InfinityTrainer( |
| | device=args.device, |
| | raw_scale_schedule=args.scale_schedule, |
| | vae_local=vae_local, |
| | gpt_wo_ddp=gpt_wo_ddp, gpt=gpt_ddp, |
| | gpt_opt=gpt_optim, |
| | label_smooth=args.label_smooth, |
| | zero=args.zero, |
| | vae_type=args.vae_type, |
| | reweight_loss_by_scale=args.reweight_loss_by_scale, |
| | gpt_wo_ddp_ema=gpt_wo_ddp_ema, |
| | gpt_ema=gpt_ddp_ema, |
| | use_fsdp_model_ema=args.use_fsdp_model_ema, |
| | other_args=args, |
| | ) |
| | |
| | |
| | global_it = 0 |
| | if args.checkpoint_type == 'torch': |
| | auto_resume_info, start_ep, global_it, acc_str, _, trainer_state, _ = auto_resume(args, 'global_step_*') |
| | if trainer_state is not None and len(trainer_state): |
| | trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) |
| | elif args.checkpoint_type == 'omnistore': |
| | resume_path, info = omnistore_auto_resume(args, 'global_step_*') |
| | if not resume_path and args.rush_omnistore_resume: |
| | resume_path = args.rush_omnistore_resume |
| | if resume_path: |
| | print(f"omnistore resume from {resume_path}", flush=True) |
| | args_state, start_ep, start_it, global_it, acc_str, eval_milestone = saver.load(resume_path, fsdp_object=trainer.gpt, optimizer_object=trainer.gpt_opt.optimizer) |
| | dist.barrier() |
| | if args.rush_omnistore_resume == resume_path: |
| | global_it = 0 |
| | auto_resume_info, acc_str, eval_milestone, trainer_state, args_state = info, '[no acc str]', [], {}, {} |
| | |
| | del vae_local, gpt_uncompiled, gpt_wo_ddp, gpt_ddp, gpt_wo_ddp_ema, gpt_ddp_ema, gpt_optim |
| | dist.barrier() |
| | return text_tokenizer, text_encoder, trainer, global_it |
| |
|
| |
|
| | def build_model_optimizer(args): |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from infinity.models.infinity import Infinity, MultipleLayers |
| | from infinity.models.init_param import init_weights |
| | from infinity.utils.amp_opt import AmpOptimizer |
| | from infinity.utils.lr_control import filter_params |
| | from infinity.utils.load import build_vae_gpt |
| | |
| | |
| | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) |
| | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) |
| | vae_local, gpt_wo_ddp = build_vae_gpt(args, device=args.model_init_device) |
| | count_p = lambda m: sum(p.numel() for p in m.parameters()) / 1e6 |
| | num_para = count_p(gpt_wo_ddp) |
| | if num_para/1000 < 20: |
| | gpt_wo_ddp = gpt_wo_ddp.to('cuda') |
| |
|
| | if args.tini < 0: |
| | args.tini = math.sqrt(1 / gpt_wo_ddp.C / 3) |
| | init_weights(gpt_wo_ddp, other_std=args.tini) |
| | gpt_wo_ddp.special_init() |
| | if args.use_fsdp_model_ema: |
| | gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) |
| | else: |
| | gpt_wo_ddp_ema = None |
| | |
| | if args.rush_resume: |
| | print(f"{args.rush_resume=}") |
| | cpu_d = torch.load(args.rush_resume, 'cpu') |
| | if 'trainer' in cpu_d: |
| | state_dict = cpu_d['trainer']['gpt_fsdp'] |
| | ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict) |
| | else: |
| | state_dict = cpu_d |
| | ema_state_dict = state_dict |
| | def drop_unfit_weights(state_dict): |
| | if 'word_embed.weight' in state_dict and (state_dict['word_embed.weight'].shape[1] != gpt_wo_ddp.word_embed.in_features): |
| | print(f'[rush_resume] drop word_embed.weight') |
| | del state_dict['word_embed.weight'] |
| | if 'head.weight' in state_dict and (state_dict['head.weight'].shape[0] != gpt_wo_ddp.head.out_features): |
| | print(f'[rush_resume] drop head.weight') |
| | del state_dict['head.weight'] |
| | if 'head.bias' in state_dict and (state_dict['head.bias'].shape[0] != gpt_wo_ddp.head.bias.shape[0]): |
| | print(f'[rush_resume] drop head.bias') |
| | del state_dict['head.bias'] |
| | if 'text_proj_for_sos.ca.mat_kv.weight' in state_dict and \ |
| | (state_dict['text_proj_for_sos.ca.mat_kv.weight'].shape != gpt_wo_ddp.text_proj_for_sos.ca.mat_kv.weight.shape): |
| | print(f'[rush_resume] drop cfg_uncond') |
| | del state_dict['cfg_uncond'] |
| | for key in list(state_dict.keys()): |
| | if 'text' in key: |
| | del state_dict[key] |
| | if 'semantic_head.weight' in state_dict: |
| | print(f'[rush_resume] replace semantic_head with semantic_head2') |
| | state_dict['semantic_head2.weight'] = state_dict['semantic_head.weight'] |
| | state_dict['semantic_head2.bias'] = state_dict['semantic_head.bias'] |
| | del state_dict['semantic_head.weight'] |
| | del state_dict['semantic_head.bias'] |
| | if 'semantic_head2.weight' in state_dict and (state_dict['semantic_head2.weight'].shape[0] != gpt_wo_ddp.semantic_head2.out_features): |
| | print(f'[rush_resume] drop semantic_head2.weight, semantic_head2.bias') |
| | del state_dict['semantic_head2.weight'] |
| | del state_dict['semantic_head2.bias'] |
| | return state_dict |
| | print(gpt_wo_ddp.load_state_dict(drop_unfit_weights(state_dict), strict=False)) |
| | if args.use_fsdp_model_ema: |
| | gpt_wo_ddp_ema.load_state_dict(drop_unfit_weights(ema_state_dict), strict=False) |
| | elif args.torchshard_resume: |
| | from transformers.modeling_utils import load_sharded_checkpoint |
| | load_sharded_checkpoint(gpt_wo_ddp, args.torchshard_resume, strict=False) |
| |
|
| | ndim_dict = {name: para.ndim for name, para in gpt_wo_ddp.named_parameters() if para.requires_grad} |
| | |
| | print(f'[PT] GPT model = {gpt_wo_ddp}\n\n') |
| | print(f'[PT][#para], GPT={num_para:.2f}\n\n') |
| | |
| | gpt_uncompiled = gpt_wo_ddp |
| |
|
| | gpt_ddp_ema = None |
| | if args.zero: |
| | from torch.distributed.fsdp import ShardingStrategy |
| | from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
| | from torch.distributed.device_mesh import init_device_mesh |
| |
|
| | |
| | if gpt_wo_ddp.num_block_chunks == 1: |
| | auto_wrap_policy = ModuleWrapPolicy([type(gpt_wo_ddp.unregistered_blocks[0]), ]) |
| | else: |
| | auto_wrap_policy = ModuleWrapPolicy([MultipleLayers, ]) |
| | |
| | if args.enable_hybrid_shard: |
| | sharding_strategy = ShardingStrategy.HYBRID_SHARD if args.zero == 3 else ShardingStrategy._HYBRID_SHARD_ZERO2 |
| | world_size = dist.get_world_size() |
| | assert world_size % args.inner_shard_degree == 0 |
| | assert args.inner_shard_degree > 1 and args.inner_shard_degree < world_size |
| | device_mesh = init_device_mesh('cuda', (world_size // args.inner_shard_degree, args.inner_shard_degree)) |
| | else: |
| | sharding_strategy = ShardingStrategy.FULL_SHARD if args.zero == 3 else ShardingStrategy.SHARD_GRAD_OP |
| | device_mesh = None |
| | print(f'{">" * 45 + " " * 5} FSDP INIT with {args.zero=} {sharding_strategy=} {auto_wrap_policy=} {" " * 5 + "<" * 45}', flush=True) |
| |
|
| | if args.fsdp_init_device == 'cpu': |
| | gpt_wo_ddp = gpt_wo_ddp.cpu() |
| |
|
| | gpt_ddp: FSDP = FSDP( |
| | gpt_wo_ddp, |
| | device_id=dist.get_local_rank(), |
| | sharding_strategy=sharding_strategy, |
| | mixed_precision=None, |
| | auto_wrap_policy=auto_wrap_policy, |
| | use_orig_params=True, |
| | sync_module_states=True, |
| | limit_all_gathers=True, |
| | device_mesh=device_mesh, |
| | ).to(args.device) |
| | |
| | if args.use_fsdp_model_ema: |
| | gpt_wo_ddp_ema = gpt_wo_ddp_ema.to(args.device) |
| | gpt_ddp_ema: FSDP = FSDP( |
| | gpt_wo_ddp_ema, |
| | device_id=dist.get_local_rank(), |
| | sharding_strategy=sharding_strategy, |
| | mixed_precision=None, |
| | auto_wrap_policy=auto_wrap_policy, |
| | use_orig_params=args.fsdp_orig, |
| | sync_module_states=True, |
| | limit_all_gathers=True, |
| | ) |
| | else: |
| | ddp_class = DDP if dist.initialized() else misc.NullDDP |
| | gpt_ddp: DDP = ddp_class(gpt_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
| | torch.cuda.synchronize() |
| |
|
| | |
| | nowd_keys = set() |
| | if args.disable_weight_decay: |
| | nowd_keys |= { |
| | 'cls_token', 'start_token', 'task_token', 'cfg_uncond', |
| | 'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed', |
| | 'gamma', 'beta', |
| | 'ada_gss', 'moe_bias', |
| | 'scale_mul', |
| | 'text_proj_for_sos.ca.mat_q', |
| | } |
| | names, paras, para_groups = filter_params(gpt_ddp if args.zero else gpt_wo_ddp, ndim_dict, nowd_keys=nowd_keys) |
| | del ndim_dict |
| | if '_' in args.ada: |
| | beta0, beta1 = map(float, args.ada.split('_')) |
| | else: |
| | beta0, beta1 = float(args.ada), -1 |
| | |
| | opt_clz = { |
| | 'sgd': partial(torch.optim.SGD, momentum=beta0, nesterov=True), |
| | 'adam': partial(torch.optim.AdamW, betas=(beta0, beta1), fused=args.fused_adam), |
| | 'adamw': partial(torch.optim.AdamW, betas=(beta0, beta1), fused=args.fused_adam), |
| | }[args.opt] |
| | opt_kw = dict(lr=args.tlr, weight_decay=0) |
| | if args.adam_eps: opt_kw['eps'] = args.adam_eps |
| | print(f'[vgpt] optim={opt_clz}, opt_kw={opt_kw}\n') |
| | gpt_optim = AmpOptimizer('gpt', args.fp16, opt_clz(params=para_groups, **opt_kw), gpt_ddp if args.zero else gpt_wo_ddp, args.r_accu, args.grad_clip, args.zero) |
| | del names, paras, para_groups |
| | return vae_local, gpt_uncompiled, gpt_wo_ddp, gpt_ddp, gpt_wo_ddp_ema, gpt_ddp_ema, gpt_optim |
| |
|
| |
|
| | def build_dataset(args): |
| | train_dataset = build_joint_dataset( |
| | args, |
| | args.data_path, |
| | args.video_data_path, |
| | max_caption_len=args.tlen, |
| | short_prob=args.short_cap_prob, |
| | load_vae_instead_of_image=False |
| | ) |
| | return train_dataset |
| |
|
| | def main_train(args: arg_util.Args): |
| | if args.checkpoint_type == 'torch': |
| | saver = CKPTSaver(dist.is_master(), eval_milestone=None) |
| | elif args.checkpoint_type == 'omnistore': |
| | saver = omnistoreCheckpoint(eval_milestone=None) |
| | else: |
| | raise ValueError(f'{args.checkpoint_type=}') |
| | ret = build_everything_from_args(args, saver) |
| | |
| | if ret is None: |
| | return |
| | |
| | text_tokenizer, text_encoder, trainer, start_global_it = ret |
| | gc.collect(), torch.cuda.empty_cache() |
| | seg5 = np.linspace(1, args.epoch, 5+1, dtype=int).tolist() |
| | |
| | time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3) |
| | ep_lg = max(1, args.epoch // 10) if args.epoch <= 100 else max(1, args.epoch // 20) |
| | |
| | |
| | |
| | if dist.is_master(): |
| | wandb_utils.wandb.init(project=args.project_name, name=args.exp_name, config={}) |
| | for ep in range(args.epoch): |
| | |
| | args.epoch = ep |
| |
|
| | if ep == 0: |
| | train_dataset = build_dataset(args) |
| | iters_train = len(train_dataset) |
| | start_ep = start_global_it // iters_train |
| | start_it = start_global_it % iters_train |
| | print(f'[PT info] from ep{start_ep} it{start_it} {iters_train=}=======> bed: {args.bed} <=======\n') |
| |
|
| | if ep < start_ep: |
| | continue |
| | if ep > start_ep: |
| | train_dataset = build_dataset(args) |
| | iters_train = len(train_dataset) |
| |
|
| | |
| | train_dataloader = DataLoader(dataset=train_dataset, num_workers=args.workers, pin_memory=True, batch_size=None) |
| | stats = train_one_epoch( |
| | epoch=ep, |
| | is_first_ep=ep == start_ep, |
| | start_it=start_it if ep == start_ep else 0, |
| | start_global_it=start_global_it, |
| | me=None, |
| | saver=saver, |
| | args=args, |
| | dataloader_iter=iter(train_dataloader), |
| | iters_train=iters_train, |
| | text_tokenizer=text_tokenizer, text_encoder=text_encoder, |
| | trainer=trainer, |
| | ) |
| | |
| | del stats, train_dataset, train_dataloader |
| | return |
| |
|
| |
|
| | g_speed_ls = deque(maxlen=128) |
| | def train_one_epoch( |
| | epoch: int, is_first_ep: bool, start_it: int, start_global_it: int, me: misc.MetricLogger, |
| | saver: CKPTSaver, args: arg_util.Args, dataloader_iter, iters_train: int, |
| | text_tokenizer: T5TokenizerFast, text_encoder: T5EncoderModel, trainer, |
| | ): |
| | |
| | step_cnt = 0 |
| | header = f'[Ep]: [{epoch:4d}/{args.epoch}]' |
| | |
| | last_touch = time.time() |
| | g_it, max_it = epoch * iters_train, args.epoch * iters_train |
| | |
| | doing_profiling = args.prof and epoch == 0 and (args.profall or dist.is_master()) |
| | maybe_record_function = record_function if doing_profiling else nullcontext |
| | trainer.gpt_wo_ddp.maybe_record_function = maybe_record_function |
| | |
| | last_t_perf = time.time() |
| | speed_ls: deque = g_speed_ls |
| | FREQ = min(args.prof_freq, iters_train//2-1) |
| | NVIDIA_IT_PLUS_1 = set(FREQ*i for i in (1, 2, 3, 4, 6, 8)) |
| | ranges = set([2 ** i for i in range(20)]) |
| | if epoch <= 1: ranges |= {1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 32, 40} |
| | PRINTABLE_IT_PLUS_1 = set(FREQ*i for i in ranges) |
| |
|
| | me = misc.MetricLogger() |
| | [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{value:.2g}')) for x in ['tlr']] |
| | [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{median:.2f} ({global_avg:.2f})')) for x in ['tnm']] |
| | [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{median:.3f} ({global_avg:.3f})')) for x in ['L', 'L_i', 'L_v']] |
| | [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Acc', 'Acc_i', 'Acc_v']] |
| | [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{median:.2f} ({global_avg:.2f})')) for x in ['seq_usage']] |
| | |
| | for it, data in me.log_every(start_it, iters_train, dataloader_iter, args.log_freq, args.log_every_iter, header, args): |
| | g_it = epoch * iters_train + it |
| | |
| | |
| | |
| |
|
| |
|
| | if (it+1) % FREQ == 0: |
| | speed_ls.append((time.time() - last_t_perf) / FREQ) |
| | last_t_perf = time.time() |
| |
|
| | if (g_it+1) % args.save_model_iters_freq == 0: |
| | if args.checkpoint_type == 'torch': |
| | saver.sav(args=args, g_it=(g_it+1), next_ep=epoch, next_it=it+1, trainer=trainer, acc_str=f'[todo]', eval_milestone=None, also_save_to=None, best_save_to=None) |
| | elif args.checkpoint_type == 'omnistore': |
| | saver.sav(args=args, global_it=(g_it+1), next_ep=epoch, next_it=it+1, fsdp_object=trainer.gpt, optimizer_object=trainer.gpt_opt.optimizer, acc_str=None, eval_milestone=None) |
| | |
| | with maybe_record_function('before_train'): |
| | |
| | images, captions, raw_features_bcthw, feature_cache_files4images, media = data['images'], data['captions'], data['raw_features_bcthw'], data['feature_cache_files4images'], data['media'] |
| |
|
| | |
| | if args.text_tokenizer_type == 'flan_t5': |
| | tokens = text_tokenizer(text=captions, max_length=text_tokenizer.model_max_length, padding='max_length', truncation=True, return_tensors='pt') |
| | input_ids = tokens.input_ids.cuda(non_blocking=True) |
| | mask = tokens.attention_mask.cuda(non_blocking=True) |
| | text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float() |
| | lens: List[int] = mask.sum(dim=-1).tolist() |
| | cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0)) |
| | Ltext = max(lens) |
| | kv_compact = [] |
| | for text_ind, (len_i, feat_i) in enumerate(zip(lens, text_features.unbind(0))): |
| | kv_compact.append(feat_i[:len_i]) |
| | kv_compact = torch.cat(kv_compact, dim=0) |
| | text_cond_tuple: Tuple[torch.FloatTensor, List[int], torch.LongTensor, int] = (kv_compact, lens, cu_seqlens_k, Ltext) |
| | else: |
| | text_features = text_encoder(captions, args.device) |
| | lens = [len(item) for item in text_features] |
| | cu_seqlens_k = [0] |
| | for len_i in lens: |
| | cu_seqlens_k.append(cu_seqlens_k[-1] + len_i) |
| | cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32) |
| | Ltext = max(lens) |
| | kv_compact = torch.cat(text_features, dim=0).float() |
| | text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext) |
| |
|
| | if len(images): |
| | images = [item.to(args.device, non_blocking=True) for item in images] |
| | if len(raw_features_bcthw): |
| | raw_features_bcthw = [item.to(args.device, non_blocking=True) for item in raw_features_bcthw] |
| | |
| | |
| | if dist.is_local_master() and (it >= start_it + 10) and (time.time() - last_touch > 90): |
| | args.dump_log() |
| | last_touch = time.time() |
| | |
| | |
| | progress = g_it / (max_it - 1) |
| | clip_decay_ratio = (0.3 ** (20 * progress) + 0.2) if args.cdec else 1 |
| | |
| | stepping = (g_it + 1) % args.ac == 0 |
| | step_cnt += int(stepping) |
| | |
| | with maybe_record_function('in_training'): |
| | grad_norm_t, scale_log2_t = trainer.train_step( |
| | epoch=epoch, |
| | it=it, |
| | g_it=g_it, |
| | stepping=stepping, |
| | clip_decay_ratio=clip_decay_ratio, |
| | metric_lg=me, |
| | inp_B3HW=images, |
| | raw_features_bcthw=raw_features_bcthw, |
| | feature_cache_files4images=feature_cache_files4images, |
| | text_cond_tuple=text_cond_tuple, |
| | media=media, |
| | args=args, |
| | ) |
| | |
| | with maybe_record_function('after_train'): |
| | me.update(tlr=args.tlr) |
| | |
| | |
| | me.synchronize_between_processes() |
| | return {k: meter.global_avg for k, meter in me.meters.items()} |
| |
|
| |
|
| | def main(): |
| | args: arg_util.Args = arg_util.init_dist_and_get_args() |
| | main_train(args) |
| | print(f'final args:\n\n{str(args)}') |
| | args.dump_log() |
| | if isinstance(sys.stdout, dist.BackupStreamToFile) and isinstance(sys.stderr, dist.BackupStreamToFile): |
| | sys.stdout.close(), sys.stderr.close() |
| | dist.barrier() |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|