mm / src /genmo /mochi_preview /dit /joint_model /context_parallel.py
nruto's picture
Upload 31 files
d0bfdd6 verified
raw
history blame
5.14 kB
import torch
import torch.distributed as dist
from einops import rearrange
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_RANK = None
_CONTEXT_PARALLEL_GROUP_SIZE = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def get_cp_rank_size():
if _CONTEXT_PARALLEL_GROUP:
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
else:
return 0, 1
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return x
cp_rank, cp_size = get_cp_rank_size()
return x.tensor_split(cp_size, dim=dim)[cp_rank]
def set_cp_group(cp_group, ranks, global_rank):
global _CONTEXT_PARALLEL_GROUP, _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE, _CONTEXT_PARALLEL_GROUP_RANKS
if _CONTEXT_PARALLEL_GROUP is not None:
raise RuntimeError("CP group already initialized.")
_CONTEXT_PARALLEL_GROUP = cp_group
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
assert _CONTEXT_PARALLEL_RANK == ranks.index(
global_rank
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
ranks
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
def get_cp_group():
if _CONTEXT_PARALLEL_GROUP is None:
raise RuntimeError("CP group not initialized")
return _CONTEXT_PARALLEL_GROUP
def is_cp_active():
return _CONTEXT_PARALLEL_GROUP is not None
class AllGatherIntoTensorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
ctx.reduce_dtype = reduce_dtype
ctx.group = group
ctx.batch_size = x.size(0)
group_size = dist.get_world_size(group)
x = x.contiguous()
output = torch.empty(group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(output, x, group=group)
return output
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return tensor
return AllGatherIntoTensorFunction.apply(tensor, torch.float32, _CONTEXT_PARALLEL_GROUP)
@torch.compiler.disable()
def _all_to_all_single(output, input, group):
# Disable compilation since torch compile changes contiguity.
assert input.is_contiguous(), "Input tensor must be contiguous."
assert output.is_contiguous(), "Output tensor must be contiguous."
return dist.all_to_all_single(output, input, group=group)
class CollectTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
"""Redistribute heads and receive tokens.
Args:
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
Returns:
qkv: shape: [3, B, N, local_heads, head_dim]
where M is the number of local tokens,
N = cp_size * M is the number of global tokens,
local_heads = num_heads // cp_size is the number of local heads.
"""
ctx.group = group
ctx.num_heads = num_heads
cp_size = dist.get_world_size(group)
assert num_heads % cp_size == 0
ctx.local_heads = num_heads // cp_size
qkv = rearrange(
qkv,
"B M (qkv G h d) -> G M h B (qkv d)",
qkv=3,
G=cp_size,
h=ctx.local_heads,
).contiguous()
output_chunks = torch.empty_like(qkv)
_all_to_all_single(output_chunks, qkv, group=group)
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = x.size()
x = x.view(B, M, 3, num_heads, -1)
return x.permute(2, 0, 1, 3, 4)
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
class CollectHeads(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
"""Redistribute tokens and receive heads.
Args:
x: Output of attention. Shape: [B, N, local_heads, head_dim]
Returns:
Shape: [B, M, num_heads * head_dim]
"""
ctx.group = group
ctx.local_heads = x.size(2)
ctx.head_dim = x.size(3)
group_size = dist.get_world_size(group)
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
output = torch.empty_like(x)
_all_to_all_single(output, x, group=group)
del x
return rearrange(output, "G h M B D -> B M (G h D)")
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Merge heads.
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)