|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Megatron arguments.""" |
|
|
|
import argparse |
|
import os |
|
|
|
import torch |
|
|
|
def parse_args(extra_args_provider=None, ignore_unknown_args=False): |
|
"""Parse all arguments.""" |
|
parser = argparse.ArgumentParser(description='Megatron-LM Arguments', |
|
allow_abbrev=False) |
|
|
|
|
|
parser = _add_network_size_args(parser) |
|
parser = _add_regularization_args(parser) |
|
parser = _add_training_args(parser) |
|
parser = _add_initialization_args(parser) |
|
parser = _add_learning_rate_args(parser) |
|
parser = _add_checkpointing_args(parser) |
|
parser = _add_mixed_precision_args(parser) |
|
parser = _add_distributed_args(parser) |
|
parser = _add_validation_args(parser) |
|
parser = _add_data_args(parser) |
|
parser = _add_autoresume_args(parser) |
|
parser = _add_biencoder_args(parser) |
|
parser = _add_vision_args(parser) |
|
parser = _add_logging_args(parser) |
|
parser = _add_inference_args(parser) |
|
|
|
|
|
if extra_args_provider is not None: |
|
parser = extra_args_provider(parser) |
|
|
|
|
|
if ignore_unknown_args: |
|
args, _ = parser.parse_known_args() |
|
else: |
|
args = parser.parse_args() |
|
|
|
|
|
args.rank = int(os.getenv('RANK', '0')) |
|
args.world_size = int(os.getenv("WORLD_SIZE", '1')) |
|
|
|
return args |
|
|
|
def validate_args(args, defaults={}): |
|
|
|
args.tensor_model_parallel_size = min( |
|
args.tensor_model_parallel_size, args.world_size) |
|
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ |
|
' ({}) is not divisible by tensor model parallel size ({})'.format( |
|
args.world_size, args.tensor_model_parallel_size) |
|
|
|
args.pipeline_model_parallel_size = min( |
|
args.pipeline_model_parallel_size, |
|
(args.world_size // args.tensor_model_parallel_size)) |
|
args.transformer_pipeline_model_parallel_size = ( |
|
args.pipeline_model_parallel_size - 1 |
|
if args.standalone_embedding_stage else |
|
args.pipeline_model_parallel_size |
|
) |
|
|
|
model_parallel_size = args.pipeline_model_parallel_size * \ |
|
args.tensor_model_parallel_size |
|
assert args.world_size % model_parallel_size == 0, 'world size is not'\ |
|
' divisible by tensor parallel size ({}) times pipeline parallel ' \ |
|
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, |
|
args.pipeline_model_parallel_size) |
|
args.data_parallel_size = args.world_size // model_parallel_size |
|
if args.rank == 0: |
|
print('using world size: {}, data-parallel-size: {}, ' |
|
'tensor-model-parallel size: {}, ' |
|
'pipeline-model-parallel size: {} '.format( |
|
args.world_size, args.data_parallel_size, |
|
args.tensor_model_parallel_size, |
|
args.pipeline_model_parallel_size), flush=True) |
|
if args.pipeline_model_parallel_size > 1: |
|
if args.pipeline_model_parallel_split_rank is not None: |
|
assert args.pipeline_model_parallel_split_rank < \ |
|
args.pipeline_model_parallel_size, 'split rank needs'\ |
|
' to be less than pipeline model parallel size ({})'.format( |
|
args.pipeline_model_parallel_size) |
|
if args.data_path: |
|
|
|
data_path = args.data_path |
|
processed_data_path = [] |
|
for path in data_path: |
|
files = os.listdir(path) |
|
idx_files = [fn[:-4] for fn in files if fn.endswith('.idx')] |
|
bin_files = [fn[:-4] for fn in files if fn.endswith('.bin')] |
|
for idx_fn in idx_files: |
|
if idx_fn in bin_files: |
|
|
|
processed_data_path.append('1') |
|
processed_data_path.append(os.path.join(path, idx_fn)) |
|
args.raw_data_path = data_path |
|
args.data_path = processed_data_path |
|
|
|
|
|
|
|
assert args.batch_size is None, '--batch-size argument is no longer ' \ |
|
'valid, use --micro-batch-size instead' |
|
del args.batch_size |
|
assert args.warmup is None, '--warmup argument is no longer valid, use ' \ |
|
'--lr-warmup-fraction instead' |
|
del args.warmup |
|
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ |
|
'longer valid, use --tensor-model-parallel-size instead' |
|
del args.model_parallel_size |
|
|
|
if args.checkpoint_activations: |
|
args.recompute_granularity = 'full' |
|
args.recompute_method = 'uniform' |
|
if args.rank == 0: |
|
print('--checkpoint-activations is no longer valid, ' |
|
'use --recompute-granularity and --recompute-method instead. ' |
|
'Defaulting to recompute-granularity=full and recompute-method=uniform.') |
|
del args.checkpoint_activations |
|
|
|
if args.recompute_activations: |
|
args.recompute_granularity = 'selective' |
|
del args.recompute_activations |
|
|
|
|
|
for key in defaults: |
|
|
|
|
|
|
|
if getattr(args, key) is not None: |
|
if args.rank == 0: |
|
print('WARNING: overriding default arguments for {key}:{v} \ |
|
with {key}:{v2}'.format(key=key, v=defaults[key], |
|
v2=getattr(args, key)), |
|
flush=True) |
|
else: |
|
setattr(args, key, defaults[key]) |
|
|
|
|
|
assert args.micro_batch_size is not None |
|
assert args.micro_batch_size > 0 |
|
if args.global_batch_size is None: |
|
args.global_batch_size = args.micro_batch_size * args.data_parallel_size |
|
if args.rank == 0: |
|
print('setting global batch size to {}'.format( |
|
args.global_batch_size), flush=True) |
|
assert args.global_batch_size > 0 |
|
if args.num_layers_per_virtual_pipeline_stage is not None: |
|
assert args.pipeline_model_parallel_size > 2, \ |
|
'pipeline-model-parallel size should be greater than 2 with ' \ |
|
'interleaved schedule' |
|
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ |
|
'number of layers is not divisible by number of layers per virtual ' \ |
|
'pipeline stage' |
|
args.virtual_pipeline_model_parallel_size = \ |
|
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \ |
|
args.num_layers_per_virtual_pipeline_stage |
|
else: |
|
args.virtual_pipeline_model_parallel_size = None |
|
|
|
|
|
args.params_dtype = torch.float |
|
if args.fp16: |
|
assert not args.bf16 |
|
args.params_dtype = torch.half |
|
if args.bf16: |
|
assert not args.fp16 |
|
args.params_dtype = torch.bfloat16 |
|
|
|
|
|
if not args.accumulate_allreduce_grads_in_fp32: |
|
args.accumulate_allreduce_grads_in_fp32 = True |
|
if args.rank == 0: |
|
print('accumulate and all-reduce gradients in fp32 for ' |
|
'bfloat16 data type.', flush=True) |
|
|
|
if args.rank == 0: |
|
print('using {} for parameters ...'.format(args.params_dtype), |
|
flush=True) |
|
|
|
|
|
|
|
if args.accumulate_allreduce_grads_in_fp32: |
|
assert args.DDP_impl == 'local' |
|
assert args.use_contiguous_buffers_in_local_ddp |
|
else: |
|
if args.gradient_accumulation_fusion: |
|
args.gradient_accumulation_fusion = False |
|
if args.rank == 0: |
|
print('Gradient accumulation fusion to linear layer weight ' |
|
'gradient computation is supported only with fp32 ' |
|
'gradient accumulation. Setting gradient_accumulation_fusion ' |
|
'to False', flush=True) |
|
|
|
|
|
|
|
if args.use_distributed_optimizer: |
|
assert args.DDP_impl == 'local' |
|
assert args.use_contiguous_buffers_in_local_ddp |
|
|
|
|
|
if args.DDP_impl == 'torch': |
|
args.use_contiguous_buffers_in_local_ddp = False |
|
|
|
if args.dataloader_type is None: |
|
args.dataloader_type = 'single' |
|
|
|
|
|
args.consumed_train_samples = 0 |
|
args.consumed_valid_samples = 0 |
|
|
|
|
|
if args.train_iters: |
|
|
|
|
|
assert args.train_samples is None, \ |
|
'expected iteration-based training' |
|
assert args.lr_decay_samples is None, \ |
|
'expected iteration-based learning rate decay' |
|
assert args.lr_warmup_samples == 0, \ |
|
'expected iteration-based learning rate warmup' |
|
assert args.rampup_batch_size is None, \ |
|
'expected no batch-size rampup for iteration-based training' |
|
if args.lr_warmup_fraction is not None: |
|
assert args.lr_warmup_iters == 0, \ |
|
'can only specify one of lr-warmup-fraction and lr-warmup-iters' |
|
|
|
|
|
if args.train_samples: |
|
|
|
|
|
assert args.train_iters is None, \ |
|
'expected sample-based training' |
|
assert args.lr_decay_iters is None, \ |
|
'expected sample-based learning rate decay' |
|
assert args.lr_warmup_iters == 0, \ |
|
'expected sample-based learnig rate warmup' |
|
if args.lr_warmup_fraction is not None: |
|
assert args.lr_warmup_samples == 0, \ |
|
'can only specify one of lr-warmup-fraction ' \ |
|
'and lr-warmup-samples' |
|
|
|
|
|
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', |
|
'max_position_embeddings'] |
|
for req_arg in required_args: |
|
_check_arg_is_not_none(args, req_arg) |
|
|
|
|
|
if args.ffn_hidden_size is None: |
|
args.ffn_hidden_size = 4 * args.hidden_size |
|
|
|
if args.kv_channels is None: |
|
assert args.hidden_size % args.num_attention_heads == 0 |
|
args.kv_channels = args.hidden_size // args.num_attention_heads |
|
|
|
if args.seq_length is not None: |
|
assert args.encoder_seq_length is None |
|
args.encoder_seq_length = args.seq_length |
|
else: |
|
assert args.encoder_seq_length is not None |
|
args.seq_length = args.encoder_seq_length |
|
|
|
if args.seq_length is not None: |
|
assert args.max_position_embeddings >= args.seq_length |
|
if args.decoder_seq_length is not None: |
|
assert args.max_position_embeddings >= args.decoder_seq_length |
|
if args.lr is not None: |
|
assert args.min_lr <= args.lr |
|
if args.save is not None: |
|
assert args.save_interval is not None |
|
|
|
if args.fp16_lm_cross_entropy: |
|
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' |
|
if args.fp32_residual_connection: |
|
assert args.fp16 or args.bf16, \ |
|
'residual connection in fp32 only supported when using fp16 or bf16.' |
|
|
|
if args.weight_decay_incr_style == 'constant': |
|
assert args.start_weight_decay is None |
|
assert args.end_weight_decay is None |
|
args.start_weight_decay = args.weight_decay |
|
args.end_weight_decay = args.weight_decay |
|
else: |
|
assert args.start_weight_decay is not None |
|
assert args.end_weight_decay is not None |
|
|
|
TORCH_MAJOR = int(torch.__version__.split('.')[0]) |
|
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
|
|
|
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): |
|
args.no_persist_layer_norm = True |
|
if args.rank == 0: |
|
print('Persistent fused layer norm kernel is supported from ' |
|
'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' |
|
'Defaulting to no_persist_layer_norm=True') |
|
|
|
|
|
if args.distribute_saved_activations: |
|
assert args.tensor_model_parallel_size > 1, 'can distribute ' \ |
|
'recomputed activations only across tensor model ' \ |
|
'parallel groups' |
|
assert args.recompute_granularity == 'full', \ |
|
'distributed recompute activations is only '\ |
|
'application to full recompute granularity' |
|
assert args.recompute_method is not None, \ |
|
'for distributed recompute activations to work you '\ |
|
'need to use a recompute method ' |
|
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ |
|
'distributed recompute activations are supported for pytorch ' \ |
|
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ |
|
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) |
|
|
|
if args.recompute_granularity == 'selective': |
|
assert args.recompute_method is None, \ |
|
'recompute method is not yet supported for ' \ |
|
'selective recomputing granularity' |
|
|
|
|
|
|
|
|
|
if args.tensor_model_parallel_size == 1: |
|
args.sequence_parallel = False |
|
|
|
|
|
|
|
if args.sequence_parallel: |
|
args.async_tensor_model_parallel_allreduce = False |
|
|
|
_print_args(args) |
|
return args |
|
|
|
|
|
def _print_args(args): |
|
"""Print arguments.""" |
|
if args.rank == 0: |
|
print('------------------------ arguments ------------------------', |
|
flush=True) |
|
str_list = [] |
|
for arg in vars(args): |
|
dots = '.' * (48 - len(arg)) |
|
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) |
|
for arg in sorted(str_list, key=lambda x: x.lower()): |
|
print(arg, flush=True) |
|
print('-------------------- end of arguments ---------------------', |
|
flush=True) |
|
|
|
|
|
def _check_arg_is_not_none(args, arg): |
|
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) |
|
|
|
|
|
def _add_inference_args(parser): |
|
group = parser.add_argument_group(title='inference') |
|
|
|
group.add_argument('--inference-batch-times-seqlen-threshold', |
|
type=int, default=512, |
|
help='During inference, if batch-size times ' |
|
'sequence-length is smaller than this threshold ' |
|
'then we will not use pipelining, otherwise we will.') |
|
|
|
return parser |
|
|
|
|
|
def _add_network_size_args(parser): |
|
group = parser.add_argument_group(title='network size') |
|
|
|
group.add_argument('--num-layers', type=int, default=None, |
|
help='Number of transformer layers.') |
|
group.add_argument('--num-layers-decoder', type=int, default=None, |
|
help='Number of transformer layers decoder.') |
|
group.add_argument('--hidden-size', type=int, default=None, |
|
help='Tansformer hidden size.') |
|
group.add_argument('--ffn-hidden-size', type=int, default=None, |
|
help='Transformer Feed-Forward Network hidden size. ' |
|
'This is set to 4*hidden-size if not provided') |
|
group.add_argument('--num-attention-heads', type=int, default=None, |
|
help='Number of transformer attention heads.') |
|
group.add_argument('--kv-channels', type=int, default=None, |
|
help='Projection weights dimension in multi-head ' |
|
'attention. This is set to ' |
|
' args.hidden_size // args.num_attention_heads ' |
|
'if not provided.') |
|
group.add_argument('--max-position-embeddings', type=int, default=None, |
|
help='Maximum number of position embeddings to use. ' |
|
'This is the size of position embedding.') |
|
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, |
|
help='Pad the vocab size to be divisible by this value.' |
|
'This is added for computational efficieny reasons.') |
|
group.add_argument('--layernorm-epsilon', type=float, default=1e-5, |
|
help='Layer norm epsilon.') |
|
group.add_argument('--apply-residual-connection-post-layernorm', |
|
action='store_true', |
|
help='If set, use original BERT residula connection ' |
|
'ordering.') |
|
group.add_argument('--openai-gelu', action='store_true', |
|
help='Use OpenAIs GeLU implementation. This option' |
|
'should not be used unless for backward compatibility' |
|
'reasons.') |
|
group.add_argument('--onnx-safe', type=bool, required=False, |
|
help='Use workarounds for known problems with ' |
|
'Torch ONNX exporter') |
|
group.add_argument('--bert-no-binary-head', action='store_false', |
|
help='Disable BERT binary head.', |
|
dest='bert_binary_head') |
|
group.add_argument('--num-experts', type=int, default=None, |
|
help='Number of Experts in Switch Transformer (None means no Switch)') |
|
return parser |
|
|
|
|
|
def _add_logging_args(parser): |
|
group = parser.add_argument_group(title='logging') |
|
|
|
group.add_argument('--log-params-norm', action='store_true', |
|
help='If set, calculate and log parameters norm.') |
|
group.add_argument('--log-num-zeros-in-grad', action='store_true', |
|
help='If set, calculate and log the number of zeros in gradient.') |
|
group.add_argument('--tensorboard-log-interval', type=int, default=1, |
|
help='Report to tensorboard interval.') |
|
group.add_argument('--tensorboard-queue-size', type=int, default=1000, |
|
help='Size of the tensorboard queue for pending events ' |
|
'and summaries before one of the ‘add’ calls forces a ' |
|
'flush to disk.') |
|
group.add_argument('--log-timers-to-tensorboard', action='store_true', |
|
help='If set, write timers to tensorboard.') |
|
group.add_argument('--log-batch-size-to-tensorboard', action='store_true', |
|
help='If set, write batch-size to tensorboard.') |
|
group.add_argument('--no-log-learnig-rate-to-tensorboard', |
|
action='store_false', |
|
help='Disable learning rate logging to tensorboard.', |
|
dest='log_learning_rate_to_tensorboard') |
|
group.add_argument('--no-log-loss-scale-to-tensorboard', |
|
action='store_false', |
|
help='Disable loss-scale logging to tensorboard.', |
|
dest='log_loss_scale_to_tensorboard') |
|
group.add_argument('--log-validation-ppl-to-tensorboard', |
|
action='store_true', |
|
help='If set, write validation perplexity to ' |
|
'tensorboard.') |
|
group.add_argument('--log-memory-to-tensorboard', |
|
action='store_true', |
|
help='Enable memory logging to tensorboard.') |
|
group.add_argument('--log-world-size-to-tensorboard', |
|
action='store_true', |
|
help='Enable world size logging to tensorboard.') |
|
|
|
return parser |
|
|
|
|
|
def _add_regularization_args(parser): |
|
group = parser.add_argument_group(title='regularization') |
|
|
|
group.add_argument('--attention-dropout', type=float, default=0.1, |
|
help='Post attention dropout probability.') |
|
group.add_argument('--hidden-dropout', type=float, default=0.1, |
|
help='Dropout probability for hidden state transformer.') |
|
group.add_argument('--weight-decay', type=float, default=0.01, |
|
help='Weight decay coefficient for L2 regularization.') |
|
group.add_argument('--start-weight-decay', type=float, |
|
help='Initial weight decay coefficient for L2 regularization.') |
|
group.add_argument('--end-weight-decay', type=float, |
|
help='End of run weight decay coefficient for L2 regularization.') |
|
group.add_argument('--weight-decay-incr-style', type=str, default='constant', |
|
choices=['constant', 'linear', 'cosine'], |
|
help='Weight decay increment function.') |
|
group.add_argument('--clip-grad', type=float, default=1.0, |
|
help='Gradient clipping based on global L2 norm.') |
|
group.add_argument('--adam-beta1', type=float, default=0.9, |
|
help='First coefficient for computing running averages ' |
|
'of gradient and its square') |
|
group.add_argument('--adam-beta2', type=float, default=0.999, |
|
help='Second coefficient for computing running averages ' |
|
'of gradient and its square') |
|
group.add_argument('--adam-eps', type=float, default=1e-08, |
|
help='Term added to the denominator to improve' |
|
'numerical stability') |
|
group.add_argument('--sgd-momentum', type=float, default=0.9, |
|
help='Momentum factor for sgd') |
|
|
|
return parser |
|
|
|
|
|
def _add_training_args(parser): |
|
group = parser.add_argument_group(title='training') |
|
|
|
group.add_argument('--micro-batch-size', type=int, default=None, |
|
help='Batch size per model instance (local batch size). ' |
|
'Global batch size is local batch size times data ' |
|
'parallel size times number of micro batches.') |
|
group.add_argument('--batch-size', type=int, default=None, |
|
help='Old batch size parameter, do not use. ' |
|
'Use --micro-batch-size instead') |
|
group.add_argument('--global-batch-size', type=int, default=None, |
|
help='Training batch size. If set, it should be a ' |
|
'multiple of micro-batch-size times data-parallel-size. ' |
|
'If this value is None, then ' |
|
'use micro-batch-size * data-parallel-size as the ' |
|
'global batch size. This choice will result in 1 for ' |
|
'number of micro-batches.') |
|
group.add_argument('--rampup-batch-size', nargs='*', default=None, |
|
help='Batch size ramp up with the following values:' |
|
' --rampup-batch-size <start batch size> ' |
|
' <batch size incerement> ' |
|
' <ramp-up samples> ' |
|
'For example:' |
|
' --rampup-batch-size 16 8 300000 \ ' |
|
' --global-batch-size 1024' |
|
'will start with global batch size 16 and over ' |
|
' (1024 - 16) / 8 = 126 intervals will increase' |
|
'the batch size linearly to 1024. In each interval' |
|
'we will use approximately 300000 / 126 = 2380 samples.') |
|
group.add_argument('--recompute-activations', action='store_true', |
|
help='recompute activation to allow for training ' |
|
'with larger models, sequences, and batch sizes.') |
|
group.add_argument('--recompute-granularity', type=str, default=None, |
|
choices=['full', 'selective'], |
|
help='Checkpoint activations to allow for training ' |
|
'with larger models, sequences, and batch sizes. ' |
|
'It is supported at two granularities 1) full: ' |
|
'whole transformer layer is recomputed, ' |
|
'2) selective: core attention part of the transformer ' |
|
'layer is recomputed.') |
|
group.add_argument('--distribute-saved-activations', |
|
action='store_true', |
|
help='If set, distribute recomputed activations ' |
|
'across model parallel group.') |
|
group.add_argument('--recompute-method', type=str, default=None, |
|
choices=['uniform', 'block'], |
|
help='1) uniform: uniformly divide the total number of ' |
|
'Transformer layers and recompute the input activation of ' |
|
'each divided chunk at specified granularity, ' |
|
'2) recompute the input activations of only a set number of ' |
|
'individual Transformer layers per pipeline stage and do the ' |
|
'rest without any recomputing at specified granularity' |
|
'default) do not apply activations recompute to any layers') |
|
group.add_argument('--recompute-num-layers', type=int, default=1, |
|
help='1) uniform: the number of Transformer layers in each ' |
|
'uniformly divided recompute unit, ' |
|
'2) block: the number of individual Transformer layers ' |
|
'to recompute within each pipeline stage.') |
|
|
|
|
|
group.add_argument('--checkpoint-activations', action='store_true', |
|
help='Checkpoint activation to allow for training ' |
|
'with larger models, sequences, and batch sizes.') |
|
group.add_argument('--train-iters', type=int, default=None, |
|
help='Total number of iterations to train over all ' |
|
'training runs. Note that either train-iters or ' |
|
'train-samples should be provided.') |
|
group.add_argument('--train-samples', type=int, default=None, |
|
help='Total number of samples to train over all ' |
|
'training runs. Note that either train-iters or ' |
|
'train-samples should be provided.') |
|
group.add_argument('--log-interval', type=int, default=100, |
|
help='Report loss and timing interval.') |
|
group.add_argument('--exit-interval', type=int, default=None, |
|
help='Exit the program after the iteration is divisible ' |
|
'by this value.') |
|
group.add_argument('--exit-duration-in-mins', type=int, default=None, |
|
help='Exit the program after this many minutes.') |
|
group.add_argument('--exit-signal-handler', action='store_true', |
|
help='Dynamically save the checkpoint and shutdown the ' |
|
'training if SIGTERM is received') |
|
group.add_argument('--tensorboard-dir', type=str, default=None, |
|
help='Write TensorBoard logs to this directory.') |
|
group.add_argument('--no-masked-softmax-fusion', |
|
action='store_false', |
|
help='Disable fusion of query_key_value scaling, ' |
|
'masking, and softmax.', |
|
dest='masked_softmax_fusion') |
|
group.add_argument('--no-bias-gelu-fusion', action='store_false', |
|
help='Disable bias and gelu fusion.', |
|
dest='bias_gelu_fusion') |
|
group.add_argument('--no-bias-dropout-fusion', action='store_false', |
|
help='Disable bias and dropout fusion.', |
|
dest='bias_dropout_fusion') |
|
group.add_argument('--optimizer', type=str, default='adam', |
|
choices=['adam', 'sgd'], |
|
help='Optimizer function') |
|
group.add_argument('--dataloader-type', type=str, default=None, |
|
choices=['single', 'cyclic'], |
|
help='Single pass vs multiple pass data loader') |
|
group.add_argument('--no-async-tensor-model-parallel-allreduce', |
|
action='store_false', |
|
help='Disable asynchronous execution of ' |
|
'tensor-model-parallel all-reduce with weight ' |
|
'gradient compuation of a column-linear layer.', |
|
dest='async_tensor_model_parallel_allreduce') |
|
group.add_argument('--no-persist-layer-norm', action='store_true', |
|
help='Disable using persistent fused layer norm kernel. ' |
|
'This kernel supports only a set of hidden sizes. Please ' |
|
'check persist_ln_hidden_sizes if your hidden ' |
|
'size is supported.') |
|
group.add_argument('--sequence-parallel', action='store_true', |
|
help='Enable sequence parallel optimization.') |
|
group.add_argument('--no-gradient-accumulation-fusion', |
|
action='store_false', |
|
help='Disable fusing gradient accumulation to weight ' |
|
'gradient computation of linear layers', |
|
dest='gradient_accumulation_fusion') |
|
return parser |
|
|
|
|
|
def _add_initialization_args(parser): |
|
group = parser.add_argument_group(title='initialization') |
|
|
|
group.add_argument('--seed', type=int, default=1234, |
|
help='Random seed used for python, numpy, ' |
|
'pytorch, and cuda.') |
|
group.add_argument('--data-parallel-random-init', action='store_true', |
|
help='Enable random initialization of params ' |
|
'across data parallel ranks') |
|
group.add_argument('--init-method-std', type=float, default=0.02, |
|
help='Standard deviation of the zero mean normal ' |
|
'distribution used for weight initialization.') |
|
group.add_argument('--init-method-xavier-uniform', action='store_true', |
|
help='Enable Xavier uniform parameter initialization') |
|
|
|
return parser |
|
|
|
|
|
def _add_learning_rate_args(parser): |
|
group = parser.add_argument_group(title='learning rate') |
|
|
|
group.add_argument('--lr', type=float, default=None, |
|
help='Initial learning rate. Depending on decay style ' |
|
'and initial warmup, the learing rate at each ' |
|
'iteration would be different.') |
|
group.add_argument('--lr-decay-style', type=str, default='linear', |
|
choices=['constant', 'linear', 'cosine'], |
|
help='Learning rate decay function.') |
|
group.add_argument('--lr-decay-iters', type=int, default=None, |
|
help='number of iterations to decay learning rate over,' |
|
' If None defaults to `--train-iters`') |
|
group.add_argument('--lr-decay-samples', type=int, default=None, |
|
help='number of samples to decay learning rate over,' |
|
' If None defaults to `--train-samples`') |
|
group.add_argument('--lr-warmup-fraction', type=float, default=None, |
|
help='fraction of lr-warmup-(iters/samples) to use ' |
|
'for warmup (as a float)') |
|
group.add_argument('--lr-warmup-iters', type=int, default=0, |
|
help='number of iterations to linearly warmup ' |
|
'learning rate over.') |
|
group.add_argument('--lr-warmup-samples', type=int, default=0, |
|
help='number of samples to linearly warmup ' |
|
'learning rate over.') |
|
group.add_argument('--warmup', type=int, default=None, |
|
help='Old lr warmup argument, do not use. Use one of the' |
|
'--lr-warmup-* arguments above') |
|
group.add_argument('--min-lr', type=float, default=0.0, |
|
help='Minumum value for learning rate. The scheduler' |
|
'clip values below this threshold.') |
|
group.add_argument('--override-opt_param-scheduler', action='store_true', |
|
help='Reset the values of the scheduler (learning rate,' |
|
'warmup iterations, minimum learning rate, maximum ' |
|
'number of iterations, and decay style from input ' |
|
'arguments and ignore values from checkpoints. Note' |
|
'that all the above values will be reset.') |
|
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true', |
|
help='Use checkpoint to set the values of the scheduler ' |
|
'(learning rate, warmup iterations, minimum learning ' |
|
'rate, maximum number of iterations, and decay style ' |
|
'from checkpoint and ignore input arguments.') |
|
|
|
return parser |
|
|
|
|
|
def _add_checkpointing_args(parser): |
|
group = parser.add_argument_group(title='checkpointing') |
|
|
|
group.add_argument('--save', type=str, default=None, |
|
help='Output directory to save checkpoints to.') |
|
group.add_argument('--save-interval', type=int, default=None, |
|
help='Number of iterations between checkpoint saves.') |
|
group.add_argument('--no-save-optim', action='store_true', default=None, |
|
help='Do not save current optimizer.') |
|
group.add_argument('--no-save-rng', action='store_true', default=None, |
|
help='Do not save current rng state.') |
|
group.add_argument('--load', type=str, default=None, |
|
help='Directory containing a model checkpoint.') |
|
group.add_argument('--no-load-optim', action='store_true', default=None, |
|
help='Do not load optimizer when loading checkpoint.') |
|
group.add_argument('--no-load-rng', action='store_true', default=None, |
|
help='Do not load rng state when loading checkpoint.') |
|
group.add_argument('--finetune', action='store_true', |
|
help='Load model for finetuning. Do not load optimizer ' |
|
'or rng state from checkpoint and set iteration to 0. ' |
|
'Assumed when loading a release checkpoint.') |
|
group.add_argument('--no-initialization', action='store_false', |
|
help='Do not perform initialization when building model, ' |
|
'can reduce startup time when definitely loading from a ' |
|
'checkpoint', |
|
dest='perform_initialization') |
|
group.add_argument('--use-checkpoint-args', action='store_true', |
|
help='Override any command line arguments with arguments ' |
|
'from the checkpoint') |
|
|
|
return parser |
|
|
|
|
|
def _add_mixed_precision_args(parser): |
|
group = parser.add_argument_group(title='mixed precision') |
|
|
|
group.add_argument('--fp16', action='store_true', |
|
help='Run model in fp16 mode.') |
|
group.add_argument('--bf16', action='store_true', |
|
help='Run model in bfloat16 mode.') |
|
group.add_argument('--loss-scale', type=float, default=None, |
|
help='Static loss scaling, positive power of 2 ' |
|
'values can improve fp16 convergence. If None, dynamic' |
|
'loss scaling is used.') |
|
group.add_argument('--initial-loss-scale', type=float, default=2**32, |
|
help='Initial loss-scale for dynamic loss scaling.') |
|
group.add_argument('--min-loss-scale', type=float, default=1.0, |
|
help='Minimum loss scale for dynamic loss scale.') |
|
group.add_argument('--loss-scale-window', type=float, default=1000, |
|
help='Window over which to raise/lower dynamic scale.') |
|
group.add_argument('--hysteresis', type=int, default=2, |
|
help='hysteresis for dynamic loss scaling') |
|
group.add_argument('--fp32-residual-connection', action='store_true', |
|
help='Move residual connections to fp32.') |
|
group.add_argument('--no-query-key-layer-scaling', action='store_false', |
|
help='Do not scale Q * K^T by 1 / layer-number.', |
|
dest='apply_query_key_layer_scaling') |
|
group.add_argument('--attention-softmax-in-fp32', action='store_true', |
|
help='Run attention masking and softmax in fp32. ' |
|
'This flag is ignored unless ' |
|
'--no-query-key-layer-scaling is specified.') |
|
group.add_argument('--accumulate-allreduce-grads-in-fp32', |
|
action='store_true', |
|
help='Gradient accumulation and all-reduce in fp32.') |
|
group.add_argument('--fp16-lm-cross-entropy', action='store_true', |
|
help='Move the cross entropy unreduced loss calculation' |
|
'for lm head to fp16.') |
|
|
|
return parser |
|
|
|
|
|
def _add_distributed_args(parser): |
|
group = parser.add_argument_group(title='distributed') |
|
|
|
group.add_argument('--tensor-model-parallel-size', type=int, default=1, |
|
help='Degree of tensor model parallelism.') |
|
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, |
|
help='Degree of pipeline model parallelism.') |
|
group.add_argument('--pipeline-model-parallel-split-rank', |
|
type=int, default=None, |
|
help='Rank where encoder and decoder should be split.') |
|
group.add_argument('--model-parallel-size', type=int, default=None, |
|
help='Old model parallel argument, do not use. Use ' |
|
'--tensor-model-parallel-size instead.') |
|
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, |
|
help='Number of layers per virtual pipeline stage') |
|
group.add_argument('--distributed-backend', default='nccl', |
|
choices=['nccl', 'gloo'], |
|
help='Which backend to use for distributed training.') |
|
group.add_argument('--DDP-impl', default='local', |
|
choices=['local', 'torch'], |
|
help='which DistributedDataParallel implementation ' |
|
'to use.') |
|
group.add_argument('--no-contiguous-buffers-in-local-ddp', |
|
action='store_false', help='If set, dont use ' |
|
'contiguous buffer in local DDP.', |
|
dest='use_contiguous_buffers_in_local_ddp') |
|
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', |
|
help='Use scatter/gather to optimize communication of tensors in pipeline', |
|
dest='scatter_gather_tensors_in_pipeline') |
|
group.add_argument('--local_rank', type=int, default=None, |
|
help='local rank passed from distributed launcher.') |
|
group.add_argument('--lazy-mpu-init', type=bool, required=False, |
|
help='If set to True, initialize_megatron() ' |
|
'skips DDP initialization and returns function to ' |
|
'complete it instead.Also turns on ' |
|
'--use-cpu-initialization flag. This is for ' |
|
'external DDP manager.' ) |
|
group.add_argument('--use-cpu-initialization', action='store_true', |
|
default=None, help='If set, affine parallel weights ' |
|
'initialization uses CPU' ) |
|
group.add_argument('--empty-unused-memory-level', default=0, type=int, |
|
choices=[0, 1, 2], |
|
help='Call torch.cuda.empty_cache() each iteration ' |
|
'(training and eval), to reduce fragmentation.' |
|
'0=off, 1=moderate, 2=aggressive.') |
|
group.add_argument('--standalone-embedding-stage', action='store_true', |
|
default=False, help='If set, *input* embedding layer ' |
|
'is placed on its own pipeline stage, without any ' |
|
'transformer layers. (For T5, this flag currently only ' |
|
'affects the encoder embedding.)') |
|
group.add_argument('--use-distributed-optimizer', action='store_true', |
|
help='Use distributed optimizer.') |
|
|
|
return parser |
|
|
|
|
|
def _add_validation_args(parser): |
|
group = parser.add_argument_group(title='validation') |
|
|
|
group.add_argument('--eval-iters', type=int, default=100, |
|
help='Number of iterations to run for evaluation' |
|
'validation/test for.') |
|
group.add_argument('--eval-interval', type=int, default=1000, |
|
help='Interval between running evaluation on ' |
|
'validation set.') |
|
|
|
return parser |
|
|
|
|
|
def _add_data_args(parser): |
|
group = parser.add_argument_group(title='data and dataloader') |
|
|
|
group.add_argument('--data-path', nargs='*', default=None, |
|
help='Path to the training dataset. Accepted format:' |
|
'1) a single data path, 2) multiple datasets in the' |
|
'form: dataset1-weight dataset1-path dataset2-weight ' |
|
'dataset2-path ...') |
|
group.add_argument('--split', type=str, default='969, 30, 1', |
|
help='Comma-separated list of proportions for training,' |
|
' validation, and test split. For example the split ' |
|
'`90,5,5` will use 90%% of data for training, 5%% for ' |
|
'validation and 5%% for test.') |
|
group.add_argument('--vocab-file', type=str, default=None, |
|
help='Path to the vocab file.') |
|
group.add_argument('--merge-file', type=str, default=None, |
|
help='Path to the BPE merge file.') |
|
group.add_argument('--vocab-extra-ids', type=int, default=0, |
|
help='Number of additional vocabulary tokens. ' |
|
'They are used for span masking in the T5 model') |
|
group.add_argument('--seq-length', type=int, default=None, |
|
help='Maximum sequence length to process.') |
|
group.add_argument('--encoder-seq-length', type=int, default=None, |
|
help='Maximum encoder sequence length to process.' |
|
'This should be exclusive of --seq-length') |
|
group.add_argument('--decoder-seq-length', type=int, default=None, |
|
help="Maximum decoder sequence length to process.") |
|
group.add_argument('--retriever-seq-length', type=int, default=256, |
|
help='Maximum sequence length for the biencoder model ' |
|
' for retriever') |
|
group.add_argument('--sample-rate', type=float, default=1.0, |
|
help='sample rate for training data. Supposed to be 0 ' |
|
' < sample_rate < 1') |
|
group.add_argument('--mask-prob', type=float, default=0.15, |
|
help='Probability of replacing a token with mask.') |
|
group.add_argument('--short-seq-prob', type=float, default=0.1, |
|
help='Probability of producing a short sequence.') |
|
group.add_argument('--mmap-warmup', action='store_true', |
|
help='Warm up mmap files.') |
|
group.add_argument('--num-workers', type=int, default=2, |
|
help="Dataloader number of workers.") |
|
group.add_argument('--tokenizer-type', type=str, |
|
default=None, |
|
choices=['BertWordPieceLowerCase', |
|
'BertWordPieceCase', |
|
'GPT2BPETokenizer'], |
|
help='What type of tokenizer to use.') |
|
group.add_argument('--data-impl', type=str, default='infer', |
|
choices=['lazy', 'cached', 'mmap', 'infer'], |
|
help='Implementation of indexed datasets.') |
|
group.add_argument('--reset-position-ids', action='store_true', |
|
help='Reset posistion ids after end-of-document token.') |
|
group.add_argument('--reset-attention-mask', action='store_true', |
|
help='Reset self attention maske after ' |
|
'end-of-document token.') |
|
group.add_argument('--eod-mask-loss', action='store_true', |
|
help='Mask loss for the end of document tokens.') |
|
|
|
return parser |
|
|
|
|
|
def _add_autoresume_args(parser): |
|
group = parser.add_argument_group(title='autoresume') |
|
|
|
group.add_argument('--adlr-autoresume', action='store_true', |
|
help='Enable autoresume on adlr cluster.') |
|
group.add_argument('--adlr-autoresume-interval', type=int, default=1000, |
|
help='Intervals over which check for autoresume' |
|
'termination signal') |
|
|
|
return parser |
|
|
|
|
|
def _add_biencoder_args(parser): |
|
group = parser.add_argument_group(title='biencoder') |
|
|
|
|
|
group.add_argument('--ict-head-size', type=int, default=None, |
|
help='Size of block embeddings to be used in ICT and ' |
|
'REALM (paper default: 128)') |
|
group.add_argument('--biencoder-projection-dim', type=int, default=0, |
|
help='Size of projection head used in biencoder (paper' |
|
' default: 128)') |
|
group.add_argument('--biencoder-shared-query-context-model', action='store_true', |
|
help='Whether to share the parameters of the query ' |
|
'and context models or not') |
|
|
|
|
|
group.add_argument('--ict-load', type=str, default=None, |
|
help='Directory containing an ICTBertModel checkpoint') |
|
group.add_argument('--bert-load', type=str, default=None, |
|
help='Directory containing an BertModel checkpoint ' |
|
'(needed to start ICT and REALM)') |
|
|
|
|
|
group.add_argument('--titles-data-path', type=str, default=None, |
|
help='Path to titles dataset used for ICT') |
|
group.add_argument('--query-in-block-prob', type=float, default=0.1, |
|
help='Probability of keeping query in block for ' |
|
'ICT dataset') |
|
group.add_argument('--use-one-sent-docs', action='store_true', |
|
help='Whether to use one sentence documents in ICT') |
|
group.add_argument('--evidence-data-path', type=str, default=None, |
|
help='Path to Wikipedia Evidence frm DPR paper') |
|
|
|
|
|
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, |
|
default=[], help="Which top-k accuracies to report " |
|
"(e.g. '1 5 20')") |
|
group.add_argument('--retriever-score-scaling', action='store_true', |
|
help='Whether to scale retriever scores by inverse ' |
|
'square root of hidden size') |
|
|
|
|
|
group.add_argument('--block-data-path', type=str, default=None, |
|
help='Where to save/load BlockData to/from') |
|
group.add_argument('--embedding-path', type=str, default=None, |
|
help='Where to save/load Open-Retrieval Embedding' |
|
' data to/from') |
|
|
|
|
|
group.add_argument('--indexer-batch-size', type=int, default=128, |
|
help='How large of batches to use when doing indexing ' |
|
'jobs') |
|
group.add_argument('--indexer-log-interval', type=int, default=1000, |
|
help='After how many batches should the indexer ' |
|
'report progress') |
|
return parser |
|
|
|
|
|
def _add_vision_args(parser): |
|
group = parser.add_argument_group(title="vision") |
|
|
|
|
|
group.add_argument('--num-classes', type=int, default=1000, |
|
help='num of classes in vision classificaiton task') |
|
group.add_argument('--img-h', type=int, default=224, |
|
help='Image height for vision classification task') |
|
group.add_argument('--img-w', type=int, default=224, |
|
help='Image height for vision classification task') |
|
group.add_argument('--num-channels', type=int, default=3, |
|
help='Number of channels in input image data') |
|
group.add_argument('--patch-dim', type=int, default=16, |
|
help='patch dimension') |
|
group.add_argument('--classes-fraction', type=float, default=1.0, |
|
help='training with fraction of classes.') |
|
group.add_argument('--data-per-class-fraction', type=float, default=1.0, |
|
help='training with fraction of data per class.') |
|
group.add_argument('--no-data-sharding', action='store_false', |
|
help='Disable data sharding.', |
|
dest='data_sharding') |
|
group.add_argument('--head-lr-mult', type=float, default=1.0, |
|
help='learning rate multiplier for head during finetuning') |
|
|
|
|
|
group.add_argument('--vision-pretraining', action='store_true', |
|
help='flag to indicate vision pretraining') |
|
group.add_argument('--vision-pretraining-type', type=str, default='classify', |
|
choices=['classify', 'inpaint', 'dino'], |
|
help='pretraining objectives') |
|
group.add_argument('--vision-backbone-type', type=str, default='vit', |
|
choices=['vit', 'mit', 'swin'], |
|
help='backbone types types') |
|
group.add_argument('--swin-backbone-type', type=str, default='tiny', |
|
choices=['tiny', 'base', 'h3'], |
|
help='pretraining objectives') |
|
|
|
|
|
group.add_argument('--mask-type', type=str, default='random', |
|
choices=['random', 'row'], |
|
help='mask types') |
|
group.add_argument('--mask-factor', type=float, default=1.0, |
|
help='mask size scaling parameter') |
|
|
|
|
|
group.add_argument('--iter-per-epoch', type=int, default=1250, |
|
help='iterations per epoch') |
|
group.add_argument('--dino-local-img-size', type=int, default=96, |
|
help='Image size for vision classification task') |
|
group.add_argument('--dino-local-crops-number', type=int, default=10, |
|
help='Number of local crops') |
|
group.add_argument('--dino-head-hidden-size', type=int, default=2048, |
|
help='Hidden dimension size in dino head') |
|
group.add_argument('--dino-bottleneck-size', type=int, default=256, |
|
help='Bottle neck dimension in dino head ') |
|
group.add_argument('--dino-freeze-last-layer', type=float, default=1, |
|
help='Freezing last layer weights') |
|
group.add_argument('--dino-norm-last-layer', action='store_true', |
|
help='Disable Norm in last layer.') |
|
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, |
|
help='warump teacher temperature') |
|
group.add_argument('--dino-teacher-temp', type=float, default=0.07, |
|
help='teacher temperature') |
|
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, |
|
help='warmup teacher temperaure epochs') |
|
|
|
return parser |
|
|