|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Finetune utilities.""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from megatron import get_args |
|
from megatron import print_rank_0 |
|
from megatron import get_timers |
|
from megatron import mpu, utils |
|
from megatron.checkpointing import load_checkpoint |
|
from megatron.checkpointing import save_checkpoint |
|
from megatron.training import evaluate_and_print_results |
|
from megatron.training import setup_model_and_optimizer |
|
from megatron.training import train_step |
|
from megatron.training import training_log |
|
from megatron.utils import check_adlr_autoresume_termination |
|
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm |
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP |
|
from megatron.model import DistributedDataParallel as LocalDDP |
|
from megatron.model import Float16Module, ModelType |
|
|
|
|
|
def process_batch(batch): |
|
"""Process batch and produce inputs for the model.""" |
|
images = batch[0].cuda().contiguous() |
|
labels = batch[1].cuda().contiguous() |
|
return images, labels |
|
|
|
|
|
def build_data_loader(dataset, micro_batch_size, |
|
num_workers, drop_last, shuffle): |
|
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" |
|
|
|
|
|
world_size = mpu.get_data_parallel_world_size() |
|
rank = mpu.get_data_parallel_rank() |
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
dataset, num_replicas=world_size, rank=rank, |
|
drop_last=drop_last, shuffle=shuffle |
|
) |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=micro_batch_size, |
|
sampler=sampler, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
drop_last=drop_last, |
|
pin_memory=True, |
|
) |
|
|
|
return data_loader |
|
|
|
|
|
def _build_infinite_size_dataloader(dataloader): |
|
"""Build a looped dataloader with infinite size.""" |
|
|
|
iterator = dataloader.__iter__() |
|
while True: |
|
try: |
|
yield iterator.__next__() |
|
except StopIteration: |
|
iterator = dataloader.__iter__() |
|
|
|
|
|
def _build_train_valid_dataloaders(train_dataset, valid_dataset): |
|
"""Traing and validation dataloaders.""" |
|
args = get_args() |
|
|
|
print_rank_0('building train and validation dataloaders ...') |
|
|
|
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, |
|
args.num_workers, False, True) |
|
|
|
args.train_iters_per_epoch = len(train_dataloader) |
|
args.train_iters = args.epochs * args.train_iters_per_epoch |
|
|
|
|
|
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, |
|
args.num_workers, True, False) |
|
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) |
|
|
|
|
|
|
|
|
|
|
|
|
|
args.orig_micro_batch_size = args.micro_batch_size |
|
args.orig_global_batch_size = args.global_batch_size |
|
|
|
return train_dataloader, valid_dataloader |
|
|
|
|
|
def _train( |
|
model, |
|
optimizer, |
|
opt_param_scheduler, |
|
forward_step, |
|
train_dataloader, |
|
valid_dataloader, |
|
end_of_epoch_callback, |
|
process_non_loss_data_func=None |
|
): |
|
"""Train the model.""" |
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
for m in model: |
|
m.train() |
|
|
|
|
|
losses_dict_sum = {} |
|
|
|
|
|
start_epoch = args.iteration // args.train_iters_per_epoch |
|
start_iteration = args.iteration % args.train_iters_per_epoch |
|
iteration = args.iteration |
|
|
|
|
|
report_memory_flag = True |
|
|
|
|
|
timers("interval-time").start() |
|
for epoch in range(start_epoch, args.epochs): |
|
print_rank_0("working on epoch {} ...".format(epoch + 1)) |
|
|
|
|
|
train_dataloader.sampler.set_epoch(args.seed + epoch) |
|
train_dataloader.dataset.set_epoch(epoch) |
|
|
|
|
|
for iteration_, batch in enumerate(train_dataloader): |
|
|
|
|
|
if iteration_ < start_iteration: |
|
continue |
|
|
|
start_iteration = 0 |
|
|
|
|
|
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( |
|
forward_step, batch, model, optimizer, opt_param_scheduler |
|
) |
|
iteration += 1 |
|
|
|
|
|
params_norm = None |
|
|
|
report_memory_flag = training_log( |
|
losses_dict, |
|
losses_dict_sum, |
|
optimizer.param_groups[0]["lr"], |
|
iteration, |
|
optimizer.get_loss_scale().item(), |
|
report_memory_flag, |
|
skipped_iter, |
|
grad_norm, |
|
params_norm, |
|
num_zeros_in_grad |
|
) |
|
|
|
|
|
if args.adlr_autoresume and \ |
|
iteration % args.adlr_autoresume_interval == 0: |
|
check_adlr_autoresume_termination(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
|
|
|
|
if args.save and args.save_interval and \ |
|
iteration % args.save_interval == 0: |
|
save_checkpoint(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
|
|
|
|
if args.eval_interval and iteration % args.eval_interval == 0: |
|
prefix = "iteration {}".format(iteration) |
|
evaluate_and_print_results( |
|
prefix, |
|
forward_step, |
|
valid_dataloader, |
|
model, |
|
iteration, |
|
process_non_loss_data_func, |
|
False, |
|
) |
|
|
|
|
|
if end_of_epoch_callback is not None: |
|
end_of_epoch_callback(model, epoch) |
|
|
|
|
|
def finetune( |
|
train_valid_datasets_provider, |
|
model_provider, |
|
forward_step, |
|
model_type=ModelType.encoder_or_decoder, |
|
process_non_loss_data_func=None, |
|
end_of_epoch_callback_provider=None, |
|
): |
|
"""Main finetune function used across all tasks.""" |
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
timers("train/valid/test dataset/dataloder").start() |
|
if args.epochs > 0: |
|
train_dataset, valid_dataset = train_valid_datasets_provider() |
|
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( |
|
train_dataset, valid_dataset |
|
) |
|
timers("train/valid/test dataset/dataloder").stop() |
|
|
|
|
|
timers("callback function").start() |
|
end_of_epoch_callback = None |
|
if end_of_epoch_callback_provider is not None: |
|
end_of_epoch_callback = end_of_epoch_callback_provider() |
|
timers("callback function").stop() |
|
|
|
|
|
timers("model and optimizer").start() |
|
model, optimizer, opt_param_scheduler = \ |
|
setup_model_and_optimizer( |
|
model_provider, |
|
model_type, |
|
scale_lr_cond=lambda name, param: ".head." in name, |
|
lr_mult=args.head_lr_mult) |
|
timers("model and optimizer").stop() |
|
|
|
|
|
|
|
|
|
timers("pretrained checkpoint").start() |
|
if args.iteration == 0 and args.pretrained_checkpoint is not None: |
|
if args.pretrained_checkpoint_type == 'default': |
|
original_load = args.load |
|
args.load = args.pretrained_checkpoint |
|
_ = load_checkpoint(model, None, None, strict=False) |
|
args.load = original_load |
|
elif args.pretrained_checkpoint_type == 'external': |
|
unwrap_model = utils.unwrap_model(model) |
|
state_dict = torch.load(args.pretrained_checkpoint, |
|
map_location="cpu") |
|
unwrap_model[0].module.backbone.load_state_dict(state_dict, |
|
strict=False) |
|
elif args.pretrained_checkpoint_type == 'constrastive': |
|
unwrap_model = utils.unwrap_model(model) |
|
state_dict = torch.load(args.pretrained_checkpoint, |
|
map_location="cpu") |
|
state_dict = state_dict["model"] |
|
state_dict = {k.replace("teacher.backbone.", ""): v |
|
for k, v in state_dict.items() |
|
if k.startswith("teacher.backbone.")} |
|
unwrap_model[0].module.backbone.load_state_dict(state_dict, |
|
strict=False) |
|
else: |
|
raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type)) |
|
|
|
|
|
|
|
optimizer.reload_model_params() |
|
|
|
timers("pretrained checkpoint").stop() |
|
|
|
|
|
print_rank_0("done with setups ...") |
|
timers.log( |
|
[ |
|
"train/valid/test dataset/dataloder", |
|
"callback function", |
|
"model and optimizer", |
|
"pretrained checkpoint", |
|
] |
|
) |
|
print_rank_0("training ...") |
|
|
|
|
|
if args.epochs > 0: |
|
_train( |
|
model, |
|
optimizer, |
|
opt_param_scheduler, |
|
forward_step, |
|
train_dataloader, |
|
valid_dataloader, |
|
end_of_epoch_callback, |
|
process_non_loss_data_func, |
|
) |
|
|
|
else: |
|
if end_of_epoch_callback is not None: |
|
print_rank_0("evaluation only mode, setting epoch to -1") |
|
end_of_epoch_callback(model, epoch=-1) |
|
|
|
print_rank_0("done :-)") |
|
|
|
|