|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pretrain utilities.""" |
|
|
|
from datetime import datetime |
|
import math |
|
import sys |
|
import time |
|
|
|
_TRAIN_START_TIME = time.time() |
|
import torch |
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from megatron import get_args |
|
from megatron import get_signal_handler |
|
from megatron import get_timers |
|
from megatron import get_tensorboard_writer |
|
from megatron import get_current_global_batch_size |
|
from megatron import get_num_microbatches |
|
from megatron import is_last_rank |
|
from megatron import update_num_microbatches |
|
from megatron import mpu |
|
from megatron import print_rank_0 |
|
from megatron import print_rank_last |
|
from megatron.checkpointing import load_checkpoint |
|
from megatron.checkpointing import save_checkpoint |
|
from megatron.model import Float16Module |
|
from megatron.model import ModelType |
|
from megatron.optimizer import get_megatron_optimizer |
|
from megatron.initialize import initialize_megatron |
|
from megatron.initialize import write_args_to_tensorboard |
|
from megatron.initialize import set_jit_fusion_options |
|
from megatron.optimizer_param_scheduler import OptimizerParamScheduler |
|
from megatron.model import DistributedDataParallel as LocalDDP |
|
from megatron.utils import check_adlr_autoresume_termination |
|
from megatron.utils import unwrap_model |
|
from megatron.data.data_samplers import build_pretraining_data_loader |
|
from megatron.utils import calc_params_l2_norm |
|
from megatron.schedules import get_forward_backward_func |
|
from megatron.utils import report_memory |
|
from megatron.model.vision.knn_monitor import compute_feature_bank |
|
|
|
|
|
def print_datetime(string): |
|
"""Note that this call will sync across all ranks.""" |
|
torch.distributed.barrier() |
|
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
print_rank_0('[' + string + '] datetime: {} '.format(time_str)) |
|
|
|
|
|
def pretrain(train_valid_test_dataset_provider, |
|
model_provider, |
|
model_type, |
|
forward_step_func, |
|
process_non_loss_data_func=None, |
|
extra_args_provider=None, |
|
args_defaults={}): |
|
"""Main training program. |
|
|
|
This function will run the followings in the order provided: |
|
1) initialize Megatron. |
|
2) setup model, optimizer and lr schedule using the model_provider. |
|
3) call train_val_test_data_provider to get train/val/test datasets. |
|
4) train the modle using the forward_step_func. |
|
|
|
Arguments: |
|
train_valid_test_dataset_provider: a function that takes the size of |
|
train/valid/test dataset and returns `train, valid, test` datasets. |
|
model_provider: a function that returns a vanilla version of the |
|
model. By vanilla we mean a simple model on cpu with no fp16 or ddp. |
|
model_type: an enum that specifies the type of model being trained. |
|
forward_step_func: a function that takes a `data iterator` and `model`, |
|
and returns a `loss` scalar with a dictionary with key:values being |
|
the info we would like to monitor during training, for example |
|
`lm-loss: value`. We also require that this function add |
|
`batch generator` to the timers class. |
|
process_non_loss_data_func: a function to post process outputs of the |
|
network. It can be used for dumping output tensors (e.g images) to |
|
tensorboard. It takes `collected data`(list of tensors), |
|
`current iteration index` and `tensorboard writer` as arguments. |
|
extra_args_provider: a function that takes a parser and adds arguments |
|
to it. It is used for programs to add their own arguments. |
|
args_defaults: a dictionary from argument-name to argument-value. It |
|
to set already parse arguments. |
|
""" |
|
|
|
|
|
initialize_megatron(extra_args_provider=extra_args_provider, |
|
args_defaults=args_defaults) |
|
|
|
set_jit_fusion_options() |
|
|
|
|
|
|
|
|
|
global _TRAIN_START_TIME |
|
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME]) |
|
torch.distributed.all_reduce(start_time_tensor, |
|
op=torch.distributed.ReduceOp.MIN) |
|
_TRAIN_START_TIME = start_time_tensor.item() |
|
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( |
|
time.time() - _TRAIN_START_TIME)) |
|
print_datetime('after megatron is initialized') |
|
|
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
timers('model-and-optimizer-setup').start() |
|
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, |
|
model_type) |
|
timers('model-and-optimizer-setup').stop() |
|
print_datetime('after model, optimizer, and learning rate ' |
|
'scheduler are built') |
|
|
|
|
|
timers('train/valid/test-data-iterators-setup').start() |
|
if args.virtual_pipeline_model_parallel_size is not None: |
|
all_data_iterators = [ |
|
build_train_valid_test_data_iterators(train_valid_test_dataset_provider) |
|
for _ in range(len(model)) |
|
] |
|
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators] |
|
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators] |
|
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators] |
|
else: |
|
train_data_iterator, valid_data_iterator, test_data_iterator \ |
|
= build_train_valid_test_data_iterators( |
|
train_valid_test_dataset_provider) |
|
timers('train/valid/test-data-iterators-setup').stop() |
|
print_datetime('after dataloaders are built') |
|
|
|
|
|
print_rank_0('done with setup ...') |
|
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) |
|
print_rank_0('training ...') |
|
|
|
iteration = 0 |
|
if args.do_train and args.train_iters > 0: |
|
iteration = train(forward_step_func, |
|
model, optimizer, opt_param_scheduler, |
|
train_data_iterator, valid_data_iterator, |
|
process_non_loss_data_func) |
|
print_datetime('after training is done') |
|
|
|
if args.do_valid: |
|
prefix = 'the end of training for val data' |
|
evaluate_and_print_results(prefix, forward_step_func, |
|
valid_data_iterator, model, |
|
iteration, process_non_loss_data_func, |
|
False) |
|
|
|
if args.save and iteration != 0: |
|
save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
|
|
|
if args.do_test: |
|
|
|
prefix = 'the end of training for test data' |
|
evaluate_and_print_results(prefix, forward_step_func, |
|
test_data_iterator, model, |
|
0, process_non_loss_data_func, |
|
True) |
|
|
|
def update_train_iters(args): |
|
|
|
|
|
if args.train_iters: |
|
return |
|
|
|
|
|
if args.rampup_batch_size is None: |
|
args.train_iters = args.train_samples // args.global_batch_size |
|
|
|
else: |
|
|
|
iterations = 0 |
|
consumed_samples = 0 |
|
|
|
while consumed_samples <= int(args.rampup_batch_size[2]): |
|
update_num_microbatches(consumed_samples, consistency_check=False) |
|
consumed_samples += get_current_global_batch_size() |
|
iterations += 1 |
|
|
|
update_num_microbatches(0, consistency_check=False) |
|
|
|
|
|
iterations += (args.train_samples - consumed_samples) // \ |
|
args.global_batch_size |
|
args.train_iters = iterations |
|
|
|
print_rank_0('setting training iterations to {}'.format(args.train_iters)) |
|
|
|
|
|
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): |
|
"""Build the model.""" |
|
args = get_args() |
|
args.model_type = model_type |
|
|
|
|
|
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ |
|
args.virtual_pipeline_model_parallel_size is not None: |
|
assert model_type != ModelType.encoder_and_decoder, \ |
|
"Interleaved schedule not supported for model with both encoder and decoder" |
|
model = [] |
|
for i in range(args.virtual_pipeline_model_parallel_size): |
|
mpu.set_virtual_pipeline_model_parallel_rank(i) |
|
|
|
pre_process = mpu.is_pipeline_first_stage() |
|
post_process = mpu.is_pipeline_last_stage() |
|
this_model = model_provider_func( |
|
pre_process=pre_process, |
|
post_process=post_process |
|
) |
|
this_model.model_type = model_type |
|
model.append(this_model) |
|
else: |
|
pre_process = mpu.is_pipeline_first_stage() |
|
post_process = mpu.is_pipeline_last_stage() |
|
add_encoder = True |
|
add_decoder = True |
|
if model_type == ModelType.encoder_and_decoder: |
|
if mpu.get_pipeline_model_parallel_world_size() > 1: |
|
assert args.pipeline_model_parallel_split_rank is not None, \ |
|
"Split rank needs to be specified for model with both encoder and decoder" |
|
rank = mpu.get_pipeline_model_parallel_rank() |
|
split_rank = args.pipeline_model_parallel_split_rank |
|
world_size = mpu.get_pipeline_model_parallel_world_size() |
|
pre_process = rank == 0 or rank == split_rank |
|
post_process = (rank == (split_rank - 1)) or ( |
|
rank == (world_size - 1)) |
|
add_encoder = mpu.is_pipeline_stage_before_split() |
|
add_decoder = mpu.is_pipeline_stage_after_split() |
|
model = model_provider_func( |
|
pre_process=pre_process, |
|
post_process=post_process, |
|
add_encoder=add_encoder, |
|
add_decoder=add_decoder) |
|
else: |
|
model = model_provider_func( |
|
pre_process=pre_process, |
|
post_process=post_process |
|
) |
|
model.model_type = model_type |
|
|
|
if not isinstance(model, list): |
|
model = [model] |
|
|
|
|
|
|
|
|
|
|
|
for model_module in model: |
|
for param in model_module.parameters(): |
|
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) |
|
|
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print(' > number of parameters on (tensor, pipeline) ' |
|
'model parallel rank ({}, {}): {}'.format( |
|
mpu.get_tensor_model_parallel_rank(), |
|
mpu.get_pipeline_model_parallel_rank(), |
|
sum([sum([p.nelement() for p in model_module.parameters()]) |
|
for model_module in model])), flush=True) |
|
|
|
|
|
for model_module in model: |
|
model_module.cuda(torch.cuda.current_device()) |
|
|
|
|
|
if args.fp16 or args.bf16: |
|
model = [Float16Module(model_module, args) for model_module in model] |
|
|
|
if wrap_with_ddp: |
|
if args.DDP_impl == 'torch': |
|
i = torch.cuda.current_device() |
|
model = [torchDDP(model_module, device_ids=[i], output_device=i, |
|
process_group=mpu.get_data_parallel_group()) |
|
for model_module in model] |
|
|
|
elif args.DDP_impl == 'local': |
|
model = [LocalDDP(model_module, |
|
args.accumulate_allreduce_grads_in_fp32, |
|
args.use_contiguous_buffers_in_local_ddp) |
|
for model_module in model] |
|
|
|
if args.data_parallel_random_init: |
|
for model_module in model: |
|
model_module.broadcast_params() |
|
else: |
|
raise NotImplementedError('Unknown DDP implementation specified: ' |
|
'{}. Exiting.'.format(args.DDP_impl)) |
|
|
|
return model |
|
|
|
|
|
def get_optimizer_param_scheduler(optimizer): |
|
"""Build the learning rate scheduler.""" |
|
args = get_args() |
|
|
|
|
|
if args.train_iters: |
|
if args.lr_decay_iters is None: |
|
args.lr_decay_iters = args.train_iters |
|
lr_decay_steps = args.lr_decay_iters * args.global_batch_size |
|
wd_incr_steps = args.train_iters * args.global_batch_size |
|
if args.lr_warmup_fraction is not None: |
|
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps |
|
else: |
|
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size |
|
|
|
elif args.train_samples: |
|
|
|
|
|
|
|
update_train_iters(args) |
|
if args.lr_decay_samples is None: |
|
args.lr_decay_samples = args.train_samples |
|
lr_decay_steps = args.lr_decay_samples |
|
wd_incr_steps = args.train_samples |
|
if args.lr_warmup_fraction is not None: |
|
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps |
|
else: |
|
lr_warmup_steps = args.lr_warmup_samples |
|
else: |
|
raise Exception( |
|
'either train-iters or train-samples should be provided.') |
|
|
|
opt_param_scheduler = OptimizerParamScheduler( |
|
optimizer, |
|
max_lr=args.lr, |
|
min_lr=args.min_lr, |
|
lr_warmup_steps=lr_warmup_steps, |
|
lr_decay_steps=lr_decay_steps, |
|
lr_decay_style=args.lr_decay_style, |
|
start_wd=args.start_weight_decay, |
|
end_wd=args.end_weight_decay, |
|
wd_incr_steps=wd_incr_steps, |
|
wd_incr_style=args.weight_decay_incr_style, |
|
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, |
|
override_opt_param_scheduler=args.override_opt_param_scheduler) |
|
|
|
return opt_param_scheduler |
|
|
|
|
|
def setup_model_and_optimizer(model_provider_func, |
|
model_type, |
|
no_wd_decay_cond=None, |
|
scale_lr_cond=None, |
|
lr_mult=1.0): |
|
"""Setup model and optimizer.""" |
|
args = get_args() |
|
|
|
model = get_model(model_provider_func, model_type) |
|
unwrapped_model = unwrap_model(model, |
|
(torchDDP, LocalDDP, Float16Module)) |
|
|
|
optimizer = get_megatron_optimizer(model, no_wd_decay_cond, |
|
scale_lr_cond, lr_mult) |
|
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) |
|
|
|
if args.load is not None: |
|
timers = get_timers() |
|
|
|
|
|
torch.distributed.barrier() |
|
timers('load-checkpoint').start() |
|
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) |
|
torch.distributed.barrier() |
|
timers('load-checkpoint').stop() |
|
timers.log(['load-checkpoint']) |
|
else: |
|
args.iteration = 0 |
|
|
|
|
|
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: |
|
assert args.DDP_impl == 'local' |
|
|
|
|
|
if args.iteration == 0 and len(unwrapped_model) == 1 \ |
|
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): |
|
print_rank_0("Initializing ICT from pretrained BERT model") |
|
unwrapped_model[0].init_state_dict_from_bert() |
|
if args.fp16: |
|
optimizer.reload_model_params() |
|
|
|
return model, optimizer, opt_param_scheduler |
|
|
|
|
|
def train_step(forward_step_func, data_iterator, |
|
model, optimizer, opt_param_scheduler): |
|
"""Single training step.""" |
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: |
|
for partition in model: |
|
partition.zero_grad_buffer() |
|
optimizer.zero_grad() |
|
|
|
|
|
forward_backward_func = get_forward_backward_func() |
|
losses_reduced = forward_backward_func( |
|
forward_step_func, data_iterator, model, |
|
optimizer, timers, forward_only=False) |
|
|
|
|
|
if args.empty_unused_memory_level >= 1: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
timers('backward-reduce-model-grads').start() |
|
optimizer.reduce_model_grads(args, timers) |
|
timers('backward-reduce-model-grads').stop() |
|
|
|
|
|
if args.vision_pretraining and args.vision_pretraining_type == "dino": |
|
unwrapped_model = unwrap_model(model[0], |
|
(torchDDP, LocalDDP, Float16Module)) |
|
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) |
|
|
|
|
|
timers('optimizer').start() |
|
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) |
|
timers('optimizer').stop() |
|
|
|
|
|
if update_successful: |
|
timers('backward-gather-model-params').start() |
|
optimizer.gather_model_params(args, timers) |
|
timers('backward-gather-model-params').stop() |
|
|
|
|
|
if args.vision_pretraining and args.vision_pretraining_type == "dino": |
|
unwrapped_model = unwrap_model(model[0], |
|
(torchDDP, LocalDDP, Float16Module)) |
|
unwrapped_model.update_momentum(args.curr_iteration) |
|
|
|
|
|
if update_successful: |
|
increment = get_num_microbatches() * \ |
|
args.micro_batch_size * \ |
|
args.data_parallel_size |
|
opt_param_scheduler.step(increment=increment) |
|
skipped_iter = 0 |
|
else: |
|
skipped_iter = 1 |
|
|
|
|
|
if args.empty_unused_memory_level >= 2: |
|
torch.cuda.empty_cache() |
|
|
|
if mpu.is_pipeline_last_stage(ignore_virtual=True): |
|
|
|
loss_reduced = {} |
|
for key in losses_reduced[0]: |
|
if key == "describe": |
|
continue |
|
losses_reduced_for_key = [x[key] for x in losses_reduced] |
|
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) |
|
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad |
|
return {}, skipped_iter, grad_norm, num_zeros_in_grad |
|
|
|
|
|
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, |
|
loss_scale, report_memory_flag, skipped_iter, |
|
grad_norm, params_norm, num_zeros_in_grad, my_writer): |
|
"""Log training information such as losses, timing, ....""" |
|
args = get_args() |
|
timers = get_timers() |
|
writer = get_tensorboard_writer() |
|
|
|
|
|
advanced_iters_key = 'advanced iterations' |
|
skipped_iters_key = 'skipped iterations' |
|
nan_iters_key = 'nan iterations' |
|
|
|
if not skipped_iter: |
|
total_loss_dict[advanced_iters_key] = total_loss_dict.get( |
|
advanced_iters_key, 0) + 1 |
|
else: |
|
if advanced_iters_key not in total_loss_dict: |
|
total_loss_dict[advanced_iters_key] = 0 |
|
|
|
total_loss_dict[skipped_iters_key] = total_loss_dict.get( |
|
skipped_iters_key, 0) + skipped_iter |
|
|
|
got_nan = False |
|
for key in loss_dict: |
|
if not skipped_iter: |
|
total_loss_dict[key] = total_loss_dict.get( |
|
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] |
|
else: |
|
value = loss_dict[key].float().sum().item() |
|
is_nan = value == float('inf') or \ |
|
value == -float('inf') or \ |
|
value != value |
|
got_nan = got_nan or is_nan |
|
total_loss_dict[nan_iters_key] = total_loss_dict.get( |
|
nan_iters_key, 0) + int(got_nan) |
|
|
|
|
|
timers_to_log = [] |
|
|
|
def add_to_logging(name): |
|
if name in timers.timers: |
|
timers_to_log.append(name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = args.micro_batch_size * args.data_parallel_size * \ |
|
get_num_microbatches() |
|
|
|
total_iterations = total_loss_dict[advanced_iters_key] + \ |
|
total_loss_dict[skipped_iters_key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if my_writer and iteration % args.log_interval == 0 and is_last_rank(): |
|
|
|
my_writer.add_scalar('train/learning-rate', learning_rate, iteration) |
|
|
|
|
|
total_train_loss = 0 |
|
for key in loss_dict: |
|
my_writer.add_scalar("train/" + key , loss_dict[key], iteration) |
|
ppl = math.exp(min(20, loss_dict[key])) |
|
my_writer.add_scalar("train/" + key + "_ppl" , ppl, iteration) |
|
total_train_loss += loss_dict[key] |
|
my_writer.add_scalar("train/total-loss", total_train_loss, iteration) |
|
|
|
my_writer.add_scalar('train/loss-scale', loss_scale, iteration) |
|
|
|
if iteration % args.log_interval == 0: |
|
elapsed_time = timers('interval-time').elapsed() |
|
elapsed_time_per_iteration = elapsed_time / total_iterations |
|
if writer: |
|
if args.log_timers_to_tensorboard: |
|
writer.add_scalar('iteration-time', |
|
elapsed_time_per_iteration, iteration) |
|
log_string = ' iteration {:8d}/{:8d} |'.format( |
|
iteration, args.train_iters) |
|
log_string += ' consumed samples: {:12d} |'.format( |
|
args.consumed_train_samples) |
|
log_string += ' Task: {} |'.format( |
|
args.task) |
|
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( |
|
elapsed_time_per_iteration * 1000.0) |
|
log_string += ' learning rate: {:.3E} |'.format(learning_rate) |
|
log_string += ' global batch size: {:5d} |'.format(batch_size) |
|
for key in total_loss_dict: |
|
if key not in [advanced_iters_key, skipped_iters_key, |
|
nan_iters_key]: |
|
avg = total_loss_dict[key].item() / \ |
|
float(max(1, total_loss_dict[advanced_iters_key])) |
|
if avg > 0.0: |
|
log_string += ' {}: {:.6E} |'.format(key, avg) |
|
total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) |
|
log_string += ' loss scale: {:.1f} |'.format(loss_scale) |
|
if grad_norm is not None: |
|
log_string += ' grad norm: {:.3f} |'.format(grad_norm) |
|
if num_zeros_in_grad is not None: |
|
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) |
|
if params_norm is not None: |
|
log_string += ' params norm: {:.3f} |'.format(params_norm) |
|
log_string += ' number of skipped iterations: {:3d} |'.format( |
|
total_loss_dict[skipped_iters_key]) |
|
log_string += ' number of nan iterations: {:3d} |'.format( |
|
total_loss_dict[nan_iters_key]) |
|
total_loss_dict[advanced_iters_key] = 0 |
|
total_loss_dict[skipped_iters_key] = 0 |
|
total_loss_dict[nan_iters_key] = 0 |
|
print_rank_last(log_string) |
|
if report_memory_flag and learning_rate > 0.: |
|
|
|
report_memory('(after {} iterations)'.format(iteration)) |
|
report_memory_flag = False |
|
timers.log(timers_to_log, normalizer=args.log_interval) |
|
|
|
return report_memory_flag |
|
|
|
|
|
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): |
|
timers = get_timers() |
|
|
|
|
|
torch.distributed.barrier() |
|
timers('save-checkpoint').start() |
|
save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
|
torch.distributed.barrier() |
|
timers('save-checkpoint').stop() |
|
timers.log(['save-checkpoint']) |
|
|
|
|
|
def train(forward_step_func, model, optimizer, opt_param_scheduler, |
|
train_data_iterator, valid_data_iterator, |
|
process_non_loss_data_func): |
|
"""Train the model function.""" |
|
args = get_args() |
|
timers = get_timers() |
|
|
|
|
|
write_args_to_tensorboard() |
|
|
|
|
|
for model_module in model: |
|
model_module.train() |
|
|
|
|
|
total_loss_dict = {} |
|
|
|
|
|
iteration = args.iteration |
|
|
|
timers('interval-time').start() |
|
print_datetime('before the start of training step') |
|
if is_last_rank(): |
|
my_writer = SummaryWriter(args.save + "/tb_res") |
|
else: |
|
my_writer = None |
|
report_memory_flag = True |
|
while iteration < args.train_iters: |
|
update_num_microbatches(args.consumed_train_samples) |
|
args.curr_iteration = iteration |
|
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ |
|
train_step(forward_step_func, |
|
train_data_iterator, |
|
model, |
|
optimizer, |
|
opt_param_scheduler) |
|
iteration += 1 |
|
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ |
|
args.micro_batch_size * \ |
|
get_num_microbatches() |
|
|
|
|
|
loss_scale = optimizer.get_loss_scale().item() |
|
params_norm = None |
|
if args.log_params_norm: |
|
params_norm = calc_params_l2_norm(model) |
|
report_memory_flag = training_log(loss_dict, total_loss_dict, |
|
optimizer.param_groups[0]['lr'], |
|
iteration, loss_scale, |
|
report_memory_flag, skipped_iter, |
|
grad_norm, params_norm, num_zeros_in_grad, my_writer) |
|
|
|
|
|
if args.adlr_autoresume and \ |
|
(iteration % args.adlr_autoresume_interval == 0): |
|
check_adlr_autoresume_termination(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
|
|
|
|
if args.eval_interval and iteration % args.eval_interval == 0 and \ |
|
args.do_valid: |
|
prefix = 'iteration {}'.format(iteration) |
|
evaluate_and_print_results(prefix, forward_step_func, |
|
valid_data_iterator, model, |
|
iteration, process_non_loss_data_func, my_writer, |
|
False) |
|
|
|
|
|
saved_checkpoint = False |
|
if args.exit_signal_handler: |
|
signal_handler = get_signal_handler() |
|
if any(signal_handler.signals_received()): |
|
save_checkpoint_and_time(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
print_datetime('exiting program after receiving SIGTERM.') |
|
sys.exit() |
|
|
|
if args.save and args.save_interval and \ |
|
iteration % args.save_interval == 0: |
|
save_checkpoint_and_time(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
saved_checkpoint = True |
|
|
|
|
|
if args.exit_duration_in_mins: |
|
train_time = (time.time() - _TRAIN_START_TIME) / 60.0 |
|
done_cuda = torch.cuda.IntTensor( |
|
[train_time > args.exit_duration_in_mins]) |
|
torch.distributed.all_reduce( |
|
done_cuda, op=torch.distributed.ReduceOp.MAX) |
|
done = done_cuda.item() |
|
if done: |
|
if not saved_checkpoint: |
|
save_checkpoint_and_time(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
print_datetime('exiting program after {} minutes'.format(train_time)) |
|
sys.exit() |
|
|
|
|
|
if args.exit_interval and iteration % args.exit_interval == 0: |
|
if not saved_checkpoint: |
|
save_checkpoint_and_time(iteration, model, optimizer, |
|
opt_param_scheduler) |
|
torch.distributed.barrier() |
|
print_datetime('exiting program at iteration {}'.format(iteration)) |
|
sys.exit() |
|
|
|
|
|
return iteration |
|
|
|
|
|
def evaluate(forward_step_func, |
|
data_iterator, |
|
model, |
|
process_non_loss_data_func, |
|
verbose=False): |
|
"""Evaluation.""" |
|
args = get_args() |
|
|
|
if args.vision_pretraining and args.vision_pretraining_type == "dino": |
|
compute_feature_bank(model) |
|
|
|
|
|
for model_module in model: |
|
model_module.eval() |
|
|
|
total_loss_dict = {} |
|
|
|
with torch.no_grad(): |
|
iteration = 0 |
|
while iteration < args.eval_iters: |
|
iteration += 1 |
|
if verbose and iteration % args.log_interval == 0: |
|
print_rank_0('Evaluating iter {}/{}'.format(iteration, |
|
args.eval_iters)) |
|
|
|
forward_backward_func = get_forward_backward_func() |
|
loss_dicts = forward_backward_func( |
|
forward_step_func, data_iterator, model, optimizer=None, |
|
timers=None, forward_only=True) |
|
|
|
|
|
if args.empty_unused_memory_level >= 1: |
|
torch.cuda.empty_cache() |
|
|
|
if mpu.is_pipeline_last_stage(ignore_virtual=True): |
|
|
|
for loss_dict in loss_dicts: |
|
for key in loss_dict: |
|
if key == "describe": |
|
continue |
|
total_loss_dict[key] = total_loss_dict.get( |
|
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] |
|
|
|
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ |
|
* args.micro_batch_size \ |
|
* get_num_microbatches() |
|
collected_non_loss_data = None |
|
if process_non_loss_data_func is not None and is_last_rank(): |
|
collected_non_loss_data = forward_backward_func( |
|
forward_step_func, data_iterator, model, optimizer=None, |
|
timers=None, forward_only=True, collect_non_loss_data=True) |
|
|
|
|
|
for model_module in model: |
|
model_module.train() |
|
|
|
for key in total_loss_dict: |
|
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() |
|
if "describe" in loss_dict: |
|
total_loss_dict["describe"] = loss_dict["describe"] |
|
|
|
return total_loss_dict, collected_non_loss_data |
|
|
|
def evaluate_and_print_results(prefix, forward_step_func, |
|
data_iterator, model, |
|
iteration, process_non_loss_data_func, my_writer, |
|
verbose=False): |
|
"""Helper function to evaluate and dump results on screen.""" |
|
args = get_args() |
|
writer = get_tensorboard_writer() |
|
|
|
total_loss_dict, collected_non_loss_data = evaluate( |
|
forward_step_func, data_iterator, model, |
|
process_non_loss_data_func, verbose) |
|
string = ' validation loss at {} | '.format(prefix) |
|
total_val_loss = 0 |
|
for key in total_loss_dict: |
|
if key == "describe": |
|
if isinstance(total_loss_dict["describe"], str): |
|
string += total_loss_dict["describe"] |
|
continue |
|
elif isinstance(total_loss_dict["describe"], dict): |
|
continue |
|
else: |
|
raise "Not Imp" |
|
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) |
|
|
|
|
|
|
|
if my_writer and is_last_rank(): |
|
my_writer.add_scalar('val/' + key, total_loss_dict[key].item(), iteration) |
|
my_writer.add_scalar('val/' + key + '_ppl', ppl, iteration) |
|
total_val_loss += total_loss_dict[key].item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if process_non_loss_data_func is not None and writer and is_last_rank(): |
|
process_non_loss_data_func(collected_non_loss_data, iteration, writer) |
|
|
|
if my_writer and is_last_rank(): |
|
my_writer.add_scalar("val/total-loss", total_val_loss, iteration) |
|
|
|
length = len(string) + 1 |
|
print_rank_last('-' * length) |
|
print_rank_last(string) |
|
if "describe" in total_loss_dict and isinstance(total_loss_dict["describe"], dict): |
|
for k, v in total_loss_dict["describe"].items(): |
|
out_str = " : ".join([k, v]) |
|
print_rank_last(out_str) |
|
print_rank_last('-' * length) |
|
|
|
|
|
def cyclic_iter(iter): |
|
while True: |
|
for x in iter: |
|
yield x |
|
|
|
def build_train_valid_test_data_iterators( |
|
build_train_valid_test_datasets_provider): |
|
"""XXX""" |
|
args = get_args() |
|
|
|
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) |
|
|
|
print_rank_0('> building train, validation, and test datasets ...') |
|
|
|
|
|
if args.iteration > 0 and args.consumed_train_samples == 0: |
|
assert args.train_samples is None, \ |
|
'only backward compatiblity support for iteration-based training' |
|
args.consumed_train_samples = args.iteration * args.global_batch_size |
|
if args.iteration > 0 and args.consumed_valid_samples == 0: |
|
if args.train_samples is None: |
|
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ |
|
args.eval_iters * args.global_batch_size |
|
|
|
|
|
if mpu.get_tensor_model_parallel_rank() == 0: |
|
|
|
|
|
if args.train_samples: |
|
train_samples = args.train_samples |
|
else: |
|
train_samples = args.train_iters * args.global_batch_size |
|
eval_iters = (args.train_iters // args.eval_interval + 1) * \ |
|
args.eval_iters |
|
test_iters = args.eval_iters |
|
train_val_test_num_samples = [train_samples, |
|
eval_iters * args.global_batch_size, |
|
test_iters * args.global_batch_size] |
|
print_rank_0(' > datasets target sizes (minimum size):') |
|
print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) |
|
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) |
|
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) |
|
|
|
|
|
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( |
|
train_val_test_num_samples) |
|
|
|
|
|
train_dataloader = build_pretraining_data_loader( |
|
train_ds, args.consumed_train_samples) |
|
valid_dataloader = build_pretraining_data_loader( |
|
valid_ds, args.consumed_valid_samples) |
|
test_dataloader = build_pretraining_data_loader(test_ds, 0) |
|
|
|
|
|
do_train = train_dataloader is not None and args.train_iters > 0 |
|
do_valid = valid_dataloader is not None and args.eval_iters > 0 |
|
do_test = test_dataloader is not None and args.eval_iters > 0 |
|
|
|
flags = torch.cuda.LongTensor( |
|
[int(do_train), int(do_valid), int(do_test)]) |
|
else: |
|
flags = torch.cuda.LongTensor([0, 0, 0]) |
|
|
|
|
|
torch.distributed.broadcast(flags, |
|
mpu.get_tensor_model_parallel_src_rank(), |
|
group=mpu.get_tensor_model_parallel_group()) |
|
args.do_train = flags[0].item() |
|
args.do_valid = flags[1].item() |
|
args.do_test = flags[2].item() |
|
|
|
|
|
dl_type = args.dataloader_type |
|
assert dl_type in ['single', 'cyclic'] |
|
|
|
if train_dataloader is not None: |
|
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ |
|
else iter(cyclic_iter(train_dataloader)) |
|
else: |
|
train_data_iterator = None |
|
|
|
if valid_dataloader is not None: |
|
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ |
|
else iter(cyclic_iter(valid_dataloader)) |
|
else: |
|
valid_data_iterator = None |
|
|
|
if test_dataloader is not None: |
|
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ |
|
else iter(cyclic_iter(test_dataloader)) |
|
else: |
|
test_data_iterator = None |
|
|
|
return train_data_iterator, valid_data_iterator, test_data_iterator |
|
|