|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn.init as init |
|
from torch.nn.parameter import Parameter |
|
|
|
from .initialize import get_tensor_model_parallel_rank |
|
from .initialize import get_tensor_model_parallel_world_size |
|
from .initialize import get_tensor_model_parallel_group |
|
from .mappings import copy_to_tensor_model_parallel_region |
|
from .mappings import gather_from_tensor_model_parallel_region |
|
from .mappings import gather_from_sequence_parallel_region |
|
from .mappings import reduce_from_tensor_model_parallel_region |
|
from .mappings import scatter_to_tensor_model_parallel_region |
|
from .mappings import reduce_scatter_to_sequence_parallel_region |
|
|
|
from .random import get_cuda_rng_tracker |
|
from .utils import divide |
|
from .utils import split_tensor_along_last_dim |
|
from .utils import VocabUtility |
|
from megatron import get_args, get_global_memory_buffer |
|
|
|
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, |
|
'partition_dim': -1, |
|
'partition_stride': 1} |
|
|
|
def param_is_not_tensor_parallel_duplicate(param): |
|
return (hasattr(param, 'tensor_model_parallel') and |
|
param.tensor_model_parallel) or ( |
|
get_tensor_model_parallel_rank() == 0) |
|
|
|
|
|
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): |
|
|
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
|
assert not hasattr(tensor, attribute) |
|
|
|
setattr(tensor, 'tensor_model_parallel', is_parallel) |
|
setattr(tensor, 'partition_dim', dim) |
|
setattr(tensor, 'partition_stride', stride) |
|
|
|
|
|
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): |
|
def maybe_set(attribute, value): |
|
if not hasattr(tensor, attribute): |
|
setattr(tensor, attribute, value) |
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
|
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) |
|
|
|
|
|
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): |
|
def maybe_copy(attribute): |
|
if hasattr(source_tensor, attribute): |
|
setattr(destination_tensor, attribute, |
|
getattr(source_tensor, attribute)) |
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
|
maybe_copy(attribute) |
|
|
|
|
|
def _initialize_affine_weight_gpu(weight, init_method, |
|
partition_dim, stride=1): |
|
"""Initialize affine weight for model parallel on GPU.""" |
|
|
|
set_tensor_model_parallel_attributes(tensor=weight, |
|
is_parallel=True, |
|
dim=partition_dim, |
|
stride=stride) |
|
|
|
with get_cuda_rng_tracker().fork(): |
|
init_method(weight) |
|
|
|
|
|
def _initialize_affine_weight_cpu(weight, output_size, input_size, |
|
per_partition_size, partition_dim, |
|
init_method, stride=1, |
|
return_master_weight=False): |
|
"""Initialize affine weight for model parallel. |
|
|
|
Build the master weight on all processes and scatter |
|
the relevant chunk.""" |
|
|
|
set_tensor_model_parallel_attributes(tensor=weight, |
|
is_parallel=True, |
|
dim=partition_dim, |
|
stride=stride) |
|
|
|
|
|
master_weight = torch.empty(output_size, input_size, |
|
dtype=torch.float, |
|
requires_grad=False) |
|
init_method(master_weight) |
|
args = get_args() |
|
master_weight = master_weight.to(dtype=args.params_dtype) |
|
|
|
|
|
per_partition_per_stride_size = divide(per_partition_size, stride) |
|
weight_list = torch.split(master_weight, per_partition_per_stride_size, |
|
dim=partition_dim) |
|
rank = get_tensor_model_parallel_rank() |
|
world_size = get_tensor_model_parallel_world_size() |
|
my_weight_list = weight_list[rank::world_size] |
|
|
|
with torch.no_grad(): |
|
torch.cat(my_weight_list, dim=partition_dim, out=weight) |
|
if return_master_weight: |
|
return master_weight |
|
return None |
|
|
|
|
|
class VocabParallelEmbedding(torch.nn.Module): |
|
"""Embedding parallelized in the vocabulary dimension. |
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default |
|
values are kept. |
|
Arguments: |
|
num_embeddings: vocabulary size. |
|
embedding_dim: size of hidden state. |
|
init_method: method to initialize weights. |
|
""" |
|
|
|
def __init__(self, num_embeddings, embedding_dim, |
|
init_method=init.xavier_normal_): |
|
super(VocabParallelEmbedding, self).__init__() |
|
|
|
self.num_embeddings = num_embeddings |
|
self.embedding_dim = embedding_dim |
|
|
|
self.padding_idx = None |
|
self.max_norm = None |
|
self.norm_type = 2. |
|
self.scale_grad_by_freq = False |
|
self.sparse = False |
|
self._weight = None |
|
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() |
|
|
|
self.vocab_start_index, self.vocab_end_index = \ |
|
VocabUtility.vocab_range_from_global_vocab_size( |
|
self.num_embeddings, get_tensor_model_parallel_rank(), |
|
self.tensor_model_parallel_size) |
|
self.num_embeddings_per_partition = self.vocab_end_index - \ |
|
self.vocab_start_index |
|
|
|
|
|
args = get_args() |
|
if args.use_cpu_initialization: |
|
self.weight = Parameter(torch.empty( |
|
self.num_embeddings_per_partition, self.embedding_dim, |
|
dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
_initialize_affine_weight_cpu( |
|
self.weight, self.num_embeddings, self.embedding_dim, |
|
self.num_embeddings_per_partition, 0, init_method) |
|
else: |
|
self.weight = Parameter(torch.empty( |
|
self.num_embeddings_per_partition, self.embedding_dim, |
|
device=torch.cuda.current_device(), dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
_initialize_affine_weight_gpu(self.weight, init_method, |
|
partition_dim=0, stride=1) |
|
|
|
def forward(self, input_): |
|
if self.tensor_model_parallel_size > 1: |
|
|
|
input_mask = (input_ < self.vocab_start_index) | \ |
|
(input_ >= self.vocab_end_index) |
|
|
|
masked_input = input_.clone() - self.vocab_start_index |
|
masked_input[input_mask] = 0 |
|
else: |
|
masked_input = input_ |
|
|
|
output_parallel = F.embedding(masked_input, self.weight, |
|
self.padding_idx, self.max_norm, |
|
self.norm_type, self.scale_grad_by_freq, |
|
self.sparse) |
|
|
|
if self.tensor_model_parallel_size > 1: |
|
output_parallel[input_mask, :] = 0.0 |
|
|
|
output = reduce_from_tensor_model_parallel_region(output_parallel) |
|
return output |
|
|
|
|
|
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): |
|
""" |
|
Linear layer execution with asynchronous communication and gradient accumulation |
|
fusion in backprop. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, |
|
async_grad_allreduce, sequence_parallel): |
|
ctx.save_for_backward(input, weight) |
|
ctx.use_bias = bias is not None |
|
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion |
|
ctx.async_grad_allreduce = async_grad_allreduce |
|
ctx.sequence_parallel = sequence_parallel |
|
|
|
if sequence_parallel: |
|
world_size = get_tensor_model_parallel_world_size() |
|
dim_size = list(input.size()) |
|
dim_size[0] = dim_size[0] * world_size |
|
|
|
all_gather_buffer = \ |
|
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") |
|
torch.distributed._all_gather_base( |
|
all_gather_buffer, |
|
input, |
|
group=get_tensor_model_parallel_group()) |
|
total_input = all_gather_buffer |
|
else: |
|
total_input = input |
|
|
|
output = torch.matmul(total_input, weight.t()) |
|
if bias is not None: |
|
output = output + bias |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, weight = ctx.saved_tensors |
|
use_bias = ctx.use_bias |
|
|
|
if ctx.sequence_parallel: |
|
world_size = get_tensor_model_parallel_world_size() |
|
dim_size = list(input.size()) |
|
dim_size[0] = dim_size[0] * world_size |
|
|
|
all_gather_buffer = \ |
|
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") |
|
handle = torch.distributed._all_gather_base( |
|
all_gather_buffer, |
|
input, |
|
group=get_tensor_model_parallel_group(), async_op=True) |
|
|
|
|
|
|
|
_ = torch.empty(1, device=grad_output.device) + 1 |
|
total_input = all_gather_buffer |
|
else: |
|
total_input = input |
|
grad_input = grad_output.matmul(weight) |
|
|
|
if ctx.sequence_parallel: |
|
handle.wait() |
|
|
|
|
|
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], |
|
grad_output.shape[2]) |
|
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], |
|
total_input.shape[2]) |
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
handle = torch.distributed.all_reduce( |
|
grad_input, group=get_tensor_model_parallel_group(), async_op=True) |
|
|
|
|
|
_ = torch.empty(1, device=grad_output.device) + 1 |
|
|
|
if ctx.sequence_parallel: |
|
assert not ctx.async_grad_allreduce |
|
dim_size = list(input.size()) |
|
sub_grad_input = torch.empty(dim_size, dtype=input.dtype, |
|
device=torch.cuda.current_device(), |
|
requires_grad=False) |
|
|
|
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, |
|
group=get_tensor_model_parallel_group(), |
|
async_op=True) |
|
|
|
|
|
_ = torch.empty(1, device=grad_output.device) + 1 |
|
|
|
|
|
if ctx.gradient_accumulation_fusion: |
|
import fused_dense_cuda |
|
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) |
|
grad_weight = None |
|
else: |
|
grad_weight = grad_output.t().matmul(total_input) |
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
if ctx.sequence_parallel: |
|
handle.wait() |
|
return sub_grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
if ctx.async_grad_allreduce: |
|
handle.wait() |
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
|
|
class ColumnParallelLinear(torch.nn.Module): |
|
"""Linear layer with column parallelism. |
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along |
|
its second dimension as A = [A_1, ..., A_p]. |
|
|
|
Arguments: |
|
input_size: first dimension of matrix A. |
|
output_size: second dimension of matrix A. |
|
bias: If true, add bias |
|
gather_output: If true, call all-gather on output and make Y available |
|
to all GPUs, otherwise, every GPU will have its output |
|
which is Y_i = XA_i |
|
init_method: method to initialize weights. Note that bias is always set |
|
to zero. |
|
stride: For the strided linear layers. |
|
keep_master_weight_for_test: This was added for testing and should be |
|
set to False. It returns the master weights |
|
used for initialization. |
|
skip_bias_add: This was added to enable performance optimations where bias |
|
can be fused with other elementwise operations. we skip |
|
adding bias but instead return it. |
|
""" |
|
|
|
def __init__(self, input_size, output_size, bias=True, gather_output=True, |
|
init_method=init.xavier_normal_, stride=1, |
|
keep_master_weight_for_test=False, |
|
skip_bias_add=False): |
|
super(ColumnParallelLinear, self).__init__() |
|
|
|
|
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.gather_output = gather_output |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
self.output_size_per_partition = divide(output_size, world_size) |
|
self.skip_bias_add = skip_bias_add |
|
|
|
|
|
|
|
|
|
|
|
args = get_args() |
|
if args.use_cpu_initialization: |
|
self.weight = Parameter(torch.empty(self.output_size_per_partition, |
|
self.input_size, |
|
dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
self.master_weight = _initialize_affine_weight_cpu( |
|
self.weight, self.output_size, self.input_size, |
|
self.output_size_per_partition, 0, init_method, |
|
stride=stride, return_master_weight=keep_master_weight_for_test) |
|
else: |
|
self.weight = Parameter(torch.empty( |
|
self.output_size_per_partition, self.input_size, |
|
device=torch.cuda.current_device(), dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
_initialize_affine_weight_gpu(self.weight, init_method, |
|
partition_dim=0, stride=stride) |
|
|
|
if bias: |
|
if args.use_cpu_initialization: |
|
self.bias = Parameter(torch.empty( |
|
self.output_size_per_partition, dtype=args.params_dtype)) |
|
else: |
|
self.bias = Parameter(torch.empty( |
|
self.output_size_per_partition, |
|
device=torch.cuda.current_device(), |
|
dtype=args.params_dtype)) |
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) |
|
|
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter('bias', None) |
|
self.async_tensor_model_parallel_allreduce = ( |
|
args.async_tensor_model_parallel_allreduce and |
|
world_size > 1) |
|
self.sequence_parallel = ( |
|
args.sequence_parallel and |
|
world_size > 1) |
|
assert not self.async_tensor_model_parallel_allreduce or \ |
|
not self.sequence_parallel |
|
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion |
|
|
|
def forward(self, input_): |
|
bias = self.bias if not self.skip_bias_add else None |
|
|
|
if self.async_tensor_model_parallel_allreduce or \ |
|
self.sequence_parallel: |
|
input_parallel = input_ |
|
else: |
|
input_parallel = copy_to_tensor_model_parallel_region(input_) |
|
|
|
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( |
|
input_parallel, self.weight, bias, self.gradient_accumulation_fusion, |
|
self.async_tensor_model_parallel_allreduce, self.sequence_parallel) |
|
if self.gather_output: |
|
|
|
assert not self.sequence_parallel |
|
output = gather_from_tensor_model_parallel_region(output_parallel) |
|
else: |
|
output = output_parallel |
|
output_bias = self.bias if self.skip_bias_add else None |
|
return output, output_bias |
|
|
|
|
|
class RowParallelLinear(torch.nn.Module): |
|
"""Linear layer with row parallelism. |
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along |
|
its first dimension and X along its second dimension as: |
|
- - |
|
| A_1 | |
|
| . | |
|
A = | . | X = [X_1, ..., X_p] |
|
| . | |
|
| A_p | |
|
- - |
|
Arguments: |
|
input_size: first dimension of matrix A. |
|
output_size: second dimension of matrix A. |
|
bias: If true, add bias. Note that bias is not parallelized. |
|
input_is_parallel: If true, we assume that the input is already |
|
split across the GPUs and we do not split |
|
again. |
|
init_method: method to initialize weights. Note that bias is always set |
|
to zero. |
|
stride: For the strided linear layers. |
|
keep_master_weight_for_test: This was added for testing and should be |
|
set to False. It returns the master weights |
|
used for initialization. |
|
skip_bias_add: This was added to enable performance optimization where bias |
|
can be fused with other elementwise operations. We skip |
|
adding bias but instead return it. |
|
""" |
|
|
|
def __init__(self, input_size, output_size, bias=True, |
|
input_is_parallel=False, |
|
init_method=init.xavier_normal_, stride=1, |
|
keep_master_weight_for_test=False, |
|
skip_bias_add=False): |
|
super(RowParallelLinear, self).__init__() |
|
|
|
|
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.input_is_parallel = input_is_parallel |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
self.input_size_per_partition = divide(input_size, world_size) |
|
self.skip_bias_add = skip_bias_add |
|
|
|
|
|
|
|
|
|
|
|
args = get_args() |
|
if args.use_cpu_initialization: |
|
self.weight = Parameter(torch.empty(self.output_size, |
|
self.input_size_per_partition, |
|
dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
self.master_weight = _initialize_affine_weight_cpu( |
|
self.weight, self.output_size, self.input_size, |
|
self.input_size_per_partition, 1, init_method, |
|
stride=stride, return_master_weight=keep_master_weight_for_test) |
|
else: |
|
self.weight = Parameter(torch.empty( |
|
self.output_size, self.input_size_per_partition, |
|
device=torch.cuda.current_device(), dtype=args.params_dtype)) |
|
if args.perform_initialization: |
|
_initialize_affine_weight_gpu(self.weight, init_method, |
|
partition_dim=1, stride=stride) |
|
if bias: |
|
if args.use_cpu_initialization: |
|
self.bias = Parameter(torch.empty(self.output_size, |
|
dtype=args.params_dtype)) |
|
else: |
|
self.bias = Parameter(torch.empty( |
|
self.output_size, device=torch.cuda.current_device(), |
|
dtype=args.params_dtype)) |
|
setattr(self.bias, 'sequence_parallel', args.sequence_parallel) |
|
|
|
|
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter('bias', None) |
|
self.sequence_parallel = args.sequence_parallel |
|
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion |
|
|
|
|
|
|
|
def forward(self, input_): |
|
|
|
if self.input_is_parallel: |
|
input_parallel = input_ |
|
else: |
|
assert not self.sequence_parallel |
|
input_parallel = scatter_to_tensor_model_parallel_region(input_) |
|
|
|
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( |
|
input_parallel, self.weight, None, |
|
self.gradient_accumulation_fusion, None, None) |
|
|
|
if self.sequence_parallel: |
|
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) |
|
else: |
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel) |
|
if not self.skip_bias_add: |
|
output = output_ + self.bias if self.bias is not None else output_ |
|
output_bias = None |
|
else: |
|
output = output_ |
|
output_bias = self.bias |
|
return output, output_bias |
|
|