import torch import os import deepspeed import wandb from torch.utils.data import random_split, ConcatDataset from torch.optim import AdamW from tqdm import tqdm from functools import partial from magma.datasets import ( collate_fn, ImgCptDataset, ) from magma.magma import ( Magma, ) from magma.utils import ( is_main, cycle, parse_args, wandb_log, wandb_init, save_model, load_model, print_main, configure_param_groups, ) from magma.train_loop import ( eval_step, inference_step, train_step, ) def _load_img_cpt_datasets(dataset_dir, tokenizer, transforms): if isinstance(dataset_dir, (list, tuple)): return ConcatDataset( [_load_img_cpt_datasets(d, tokenizer, transforms) for d in dataset_dir] ) elif isinstance(dataset_dir, str): return ImgCptDataset(dataset_dir, tokenizer=tokenizer, transforms=transforms) else: raise TypeError("dataset dir wrong type") def get_pretraining_datasets(config, tokenizer, transforms): # if config.train_dataset_dir is a list, load all datasets + join together train_dataset = _load_img_cpt_datasets( config.train_dataset_dir, tokenizer, transforms ) # if no dedicated eval sets are given, use a percentage of the train dataset if config.eval_dataset_dir is None: eval_len = int(len(train_dataset) * config.eval_dataset_pct) train_len = len(train_dataset) - eval_len print( f"Randomly splitting train_dataset into two datasets of length {train_len} and {eval_len}" ) train_dataset, eval_dataset = random_split(train_dataset, [train_len, eval_len]) else: eval_dataset = _load_img_cpt_datasets( config.eval_dataset_dir, tokenizer, transforms ) print_main(f"Loaded train dataset with {len(train_dataset)} samples") print_main(f"Loaded eval dataset with {len(eval_dataset)} samples") return train_dataset, eval_dataset # tell tokenizers not to do parallelism os.environ["TOKENIZERS_PARALLELISM"] = "false" if __name__ == "__main__": # parse command line arguments: args = parse_args() deepspeed.init_distributed() # load model + tokenizer: model = Magma( args.config ) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here tokenizer, config, transforms = model.tokenizer, model.config, model.transforms # filter frozen from trainable parameters: trainable_parameters = configure_param_groups(model, config) # load data: train_dataset, eval_dataset = get_pretraining_datasets( config, tokenizer, transforms ) print_main(f"Loaded train dataset with {len(train_dataset)} samples") print_main(f"Loaded eval dataset with {len(eval_dataset)} samples") opt = AdamW( trainable_parameters, config.lr, betas=(0.9, 0.95), weight_decay=config.weight_decay, ) model_engine, opt, train_loader, lr_scheduler = deepspeed.initialize( args=args, model=model, optimizer=opt, model_parameters=trainable_parameters, training_data=train_dataset, collate_fn=partial(collate_fn, seq_len=model.seq_len), config_params=config.deepspeed_config_params, ) eval_loader = cycle(model_engine.deepspeed_io(eval_dataset)) train_loader = cycle(train_loader) # initialize training global_step = 0 if config.load: # loads a deepspeed checkpoint if provided. For finetuning, set load_optimizer to false previous_global_step = load_model( model_engine, config.load, load_optimizer_states=config.load_optimizer, load_lr_scheduler_states=config.load_optimizer, ) if config.load_optimizer: global_step = previous_global_step pbar = tqdm( range(0, config.train_steps), desc="training...", initial=global_step, total=config.train_steps, disable=not is_main(), ) wandb_init( project=config.wandb_project, name=config.name or wandb.util.generate_id(), config=config, ) # training loop for i in pbar: if global_step >= config.train_steps: break ##### train step loss = train_step(config, train_loader, model_engine) global_step += 1 if global_step % config.log_every == 0: pbar.set_description(f"training... Step: {global_step} Loss: {loss}") current_lr = ( [lr for lr in lr_scheduler.get_lr()] if lr_scheduler is not None else config.lr ) to_log = {"train/loss": loss, "train/lr": current_lr} wandb_log(to_log, step=global_step) ##### Evaluation phase if global_step % config.eval_every == 0: model_engine.eval() with torch.no_grad(): ##### eval step: eval_loss = eval_step(config, eval_loader, model_engine) wandb_log({"eval/loss": eval_loss}, step=global_step) pbar.set_description( f"evaluating... Step: {global_step} Eval Loss: {eval_loss}" ) ##### inference: image_grid, caption = inference_step(config, eval_loader, model_engine) wandb_log( {"inference/image": wandb.Image(image_grid, caption=caption)}, step=global_step, ) model_engine.train() ##### Save model if global_step % config.save_every == 0: if config.save is not None: save_model(model_engine, config.save, global_step) print_main(f"saving model at step {global_step}") ##### Save model after training is finished if config.save is not None: save_model(model_engine, config.save, global_step) print_main(f"saving model at end of training (step {global_step})")