43 / Meissonic /InfinityStar /infinity /utils /sequence_parallel.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import torch
import torch.nn as nn
import torch.distributed as dist
from .comm.pg_utils import ProcessGroupManager
from .comm.comm import set_sp_comm_group, split_sequence, gather_sequence, all_to_all_comm
from .comm.operation import gather_forward_split_backward
class SequenceParallelManager:
_SP_GROUP = None
_SP_SIZE = 0
@staticmethod
def sp_on():
return SequenceParallelManager._SP_GROUP is not None
@staticmethod
def init_sp(sp_size):
if SequenceParallelManager._SP_GROUP is not None:
print("WARN: sequence parallel group is already initialized")
return
if sp_size <= 1:
print(f"WARN: sequence parallel size must > 1 but got {sp_size}")
return
world_size = dist.get_world_size()
assert world_size % sp_size == 0, f"world_size {world_size} must be divisible by sp_size({sp_size})"
SequenceParallelManager._SP_SIZE = sp_size
pm = ProcessGroupManager(
world_size // sp_size,
sp_size,
dp_axis=0,
sp_axis=1,
)
pm_group = pm.sp_group
set_sp_comm_group(pm_group)
SequenceParallelManager._SP_GROUP = pm_group
return
@staticmethod
def get_sp_group():
return SequenceParallelManager._SP_GROUP
@staticmethod
def get_sp_size():
return SequenceParallelManager._SP_SIZE
@staticmethod
def get_sp_group_nums():
# if 2 sp_size, 8 ranks, group nums is 4
if SequenceParallelManager.sp_on():
world_size = torch.distributed.get_world_size()
return world_size // SequenceParallelManager._SP_SIZE
else:
return 0
@staticmethod
def get_sp_rank():
if SequenceParallelManager.sp_on():
global_rank = torch.distributed.get_rank()
sp_rank = global_rank % SequenceParallelManager._SP_SIZE
return sp_rank
else:
return 0
def get_sp_group_rank():
if SequenceParallelManager.sp_on():
global_rank = torch.distributed.get_rank()
sp_group_rank = global_rank // SequenceParallelManager._SP_SIZE
return sp_group_rank
else:
return 0
def sp_split_sequence_by_dim(seq, seqlen_dim=1) -> torch.Tensor:
"""
split the raw sequence by seqlen_dim
"""
return split_sequence(seq, SequenceParallelManager.get_sp_group(), seqlen_dim, 'down')
def sp_gather_sequence_by_dim(seq, seqlen_dim=1) -> torch.Tensor:
"""
gather seqlen_dim to recover raw sequence
"""
return gather_sequence(seq, SequenceParallelManager.get_sp_group(), seqlen_dim, 'up')
def sp_all_to_all(ts, scatter_dim, gather_dim):
"""
reorder the tensor's dimension, like [raw_seq_len/sp_size, hidden_dim] to [raw_seq_len, hidden_dim/sp_size]
scatter_dim: the dimension to split the tensor
gather_dim: the dimension to concatenate
"""
return all_to_all_comm(ts, SequenceParallelManager.get_sp_group(), scatter_dim, gather_dim)