| |
| |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| |
| |
| |
| def _all_to_all( |
| input_: torch.Tensor, |
| world_size: int, |
| group: dist.ProcessGroup, |
| scatter_dim: int, |
| gather_dim: int, |
| ): |
| input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
| output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
| dist.all_to_all(output_list, input_list, group=group) |
| return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
| class _AllToAll(torch.autograd.Function): |
| """All-to-all communication. |
| |
| Args: |
| input_: input matrix |
| process_group: communication group |
| scatter_dim: scatter dimension |
| gather_dim: gather dimension |
| """ |
|
|
| @staticmethod |
| def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
| ctx.process_group = process_group |
| ctx.scatter_dim = scatter_dim |
| ctx.gather_dim = gather_dim |
| ctx.world_size = dist.get_world_size(process_group) |
| output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| grad_output = _all_to_all( |
| grad_output, |
| ctx.world_size, |
| ctx.process_group, |
| ctx.gather_dim, |
| ctx.scatter_dim, |
| ) |
| return ( |
| grad_output, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| def all_to_all( |
| input_: torch.Tensor, |
| process_group: dist.ProcessGroup, |
| scatter_dim: int = 2, |
| gather_dim: int = 1, |
| ): |
| return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|
| def _gather( |
| input_: torch.Tensor, |
| world_size: int, |
| group: dist.ProcessGroup, |
| gather_dim: int, |
| ): |
| if gather_list is None: |
| gather_list = [torch.empty_like(input_) for _ in range(world_size)] |
| dist.gather(input_, gather_list, group=group, gather_dim=gather_dim) |
| return gather_list |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _split(input_, pg: dist.ProcessGroup, dim=-1): |
| |
| world_size = dist.get_world_size(pg) |
| rank = dist.get_rank(pg) |
| if world_size == 1: |
| return input_ |
|
|
| |
| dim_size = input_.size(dim) |
| assert dim_size % world_size == 0, ( |
| f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " |
| f"cannot split tensor evenly" |
| ) |
|
|
| tensor_list = torch.split(input_, dim_size // world_size, dim=dim) |
| output = tensor_list[rank].contiguous() |
|
|
| return output |
|
|
|
|
| def _gather(input_, pg: dist.ProcessGroup, dim=-1): |
| |
| input_ = input_.contiguous() |
| world_size = dist.get_world_size(pg) |
| dist.get_rank(pg) |
|
|
| if world_size == 1: |
| return input_ |
|
|
| |
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
| assert input_.device.type == "cuda" |
| torch.distributed.all_gather(tensor_list, input_, group=pg) |
|
|
| |
| output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
| return output |
|
|
|
|
| class _GatherForwardSplitBackward(torch.autograd.Function): |
| """Gather the input from model parallel region and concatenate. |
| |
| Args: |
| input_: input matrix. |
| process_group: parallel mode. |
| dim: dimension |
| """ |
|
|
| @staticmethod |
| def symbolic(graph, input_): |
| return _gather(input_) |
|
|
| @staticmethod |
| def forward(ctx, input_, process_group, dim, grad_scale): |
| ctx.mode = process_group |
| ctx.dim = dim |
| ctx.grad_scale = grad_scale |
| return _gather(input_, process_group, dim) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| if ctx.grad_scale == "up": |
| grad_output = grad_output * dist.get_world_size(ctx.mode) |
| elif ctx.grad_scale == "down": |
| grad_output = grad_output / dist.get_world_size(ctx.mode) |
|
|
| return _split(grad_output, ctx.mode, ctx.dim), None, None, None |
|
|
|
|
| class _SplitForwardGatherBackward(torch.autograd.Function): |
| """ |
| Split the input and keep only the corresponding chuck to the rank. |
| |
| Args: |
| input_: input matrix. |
| process_group: parallel mode. |
| dim: dimension |
| """ |
|
|
| @staticmethod |
| def symbolic(graph, input_): |
| return _split(input_) |
|
|
| @staticmethod |
| def forward(ctx, input_, process_group, dim, grad_scale): |
| ctx.mode = process_group |
| ctx.dim = dim |
| ctx.grad_scale = grad_scale |
| return _split(input_, process_group, dim) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| if ctx.grad_scale == "up": |
| grad_output = grad_output * dist.get_world_size(ctx.mode) |
| elif ctx.grad_scale == "down": |
| grad_output = grad_output / dist.get_world_size(ctx.mode) |
| return _gather(grad_output, ctx.mode, ctx.dim), None, None, None |
|
|
|
|
| def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0): |
| return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale) |
|
|
|
|
| def gather_forward_split_backward(input_, process_group, dim, grad_scale=None): |
| return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale) |