Spaces:
Runtime error
Runtime error
import torch | |
import torch.distributed as dist | |
# ==================== | |
# All-To-All | |
# ==================== | |
def _all_to_all( | |
input_: torch.Tensor, | |
world_size: int, | |
group: dist.ProcessGroup, | |
scatter_dim: int, | |
gather_dim: int, | |
): | |
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] | |
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] | |
dist.all_to_all(output_list, input_list, group=group) | |
return torch.cat(output_list, dim=gather_dim).contiguous() | |
class _AllToAll(torch.autograd.Function): | |
"""All-to-all communication. | |
Args: | |
input_: input matrix | |
process_group: communication group | |
scatter_dim: scatter dimension | |
gather_dim: gather dimension | |
""" | |
def forward(ctx, input_, process_group, scatter_dim, gather_dim): | |
ctx.process_group = process_group | |
ctx.scatter_dim = scatter_dim | |
ctx.gather_dim = gather_dim | |
ctx.world_size = dist.get_world_size(process_group) | |
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) | |
return output | |
def backward(ctx, grad_output): | |
grad_output = _all_to_all( | |
grad_output, | |
ctx.world_size, | |
ctx.process_group, | |
ctx.gather_dim, | |
ctx.scatter_dim, | |
) | |
return ( | |
grad_output, | |
None, | |
None, | |
None, | |
) | |
def all_to_all( | |
input_: torch.Tensor, | |
process_group: dist.ProcessGroup, | |
scatter_dim: int = 2, | |
gather_dim: int = 1, | |
): | |
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) | |
def _gather( | |
input_: torch.Tensor, | |
world_size: int, | |
group: dist.ProcessGroup, | |
gather_dim: int, | |
): | |
if gather_list is None: | |
gather_list = [torch.empty_like(input_) for _ in range(world_size)] | |
dist.gather(input_, gather_list, group=group, gather_dim=gather_dim) | |
return gather_list | |
# ==================== | |
# Gather-Split | |
# ==================== | |
def _split(input_, pg: dist.ProcessGroup, dim=-1): | |
# skip if only one rank involved | |
world_size = dist.get_world_size(pg) | |
rank = dist.get_rank(pg) | |
if world_size == 1: | |
return input_ | |
# Split along last dimension. | |
dim_size = input_.size(dim) | |
assert dim_size % world_size == 0, ( | |
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " | |
f"cannot split tensor evenly" | |
) | |
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) | |
output = tensor_list[rank].contiguous() | |
return output | |
def _gather(input_, pg: dist.ProcessGroup, dim=-1): | |
# skip if only one rank involved | |
input_ = input_.contiguous() | |
world_size = dist.get_world_size(pg) | |
dist.get_rank(pg) | |
if world_size == 1: | |
return input_ | |
# all gather | |
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] | |
assert input_.device.type == "cuda" | |
torch.distributed.all_gather(tensor_list, input_, group=pg) | |
# concat | |
output = torch.cat(tensor_list, dim=dim).contiguous() | |
return output | |
class _GatherForwardSplitBackward(torch.autograd.Function): | |
"""Gather the input from model parallel region and concatenate. | |
Args: | |
input_: input matrix. | |
process_group: parallel mode. | |
dim: dimension | |
""" | |
def symbolic(graph, input_): | |
return _gather(input_) | |
def forward(ctx, input_, process_group, dim, grad_scale): | |
ctx.mode = process_group | |
ctx.dim = dim | |
ctx.grad_scale = grad_scale | |
return _gather(input_, process_group, dim) | |
def backward(ctx, grad_output): | |
if ctx.grad_scale == "up": | |
grad_output = grad_output * dist.get_world_size(ctx.mode) | |
elif ctx.grad_scale == "down": | |
grad_output = grad_output / dist.get_world_size(ctx.mode) | |
return _split(grad_output, ctx.mode, ctx.dim), None, None, None | |
class _SplitForwardGatherBackward(torch.autograd.Function): | |
""" | |
Split the input and keep only the corresponding chuck to the rank. | |
Args: | |
input_: input matrix. | |
process_group: parallel mode. | |
dim: dimension | |
""" | |
def symbolic(graph, input_): | |
return _split(input_) | |
def forward(ctx, input_, process_group, dim, grad_scale): | |
ctx.mode = process_group | |
ctx.dim = dim | |
ctx.grad_scale = grad_scale | |
return _split(input_, process_group, dim) | |
def backward(ctx, grad_output): | |
if ctx.grad_scale == "up": | |
grad_output = grad_output * dist.get_world_size(ctx.mode) | |
elif ctx.grad_scale == "down": | |
grad_output = grad_output / dist.get_world_size(ctx.mode) | |
return _gather(grad_output, ctx.mode, ctx.dim), None, None, None | |
def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0): | |
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale) | |
def gather_forward_split_backward(input_, process_group, dim, grad_scale=None): | |
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale) | |