Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import math | |
import os | |
import mup | |
import numpy as np | |
import torch | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import set_seed | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
import transformers | |
from transformers import ( | |
default_data_collator, | |
get_scheduler, | |
) | |
import wandb | |
from data import RawTokenDataset, get_maskgit_collator | |
from genie.st_mask_git import GenieConfig, STMaskGIT | |
from datetime import datetime | |
from accelerate import DistributedDataParallelKwargs | |
from common import data_sampler | |
import yaml | |
from train import parse_args, train | |
# Get current date and time | |
now = datetime.now() | |
# Format the datetime object as a string | |
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") | |
torch.set_float32_matmul_precision("medium") | |
logger = get_logger(__name__) | |
torch.autograd.set_detect_anomaly(True) | |
def parse_args_multi(): | |
# parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.") | |
parser = parse_args() | |
# Data | |
parser.add_argument( | |
"--train_split", type=str, default="experiments/datasplit/dataset2.yaml", | |
help="Config files for using multiple datasets." | |
) | |
parser.add_argument( | |
"--num_episodes_per_dataset", | |
type=int, | |
default=1000000, | |
help="Maximum number of trajectories per dataset", | |
) | |
parser.add_argument( | |
"--image_maskgit_path", | |
type=str, | |
default=None, | |
help="Optional path to the official MaskGIT checkpoint. " | |
"If specified, will copy relevant weights from the checkpoint. " | |
"These weights will have a different (hard-coded) warmup schedule.", | |
) | |
parser.add_argument( | |
"--action_network", | |
type=str, | |
default=None, | |
choices=["concat", "cross_attention"], # TODO: add other methods (resampler_concat, modulate, etc) | |
help="If specified, will override the action in the config. Helps reduce the number of config jsons." | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args_multi() | |
assert (args.llama_config is not None) ^ (args.genie_config is not None), \ | |
"Exactly one of `llama_config` and `genie_config` should be set." | |
# Manual gradient accumulation | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to, | |
even_batches=False, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs]) | |
accelerator.init_trackers("video") | |
if accelerator.is_main_process: | |
accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
if args.seed is not None: | |
set_seed(args.seed) | |
if accelerator.is_main_process: | |
os.makedirs(args.output_dir, exist_ok=True) | |
accelerator.wait_for_everyone() | |
# create multiple datasets | |
with open(args.train_split, 'r') as file: | |
datasplit = yaml.safe_load(file) | |
config = GenieConfig.from_pretrained(args.genie_config) | |
# Extract the 'domains' value and split it into a list | |
domains_list = [domain.strip() for domain in datasplit['domains'].split(',')] | |
train_datasets = [] | |
val_datasets = [] | |
dataset_num_samples = [] | |
val_dataset_num_samples = [] | |
action_dimensions = [] | |
action_stats = [] | |
total_num_videos = 0 | |
for domain in domains_list: | |
try: | |
train_data_dir = f"data/{domain}_magvit_traj1000000_train" # {args.num_episodes_per_dataset} | |
val_data_dir = f"data/{domain}_magvit_traj1000000_val" | |
train_dataset = RawTokenDataset(train_data_dir, window_size=args.window_size, name=domain, | |
stride=args.stride, filter_overlaps=args.filter_overlaps, | |
max_traj_num=args.num_episodes_per_dataset, | |
use_actions=config.use_actions, drop_action_ratio=config.drop_action_ratio) | |
dataset_num_samples.append(len(train_dataset)) | |
action_dimensions.append(train_dataset.n_action) | |
total_num_videos += train_dataset.num_videos | |
if config.use_actions: | |
action_stats.append(train_dataset.action_stat) | |
if not args.overfit_first_batch: | |
eval_dataset = RawTokenDataset(val_data_dir, window_size=args.window_size, name=domain, | |
stride=args.stride, filter_overlaps=True, | |
use_actions=config.use_actions, drop_action_ratio=config.drop_action_ratio) | |
else: | |
train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size | |
* args.gradient_accumulation_steps | |
* accelerator.num_processes] | |
eval_dataset = train_dataset | |
# Shuffle eval dataset and then set shuffle=False on the dataloader. | |
# Shuffling in the dataloader results in reshuffling with each iteration. | |
eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[ | |
torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0)) | |
].tolist() | |
val_dataset_num_samples.append(len(eval_dataset)) | |
except Exception as e: | |
import traceback | |
print(traceback.format_exc()) | |
train_datasets.append(train_dataset) | |
val_datasets.append(eval_dataset) | |
assert all(train_dataset.metadata[shared_key] == eval_dataset.metadata[shared_key] | |
for shared_key in ("s", "vocab_size", "hz")) | |
print("dataset_num_samples:", dataset_num_samples) | |
latent_side_len, vocab_size, hz = [train_dataset.metadata[key] for key in ("s", "vocab_size", "hz")] | |
config.use_mup = args.mu_transfer # Note: changing this may affect pre-trained model due to attn scaling | |
config.image_vocab_size = vocab_size | |
config.T = args.window_size | |
if args.action_network is not None: | |
print("Using action network", args.action_network) | |
config.action_network = args.action_network | |
# config.S = latent_side_len**2 | |
model = STMaskGIT(config) | |
if config.use_actions: | |
model.init_action_projectors(domains_list, action_dimensions, action_stats, config.action_network) | |
if args.image_maskgit_path is not None: | |
model.init_weights() | |
model.load_pretrained_image_weights(args.image_maskgit_path) | |
if args.mu_transfer: | |
model.set_mup_shapes(rescale_params=False) | |
elif args.mu_transfer: | |
model.set_mup_shapes(rescale_params=True) | |
# model.init_weights() # might be unnecessary if `rescale_params` is True | |
# Optimizer. Split weights in two groups, one with weight decay and the other not. | |
opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW | |
# scale base learning rate | |
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ | |
* accelerator.num_processes | |
args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8) | |
no_decay = ["bias", "layer_norm.weight"] | |
pretrained_params = { # more accurately the params we want lower lr for, some weights like pos_embed_TSC are pre-trained but not treated as lower lr | |
param_name | |
for param_name, _ in model.named_parameters() | |
if any(term in param_name for term in ("spatial_attn.qkv", "spatial_attn.proj", "mlp")) | |
} if args.image_maskgit_path is not None else set() | |
# Give pre-trained weights 10x lower learning rate | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() | |
if not any(nd in n for nd in no_decay) and n not in pretrained_params], | |
"weight_decay": args.weight_decay, | |
"lr": args.learning_rate, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() | |
if any(nd in n for nd in no_decay) and n not in pretrained_params], | |
"weight_decay": 0.0, | |
"lr": args.learning_rate, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() | |
if not any(nd in n for nd in no_decay) and n in pretrained_params], | |
"weight_decay": args.weight_decay, | |
"lr": args.learning_rate * 0.1, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() | |
if any(nd in n for nd in no_decay) and n in pretrained_params], | |
"weight_decay": 0.0, | |
"lr": args.learning_rate * 0.1, | |
}, | |
] | |
optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate, | |
betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps) | |
# DataLoaders creation: | |
collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator(config) | |
combined_dataset = torch.utils.data.ConcatDataset(train_datasets) | |
batch_sampler = data_sampler.MultiTaskBatchSampler( | |
dataset_num_samples, | |
batch_size=args.per_device_train_batch_size, | |
temperature=3. # the higher the more flat the distribution | |
) | |
dataset_traj_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_num_samples) | |
accelerator.log(({"dataset_mixture": wandb.Image(dataset_traj_image)}), log_kwargs={"wandb": {"commit": False}}) | |
dataset_weights = batch_sampler.generate_tasks_distribution().cpu().numpy() | |
dataset_weight_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_weights) | |
accelerator.log(({"dataset_mixture_weight": wandb.Image(dataset_weight_image)}), log_kwargs={"wandb": {"commit": False}}) | |
train_dataloader = DataLoader(combined_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, | |
num_workers=16, pin_memory=True) | |
batch_val_sampler = data_sampler.MultiTaskBatchSampler( | |
val_dataset_num_samples, | |
batch_size=args.per_device_train_batch_size, | |
temperature=4. # the higher the more flat the distribution | |
) | |
combined_val_dataset = torch.utils.data.ConcatDataset(val_datasets) | |
eval_dataloader = DataLoader(combined_val_dataset, batch_sampler=batch_val_sampler, collate_fn=collate_fn, | |
num_workers=16, pin_memory=True) | |
# Scheduler and math around the number of training steps. | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps | |
args.max_train_steps = 2000 | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate | |
def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1): | |
def get_lr(step): | |
if step < warmup_steps: | |
return (step + 1) / warmup_steps | |
remaining_steps = max_steps - warmup_steps | |
return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \ | |
* (1 - end_ratio) + end_ratio | |
return get_lr | |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes, | |
args.max_train_steps if overrode_max_train_steps | |
else args.max_train_steps * accelerator.num_processes) | |
) | |
else: | |
lr_scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, | |
num_training_steps=args.max_train_steps | |
if overrode_max_train_steps | |
else args.max_train_steps * accelerator.num_processes, | |
) | |
# Enable gradient checkpointing to save memory | |
if args.gradient_checkpointing: | |
logger.info("Enabling gradient checkpointing") | |
model.gradient_checkpointing_enable() | |
model.config.use_cache = False # incompatible with grad checkpointing | |
# Prepare everything with our `accelerator`. | |
accelerator.wait_for_everyone() | |
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler | |
) | |
if not args.no_compile: | |
torch._dynamo.config.cache_size_limit = 256 | |
torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674 | |
# TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776 | |
model = torch.compile(model) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# Figure out how many steps we should save the Accelerator states | |
checkpointing_steps = args.checkpointing_steps | |
if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
checkpointing_steps = int(checkpointing_steps) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initialize automatically on the main process. | |
experiment_config = vars(args) | vars(config) | |
seq_len = latent_side_len**2 * args.window_size | |
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ | |
* accelerator.num_processes | |
args.num_datasets = len(train_datasets) | |
model_module = model.module if hasattr(model, "module") else model | |
experiment_config.update({ | |
"model_parameters": sum(p.numel() for p in model.parameters()), | |
"model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6), | |
"trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()), | |
"trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6), | |
"seq_len": seq_len, | |
"hz": hz / args.stride if args.stride is not None else hz, | |
"train_data_tokens": len(train_dataset) * seq_len, | |
"effective_batch_size": effective_batch_size, | |
"effective_batch_size_tokens": effective_batch_size * seq_len, | |
"mixed_precision": accelerator.mixed_precision, | |
"num_datasets": args.num_datasets, | |
"total_num_videos": total_num_videos, | |
}) | |
experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \ | |
* experiment_config["effective_batch_size_tokens"] | |
accelerator.init_trackers(project_name="video", config=experiment_config) | |
# Train! | |
train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args) | |
if __name__ == "__main__": | |
main() | |