import logging import math import os import mup import numpy as np import torch from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from torch.utils.data import DataLoader from tqdm.auto import tqdm import transformers from transformers import ( default_data_collator, get_scheduler, ) import wandb from cont_data import RawFeatureDataset, get_maskgit_collator_feature from genie.config import DiffusionGenieConfig from genie.st_mar import STMAR from datetime import datetime from accelerate import DistributedDataParallelKwargs from common import data_sampler import yaml from train_diffusion import parse_args, train # Get current date and time now = datetime.now() # Format the datetime object as a string formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") torch.set_float32_matmul_precision("medium") logger = get_logger(__name__) torch.autograd.set_detect_anomaly(True) def parse_args_multi(): # parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.") parser = parse_args() # Data parser.add_argument( "--train_split", type=str, default="experiments/datasplit/dataset2.yaml", help="Config files for using multiple datasets." ) parser.add_argument( "--num_episodes_per_dataset", type=int, default=1000000, help="Maximum number of trajectories per dataset", ) parser.add_argument( "--image_maskgit_path", type=str, default=None, help="Optional path to the official MaskGIT checkpoint. " "If specified, will copy relevant weights from the checkpoint. " "These weights will have a different (hard-coded) warmup schedule.", ) parser.add_argument( "--action_network", type=str, default=None, choices=["concat", "cross_attention"], # TODO: add other methods (resampler_concat, modulate, etc) help="If specified, will override the action in the config. Helps reduce the number of config jsons." ) args = parser.parse_args() return args def main(): args = parse_args_multi() assert (args.llama_config is not None) ^ (args.genie_config is not None), \ "Exactly one of `llama_config` and `genie_config` should be set." # Manual gradient accumulation ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to, even_batches=False, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs]) accelerator.init_trackers("video") if accelerator.is_main_process: accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() if args.seed is not None: set_seed(args.seed) if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) accelerator.wait_for_everyone() # create multiple datasets with open(args.train_split, 'r') as file: datasplit = yaml.safe_load(file) config = DiffusionGenieConfig.from_pretrained(args.genie_config) # Extract the 'domains' value and split it into a list domains_list = [domain.strip() for domain in datasplit['domains'].split(',')] train_datasets = [] val_datasets = [] dataset_num_samples = [] val_dataset_num_samples = [] action_dimensions = [] action_stats = [] shared_keys = ("s", "h", "w", "vocab_size", "latent_channels", "encoder_type", "encoder_name_or_path", "quantized") # TODO: check train/val hz per dataset? for domain in domains_list: try: # train_data_dir = f"data/{domain}_vae_traj500_train" # {args.num_episodes_per_dataset} # val_data_dir = f"data/{domain}_vae_traj500_val" train_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_train" # {args.num_episodes_per_dataset} val_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_val" # train_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_train" # {args.num_episodes_per_dataset} # val_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_val" if config.drop_action_ratio > 0: raise NotImplementedError train_dataset = RawFeatureDataset(train_data_dir, window_size=args.window_size, stride=args.stride, filter_overlaps=args.filter_overlaps, max_traj_num=args.num_episodes_per_dataset, use_actions=config.use_actions, domain=domain) dataset_num_samples.append(len(train_dataset)) action_dimensions.append(train_dataset.n_action) if config.use_actions: action_stats.append(train_dataset.action_stat) if not args.overfit_first_batch: eval_dataset = RawFeatureDataset(val_data_dir, window_size=args.window_size, stride=args.stride, filter_overlaps=True, use_actions=config.use_actions, domain=domain) else: train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size * args.gradient_accumulation_steps * accelerator.num_processes] eval_dataset = train_dataset # Shuffle eval dataset and then set shuffle=False on the dataloader. # Shuffling in the dataloader results in reshuffling with each iteration. eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[ torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0)) ].tolist() val_dataset_num_samples.append(len(eval_dataset)) except Exception as e: import traceback print(traceback.format_exc()) train_datasets.append(train_dataset) val_datasets.append(eval_dataset) assert all(train_dataset.metadata.get(shared_key) == eval_dataset.metadata.get(shared_key) for shared_key in shared_keys) # TODO: check this across all datasets print("dataset_num_samples:", dataset_num_samples) # Will not store key in metadata if it's missing, so that defaults can be filled by functions later? # TODO: handle missing keys shared_metadata = {shared_key: train_dataset.metadata[shared_key] for shared_key in shared_keys if shared_key in train_dataset.metadata} config.use_mup = args.mu_transfer # Note: changing this may affect pre-trained model due to attn scaling config.image_vocab_size = None config.T = args.window_size config.S = shared_metadata["h"] * shared_metadata["w"] # TODO: make STMaskGIT use h and w instead of S config.vae_embed_dim = shared_metadata["latent_channels"] if args.action_network is not None: print("Using action network", args.action_network) config.action_network = args.action_network model = STMAR(config) if config.use_actions: # TODO: use new list instead of domains_list, in case domain fails model.init_action_projectors(domains_list, action_dimensions, action_stats, config.action_network) if args.image_maskgit_path is not None: model.init_weights() model.load_pretrained_image_weights(args.image_maskgit_path) if args.mu_transfer: model.set_mup_shapes(rescale_params=False) elif args.mu_transfer: model.set_mup_shapes(rescale_params=True) # model.init_weights() # might be unnecessary if `rescale_params` is True # Optimizer. Split weights in two groups, one with weight decay and the other not. opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW # scale base learning rate effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ * accelerator.num_processes args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8) no_decay = ["bias", "layer_norm.weight"] pretrained_params = { # more accurately the params we want lower lr for, some weights like pos_embed_TSC are pre-trained but not treated as lower lr param_name for param_name, _ in model.named_parameters() if any(term in param_name for term in ("spatial_attn.qkv", "spatial_attn.proj", "mlp")) } if args.image_maskgit_path is not None else set() # Give pre-trained weights 10x lower learning rate optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and n not in pretrained_params], "weight_decay": args.weight_decay, "lr": args.learning_rate, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and n not in pretrained_params], "weight_decay": 0.0, "lr": args.learning_rate, }, { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and n in pretrained_params], "weight_decay": args.weight_decay, "lr": args.learning_rate * 0.1, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and n in pretrained_params], "weight_decay": 0.0, "lr": args.learning_rate * 0.1, }, ] optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate, betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps) # DataLoaders creation: collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator_feature(config) combined_dataset = torch.utils.data.ConcatDataset(train_datasets) batch_sampler = data_sampler.MultiTaskBatchSampler( dataset_num_samples, batch_size=args.per_device_train_batch_size, temperature=3. # the higher the more flat the distribution ) dataset_traj_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_num_samples) accelerator.log(({"dataset_mixture": wandb.Image(dataset_traj_image)}), log_kwargs={"wandb": {"commit": False}}) dataset_weights = batch_sampler.generate_tasks_distribution().cpu().numpy() dataset_weight_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_weights) accelerator.log(({"dataset_mixture_weight": wandb.Image(dataset_weight_image)}), log_kwargs={"wandb": {"commit": False}}) train_dataloader = DataLoader(combined_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=24, pin_memory=False) batch_val_sampler = data_sampler.MultiTaskBatchSampler( val_dataset_num_samples, batch_size=args.per_device_train_batch_size, temperature=4. # the higher the more flat the distribution ) combined_val_dataset = torch.utils.data.ConcatDataset(val_datasets) eval_dataloader = DataLoader(combined_val_dataset, batch_sampler=batch_val_sampler, collate_fn=collate_fn, num_workers=24, pin_memory=False) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps args.max_train_steps = 2000 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1): def get_lr(step): if step < warmup_steps: return (step + 1) / warmup_steps remaining_steps = max_steps - warmup_steps return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \ * (1 - end_ratio) + end_ratio return get_lr lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes, args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes) ) else: lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes, ) # Enable gradient checkpointing to save memory if args.gradient_checkpointing: logger.info("Enabling gradient checkpointing") model.gradient_checkpointing_enable() model.config.use_cache = False # incompatible with grad checkpointing # Prepare everything with our `accelerator`. accelerator.wait_for_everyone() model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) if not args.no_compile: torch._dynamo.config.cache_size_limit = 256 torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674 # TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776 model = torch.compile(model) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Figure out how many steps we should save the Accelerator states checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps) # We need to initialize the trackers we use, and also store our configuration. # The trackers initialize automatically on the main process. experiment_config = vars(args) | vars(config) seq_len = shared_metadata["h"] * shared_metadata["w"] * args.window_size effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ * accelerator.num_processes args.num_datasets = len(train_datasets) model_module = model.module if hasattr(model, "module") else model experiment_config.update(shared_metadata | { "model_parameters": sum(p.numel() for p in model.parameters()), "model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6), "trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()), "trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6), "seq_len": seq_len, "train_data_tokens": len(train_dataset) * seq_len, "effective_batch_size": effective_batch_size, "effective_batch_size_tokens": effective_batch_size * seq_len, "mixed_precision": accelerator.mixed_precision, "num_datasets": args.num_datasets }) experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \ * experiment_config["effective_batch_size_tokens"] accelerator.init_trackers(project_name="video", config=experiment_config) # Train! train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args) if __name__ == "__main__": main()