|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from megatron.core import mpu, parallel_state |
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
|
from torch.autograd import Function |
|
|
from torch.distributed import broadcast, get_process_group_ranks |
|
|
from transformer_engine.pytorch.jit import no_torch_dynamo |
|
|
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule |
|
|
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE |
|
|
from transformer_engine.pytorch.module.rmsnorm import _RMSNorm |
|
|
|
|
|
from cosmos_predict1.utils import log |
|
|
|
|
|
|
|
|
def get_batch_on_this_cp_rank(inputs): |
|
|
"""Slice batch input along sequence dimension into multiple chunks, |
|
|
which are parallelized across GPUs in a context parallel group. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cp_size = parallel_state.get_context_parallel_world_size() |
|
|
|
|
|
if cp_size > 1: |
|
|
cp_rank = mpu.get_context_parallel_rank() |
|
|
seq_dim = 1 |
|
|
inputs = inputs.view( |
|
|
*inputs.shape[0:seq_dim], |
|
|
2 * cp_size, |
|
|
inputs.shape[seq_dim] // (2 * cp_size), |
|
|
*inputs.shape[(seq_dim + 1) :], |
|
|
) |
|
|
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( |
|
|
non_blocking=True |
|
|
) |
|
|
inputs = inputs.index_select(seq_dim, index) |
|
|
inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :]) |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def gather_batch_from_cp_ranks(outputs): |
|
|
""" |
|
|
Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group. |
|
|
""" |
|
|
cp_size = parallel_state.get_context_parallel_world_size() |
|
|
cp_rank = mpu.get_context_parallel_rank() |
|
|
|
|
|
if cp_size > 1: |
|
|
seq_dim = 1 |
|
|
|
|
|
try: |
|
|
|
|
|
chunk_size = outputs.shape[seq_dim] // 2 |
|
|
outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :]) |
|
|
|
|
|
|
|
|
gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)] |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group()) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
reordered_chunks = [None] * (2 * cp_size) |
|
|
for i in range(cp_size): |
|
|
reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0) |
|
|
reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1) |
|
|
|
|
|
|
|
|
outputs = torch.cat(reordered_chunks, dim=seq_dim) |
|
|
except Exception as e: |
|
|
log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}") |
|
|
raise |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def broadcast_data_batch_in_tp_cp_group(data_batch): |
|
|
""" |
|
|
Broadcast data batch across tensor model parallel and context parallel groups. |
|
|
""" |
|
|
keys = sorted(data_batch.keys()) |
|
|
tp_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
cp_size = parallel_state.get_context_parallel_world_size() |
|
|
tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None |
|
|
cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None |
|
|
tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None |
|
|
cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None |
|
|
if tp_size > 1 or cp_size > 1: |
|
|
for key in keys: |
|
|
tensor = data_batch[key] |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
tensor = tensor.contiguous() |
|
|
if tp_size > 1: |
|
|
broadcast(tensor, min(tp_ranks), group=tp_group) |
|
|
if cp_size > 1: |
|
|
broadcast(tensor, min(cp_ranks), group=cp_group) |
|
|
|
|
|
|
|
|
def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool): |
|
|
""" |
|
|
All-reduce layernorm grads (for sequence parallelism). |
|
|
Note: |
|
|
- We skip QK Normalization layers and the last normalization layer of Transformer, |
|
|
since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass. |
|
|
- TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that |
|
|
we must all-reduce in the backward pass as well. So we implement this function to cover these parameters. |
|
|
""" |
|
|
|
|
|
|
|
|
if tensor_model_parallel_size > 1 and sequence_parallel: |
|
|
grads = [] |
|
|
for model_chunk in model: |
|
|
for name, param in model_chunk.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if name.endswith(".layer_norm_weight"): |
|
|
grad = param.grad |
|
|
if grad is not None: |
|
|
grads.append(grad.data) |
|
|
|
|
|
if grads: |
|
|
coalesced = _flatten_dense_tensors(grads) |
|
|
torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) |
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): |
|
|
buf.copy_(synced) |
|
|
|
|
|
|
|
|
def sync_1d_parameters(model: torch.nn.Module, process_group=None): |
|
|
""" |
|
|
Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation. |
|
|
LayerNorm parameters are identified by having ndim==1. |
|
|
Note: If parameters other than LayerNorm are 1D, they will also be synchronized. |
|
|
|
|
|
Args: |
|
|
model (torch.nn.Module): The model containing layernorm parameters |
|
|
process_group (optional): The process group to perform all-reduce. |
|
|
If None, uses the default process group. |
|
|
""" |
|
|
if not torch.distributed.is_initialized(): |
|
|
return |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.ndim == 1 and param.requires_grad: |
|
|
torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group) |
|
|
|
|
|
|
|
|
class AllReduceBWD(Function): |
|
|
""" |
|
|
Custom autograd Function that performs an all-reduce operation during the backward pass. |
|
|
|
|
|
Args: |
|
|
tensor (Tensor): The input tensor. |
|
|
process_group: The process group to perform the all-reduce operation. |
|
|
|
|
|
Returns: |
|
|
Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, tensor, process_group): |
|
|
ctx.process_group = process_group |
|
|
return tensor |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
dist.all_reduce(grad_output, group=ctx.process_group) |
|
|
return grad_output, None |
|
|
|
|
|
|
|
|
class AllReduceBWDRMSNormTE(RMSNormTE): |
|
|
""" |
|
|
A custom RMSNorm layer that applies all-reduce operation during backward pass. |
|
|
Used in tensor parallel training with Transformer Engine. |
|
|
|
|
|
Args: |
|
|
hidden_size (int): The size of the hidden dimension. |
|
|
process_group: Megatron Core's process group. |
|
|
**kwargs: Additional arguments to be passed to RMSNormTE. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size, process_group, **kwargs): |
|
|
super().__init__(hidden_size, **kwargs) |
|
|
self.process_group = process_group |
|
|
|
|
|
@no_torch_dynamo() |
|
|
def forward(self, inp: torch.Tensor) -> torch.Tensor: |
|
|
"""RMSNorm FWD""" |
|
|
|
|
|
|
|
|
TransformerEngineBaseModule.set_activation_dtype(self, inp) |
|
|
|
|
|
if torch.is_grad_enabled(): |
|
|
fwd_fn = _RMSNorm.apply |
|
|
args = [] |
|
|
else: |
|
|
fwd_fn = _RMSNorm.forward |
|
|
args = [None] |
|
|
|
|
|
args += ( |
|
|
inp, |
|
|
AllReduceBWD.apply(self.weight, self.process_group), |
|
|
self.eps, |
|
|
self.fwd_rmsnorm_sm_margin, |
|
|
self.bwd_rmsnorm_sm_margin, |
|
|
self.inf_rmsnorm_sm_margin, |
|
|
self.zero_centered_gamma, |
|
|
torch.is_grad_enabled(), |
|
|
self.activation_dtype, |
|
|
) |
|
|
|
|
|
return fwd_fn(*args) |
|
|
|