|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Megatron initialization.""" |
|
|
|
import random |
|
import os |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
from datetime import timedelta |
|
|
|
from megatron import fused_kernels |
|
from megatron import get_adlr_autoresume |
|
from megatron import get_args |
|
from megatron import get_tensorboard_writer |
|
from megatron import mpu |
|
from megatron.arguments import (parse_args, validate_args) |
|
from megatron.checkpointing import load_args_from_checkpoint |
|
from megatron.global_vars import set_global_variables |
|
from megatron.mpu import (set_tensor_model_parallel_rank, |
|
set_tensor_model_parallel_world_size) |
|
from megatron.model.transformer import bias_dropout_add_fused_train |
|
from megatron.model.fused_bias_gelu import bias_gelu |
|
|
|
|
|
def initialize_megatron(extra_args_provider=None, args_defaults={}, |
|
ignore_unknown_args=False, allow_no_cuda=False): |
|
"""Set global variables, initialize distributed, and |
|
set autoresume and random seeds. |
|
`allow_no_cuda` should not be set unless using megatron for cpu only |
|
data processing. In general this arg should not be set unless you know |
|
what you are doing. |
|
Returns a function to finalize distributed env initialization |
|
(optionally, only when args.lazy_mpu_init == True) |
|
""" |
|
if not allow_no_cuda: |
|
|
|
assert torch.cuda.is_available(), 'Megatron requires CUDA.' |
|
|
|
|
|
args = parse_args(extra_args_provider, ignore_unknown_args) |
|
|
|
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False): |
|
assert args.load is not None, '--use-checkpoints-args requires --load argument' |
|
load_args_from_checkpoint(args) |
|
|
|
validate_args(args, args_defaults) |
|
|
|
|
|
|
|
set_global_variables(args) |
|
|
|
|
|
def finish_mpu_init(): |
|
args = get_args() |
|
|
|
_initialize_distributed() |
|
|
|
|
|
if args.rank == 0: |
|
print('> setting random seeds to {} ...'.format(args.seed)) |
|
_set_random_seed(args.seed, args.data_parallel_random_init) |
|
|
|
args = get_args() |
|
if args.lazy_mpu_init: |
|
args.use_cpu_initialization=True |
|
|
|
|
|
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) |
|
|
|
|
|
set_tensor_model_parallel_rank(args.rank) |
|
return finish_mpu_init |
|
else: |
|
|
|
finish_mpu_init() |
|
|
|
|
|
_init_autoresume() |
|
|
|
|
|
_compile_dependencies() |
|
|
|
|
|
return None |
|
|
|
|
|
def _compile_dependencies(): |
|
|
|
args = get_args() |
|
|
|
|
|
|
|
|
|
|
|
if torch.distributed.get_rank() == 0: |
|
start_time = time.time() |
|
print('> compiling dataset index builder ...') |
|
from megatron.data.dataset_utils import compile_helper |
|
compile_helper() |
|
print('>>> done with dataset index builder. Compilation time: {:.3f} ' |
|
'seconds'.format(time.time() - start_time), flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = args.seq_length |
|
attn_batch_size = \ |
|
(args.num_attention_heads / args.tensor_model_parallel_size) * \ |
|
args.micro_batch_size |
|
|
|
|
|
custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \ |
|
seq_len % 4 == 0 and attn_batch_size % 4 == 0 |
|
|
|
if not ((args.fp16 or args.bf16) and |
|
custom_kernel_constraint and |
|
args.masked_softmax_fusion): |
|
if args.rank == 0: |
|
print('WARNING: constraints for invoking optimized' |
|
' fused softmax kernel are not met. We default' |
|
' back to unfused kernel invocations.', flush=True) |
|
|
|
|
|
if torch.distributed.get_rank() == 0: |
|
start_time = time.time() |
|
print('> compiling and loading fused kernels ...', flush=True) |
|
fused_kernels.load(args) |
|
torch.distributed.barrier() |
|
else: |
|
torch.distributed.barrier() |
|
fused_kernels.load(args) |
|
|
|
|
|
|
|
|
|
torch.distributed.barrier() |
|
if torch.distributed.get_rank() == 0: |
|
print('>>> done with compiling and loading fused kernels. ' |
|
'Compilation time: {:.3f} seconds'.format( |
|
time.time() - start_time), flush=True) |
|
|
|
|
|
|
|
def _initialize_distributed(): |
|
"""Initialize torch.distributed and mpu.""" |
|
args = get_args() |
|
|
|
device_count = torch.cuda.device_count() |
|
if torch.distributed.is_initialized(): |
|
|
|
if args.rank == 0: |
|
print('torch distributed is already initialized, ' |
|
'skipping initialization ...', flush=True) |
|
args.rank = torch.distributed.get_rank() |
|
args.world_size = torch.distributed.get_world_size() |
|
|
|
else: |
|
|
|
if args.rank == 0: |
|
print('> initializing torch distributed ...', flush=True) |
|
|
|
if device_count > 0: |
|
device = args.rank % device_count |
|
if args.local_rank is not None: |
|
assert args.local_rank == device, \ |
|
'expected local-rank to be the same as rank % device-count.' |
|
else: |
|
args.local_rank = device |
|
torch.cuda.set_device(device) |
|
|
|
torch.distributed.init_process_group( |
|
backend=args.distributed_backend, |
|
world_size=args.world_size, rank=args.rank, |
|
timeout=timedelta(minutes=10)) |
|
|
|
|
|
|
|
if device_count > 0: |
|
if mpu.model_parallel_is_initialized(): |
|
print('model parallel is already initialized') |
|
else: |
|
mpu.initialize_model_parallel(args.tensor_model_parallel_size, |
|
args.pipeline_model_parallel_size, |
|
args.virtual_pipeline_model_parallel_size, |
|
args.pipeline_model_parallel_split_rank) |
|
|
|
|
|
def _init_autoresume(): |
|
"""Set autoresume start time.""" |
|
autoresume = get_adlr_autoresume() |
|
if autoresume: |
|
torch.distributed.barrier() |
|
autoresume.init() |
|
torch.distributed.barrier() |
|
|
|
|
|
def _set_random_seed(seed_, data_parallel_random_init=False): |
|
"""Set random seed for reproducability.""" |
|
if seed_ is not None and seed_ > 0: |
|
|
|
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) |
|
|
|
if data_parallel_random_init: |
|
seed = seed + (10 * mpu.get_data_parallel_rank()) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.device_count() > 0: |
|
mpu.model_parallel_cuda_manual_seed(seed) |
|
else: |
|
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) |
|
|
|
|
|
def write_args_to_tensorboard(): |
|
"""Write arguments to tensorboard.""" |
|
args = get_args() |
|
writer = get_tensorboard_writer() |
|
if writer: |
|
for arg in vars(args): |
|
writer.add_text(arg, str(getattr(args, arg)), |
|
global_step=args.iteration) |
|
|
|
|
|
def set_jit_fusion_options(): |
|
"""Set PyTorch JIT layer fusion options.""" |
|
|
|
TORCH_MAJOR = int(torch.__version__.split('.')[0]) |
|
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
|
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): |
|
|
|
torch._C._jit_set_profiling_executor(True) |
|
torch._C._jit_set_profiling_mode(True) |
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(True) |
|
torch._C._debug_set_autodiff_subgraph_inlining(False) |
|
else: |
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
torch._C._jit_set_profiling_executor(False) |
|
torch._C._jit_override_can_fuse_on_cpu(True) |
|
torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
_warmup_jit_function() |
|
|
|
|
|
def _warmup_jit_function(): |
|
""" Compilie JIT functions before the main training steps """ |
|
args = get_args() |
|
if args.bf16: |
|
dtype = torch.bfloat16 |
|
elif args.fp16: |
|
dtype = torch.float16 |
|
else: |
|
dtype = torch.float32 |
|
|
|
|
|
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size, |
|
dtype=dtype, device='cuda') |
|
input = torch.rand((args.seq_length, args.micro_batch_size, |
|
args.ffn_hidden_size // args.tensor_model_parallel_size), |
|
dtype=dtype, device='cuda') |
|
|
|
|
|
for bias_grad, input_grad in zip([True, True], [False, True]): |
|
bias.requires_grad, input.requires_grad = bias_grad, input_grad |
|
for _ in range(5): |
|
output = bias_gelu(bias, input) |
|
del bias, input, output |
|
|
|
|
|
if args.sequence_parallel: |
|
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() |
|
else: |
|
seq_length = args.seq_length |
|
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), |
|
dtype=dtype, device='cuda') |
|
residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), |
|
dtype=dtype, device='cuda') |
|
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual) |
|
dropout_rate = 0.1 |
|
|
|
|
|
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): |
|
input.requires_grad = input_grad |
|
bias.requires_grad = bias_grad |
|
residual.requires_grad = residual_grad |
|
for _ in range(5): |
|
output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate) |
|
del bias, input, residual, output |
|
torch.cuda.empty_cache() |
|
|