Spaces:
Runtime error
Runtime error
File size: 6,093 Bytes
bb5cd12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
import os
import deepspeed
import wandb
from torch.utils.data import random_split, ConcatDataset
from torch.optim import AdamW
from tqdm import tqdm
from functools import partial
from magma.datasets import (
collate_fn,
ImgCptDataset,
)
from magma.magma import (
Magma,
)
from magma.utils import (
is_main,
cycle,
parse_args,
wandb_log,
wandb_init,
save_model,
load_model,
print_main,
configure_param_groups,
)
from magma.train_loop import (
eval_step,
inference_step,
train_step,
)
def _load_img_cpt_datasets(dataset_dir, tokenizer, transforms):
if isinstance(dataset_dir, (list, tuple)):
return ConcatDataset(
[_load_img_cpt_datasets(d, tokenizer, transforms) for d in dataset_dir]
)
elif isinstance(dataset_dir, str):
return ImgCptDataset(dataset_dir, tokenizer=tokenizer, transforms=transforms)
else:
raise TypeError("dataset dir wrong type")
def get_pretraining_datasets(config, tokenizer, transforms):
# if config.train_dataset_dir is a list, load all datasets + join together
train_dataset = _load_img_cpt_datasets(
config.train_dataset_dir, tokenizer, transforms
)
# if no dedicated eval sets are given, use a percentage of the train dataset
if config.eval_dataset_dir is None:
eval_len = int(len(train_dataset) * config.eval_dataset_pct)
train_len = len(train_dataset) - eval_len
print(
f"Randomly splitting train_dataset into two datasets of length {train_len} and {eval_len}"
)
train_dataset, eval_dataset = random_split(train_dataset, [train_len, eval_len])
else:
eval_dataset = _load_img_cpt_datasets(
config.eval_dataset_dir, tokenizer, transforms
)
print_main(f"Loaded train dataset with {len(train_dataset)} samples")
print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
return train_dataset, eval_dataset
# tell tokenizers not to do parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if __name__ == "__main__":
# parse command line arguments:
args = parse_args()
deepspeed.init_distributed()
# load model + tokenizer:
model = Magma(
args.config
) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here
tokenizer, config, transforms = model.tokenizer, model.config, model.transforms
# filter frozen from trainable parameters:
trainable_parameters = configure_param_groups(model, config)
# load data:
train_dataset, eval_dataset = get_pretraining_datasets(
config, tokenizer, transforms
)
print_main(f"Loaded train dataset with {len(train_dataset)} samples")
print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
opt = AdamW(
trainable_parameters,
config.lr,
betas=(0.9, 0.95),
weight_decay=config.weight_decay,
)
model_engine, opt, train_loader, lr_scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=opt,
model_parameters=trainable_parameters,
training_data=train_dataset,
collate_fn=partial(collate_fn, seq_len=model.seq_len),
config_params=config.deepspeed_config_params,
)
eval_loader = cycle(model_engine.deepspeed_io(eval_dataset))
train_loader = cycle(train_loader)
# initialize training
global_step = 0
if config.load:
# loads a deepspeed checkpoint if provided. For finetuning, set load_optimizer to false
previous_global_step = load_model(
model_engine,
config.load,
load_optimizer_states=config.load_optimizer,
load_lr_scheduler_states=config.load_optimizer,
)
if config.load_optimizer:
global_step = previous_global_step
pbar = tqdm(
range(0, config.train_steps),
desc="training...",
initial=global_step,
total=config.train_steps,
disable=not is_main(),
)
wandb_init(
project=config.wandb_project,
name=config.name or wandb.util.generate_id(),
config=config,
)
# training loop
for i in pbar:
if global_step >= config.train_steps:
break
##### train step
loss = train_step(config, train_loader, model_engine)
global_step += 1
if global_step % config.log_every == 0:
pbar.set_description(f"training... Step: {global_step} Loss: {loss}")
current_lr = (
[lr for lr in lr_scheduler.get_lr()]
if lr_scheduler is not None
else config.lr
)
to_log = {"train/loss": loss, "train/lr": current_lr}
wandb_log(to_log, step=global_step)
##### Evaluation phase
if global_step % config.eval_every == 0:
model_engine.eval()
with torch.no_grad():
##### eval step:
eval_loss = eval_step(config, eval_loader, model_engine)
wandb_log({"eval/loss": eval_loss}, step=global_step)
pbar.set_description(
f"evaluating... Step: {global_step} Eval Loss: {eval_loss}"
)
##### inference:
image_grid, caption = inference_step(config, eval_loader, model_engine)
wandb_log(
{"inference/image": wandb.Image(image_grid, caption=caption)},
step=global_step,
)
model_engine.train()
##### Save model
if global_step % config.save_every == 0:
if config.save is not None:
save_model(model_engine, config.save, global_step)
print_main(f"saving model at step {global_step}")
##### Save model after training is finished
if config.save is not None:
save_model(model_engine, config.save, global_step)
print_main(f"saving model at end of training (step {global_step})")
|