Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| from torch.distributed import ProcessGroup | |
| # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for | |
| # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent | |
| # version of PyTorch. The following 4 lines are for backward compatibility with | |
| # older PyTorch. | |
| if "all_gather_into_tensor" not in dir(torch.distributed): | |
| torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base | |
| if "reduce_scatter_tensor" not in dir(torch.distributed): | |
| torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base | |
| # Raw operation, does not support autograd, but does support async | |
| def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
| world_size = torch.distributed.get_world_size(process_group) | |
| output = torch.empty( | |
| world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device | |
| ) | |
| handle = torch.distributed.all_gather_into_tensor( | |
| output, input_.contiguous(), group=process_group, async_op=async_op | |
| ) | |
| return output, handle | |
| # Raw operation, does not support autograd, but does support async | |
| def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
| world_size = torch.distributed.get_world_size(process_group) | |
| assert input_.shape[0] % world_size == 0 | |
| output = torch.empty( | |
| input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device | |
| ) | |
| handle = torch.distributed.reduce_scatter_tensor( | |
| output, input_.contiguous(), group=process_group, async_op=async_op | |
| ) | |
| return output, handle | |
| # Raw operation, does not support autograd, but does support async | |
| def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
| input_ = input_.contiguous() | |
| handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) | |
| return input_, handle | |
| class AllGatherFunc(torch.autograd.Function): | |
| """Gather the input from sequence parallel region and concatenate.""" | |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
| ctx.process_group = process_group | |
| output, _ = all_gather_raw(input_, process_group) | |
| return output | |
| def backward(ctx, grad_output: Tensor): | |
| grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) | |
| return grad_input, None | |
| # Supports autograd, but does not support async | |
| all_gather = AllGatherFunc.apply | |
| class ReduceScatterFunc(torch.autograd.Function): | |
| """Reduce scatter the input from the sequence parallel region and concatenate.""" | |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
| ctx.process_group = process_group | |
| output, _ = reduce_scatter_raw(input_, process_group) | |
| return output | |
| def backward(ctx, grad_output: Tensor): | |
| grad_input, _ = all_gather_raw(grad_output, ctx.process_group) | |
| return grad_input, None | |
| # Supports autograd, but does not support async | |
| reduce_scatter = ReduceScatterFunc.apply | |
| class AllReduceFunc(torch.autograd.Function): | |
| """Gather the input from sequence parallel region and concatenate.""" | |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
| ctx.process_group = process_group | |
| output, _ = all_reduce_raw(input_, process_group) | |
| return output | |
| def backward(ctx, grad_output: Tensor): | |
| return grad_output, None | |
| # Supports autograd, but does not support async | |
| all_reduce = AllReduceFunc.apply | |
| def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): | |
| # We want to iterate over parameters with _shared_params=True in the same order, | |
| # as different ranks might have different number of parameters (e.g., only rank 0 has bias). | |
| pamams_shared = { | |
| name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) | |
| } | |
| for _, p in sorted(pamams_shared.items()): | |
| with torch.no_grad(): | |
| # Broadcast needs src to be global rank, not group rank | |
| torch.distributed.broadcast( | |
| p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group | |
| ) | |
| # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 | |
| def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): | |
| # We want to iterate over parameters with _sequence_parallel=True in the same order, | |
| # as different ranks might have different number of parameters (e.g., only rank 0 has bias). | |
| params_seqparallel = { | |
| name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) | |
| } | |
| grads = [p.grad for _, p in sorted(params_seqparallel.items())] | |
| if grads: | |
| with torch.no_grad(): | |
| coalesced = torch._utils._flatten_dense_tensors(grads) | |
| torch.distributed.all_reduce(coalesced, group=process_group) | |
| for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): | |
| buf.copy_(synced) | |
| def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: | |
| """Get the dim for the local rank derived from splitting dim on world_size processes. | |
| The split may not be even across the world_size processes. | |
| """ | |
| multiple = dim // multiple_of | |
| div = multiple // world_size | |
| mod = multiple % world_size | |
| local_multiple = div + int(local_rank < mod) | |
| return local_multiple * multiple_of | |