Hunyuan-Avatar / hymm_sp /modules /parallel_states.py
rahul7star's picture
Upload 99 files
357c94c verified
import os
import torch
import datetime
import torch.distributed as dist
from typing import Any, Tuple
from torch import Tensor
from flash_attn.flash_attn_interface import flash_attn_varlen_func
class COMM_INFO:
def __init__(self):
self.group = None
self.sp_size = 1
self.global_rank = 0
self.rank_within_group = 0
self.group_id = 0
nccl_info = COMM_INFO()
_SEQUENCE_PARALLEL_STATE = False
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def initialize_sequence_parallel_state(sequence_parallel_size):
global _SEQUENCE_PARALLEL_STATE
if sequence_parallel_size > 1:
_SEQUENCE_PARALLEL_STATE = True
initialize_sequence_parallel_group(sequence_parallel_size)
else:
nccl_info.sp_size = 1
nccl_info.global_rank = int(os.getenv("RANK", "0"))
nccl_info.rank_within_group = 0
nccl_info.group_id = int(os.getenv("RANK", "0"))
def get_sequence_parallel_state():
return _SEQUENCE_PARALLEL_STATE
def initialize_sequence_parallel_group(sequence_parallel_size):
"""Initialize the sequence parallel group."""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
assert (
world_size % sequence_parallel_size == 0
), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
world_size, sequence_parallel_size)
nccl_info.sp_size = sequence_parallel_size
nccl_info.global_rank = rank
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
nccl_info.group = group
nccl_info.rank_within_group = rank - i * sequence_parallel_size
nccl_info.group_id = i
def initialize_distributed(seed):
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=2**31-1), world_size=world_size, rank=local_rank)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
initialize_sequence_parallel_state(world_size)
def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
"""
all-to-all for QKV
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
seq_world_size = dist.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input.shape
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous())
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, bs, shard_hc, hs)
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
return output
elif scatter_idx == 1 and gather_idx == 2:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, shard_hc, hs = input.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size
seq_world_size = dist.get_world_size(group)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc,
hs).transpose(0,
3).transpose(0,
1).contiguous().reshape(seq_world_size, shard_hc,
shard_seqlen, bs, hs))
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, bs, hs)
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
return output
else:
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (
None,
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
None,
None,
)
def all_to_all_4D(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
class _AllGather(torch.autograd.Function):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
world_size = nccl_info.sp_size
group = nccl_info.group
input_size = list(input_.size())
ctx.input_size = input_size[dim]
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
input_ = input_.contiguous()
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
world_size = nccl_info.sp_size
rank = nccl_info.rank_within_group
dim = ctx.dim
input_size = ctx.input_size
sizes = [input_size] * world_size
grad_input_list = torch.split(grad_output, sizes, dim=dim)
grad_input = grad_input_list[rank]
return grad_input, None
def all_gather(input_: torch.Tensor, dim: int = 1):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return _AllGather.apply(input_, dim)
def parallel_attention(q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,):
"""
img_q_len,img_kv_len: 32256
text_mask: 2x256
query: [2, 32256, 24, 128])
encoder_query: [2, 256, 24, 128]
"""
query, encoder_query = q
key, encoder_key = k
value, encoder_value = v
rank = torch.distributed.get_rank()
if get_sequence_parallel_state():
query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
def shrink_head(encoder_state, dim):
local_heads = encoder_state.shape[dim] // nccl_info.sp_size
return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
encoder_query = shrink_head(encoder_query, dim=2)
encoder_key = shrink_head(encoder_key, dim=2)
encoder_value = shrink_head(encoder_value, dim=2)
sequence_length = query.size(1) # 32256
encoder_sequence_length = encoder_query.size(1) # 256
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
bsz = query.shape[0]
head = query.shape[-2]
head_dim = query.shape[-1]
query, key, value = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query, key, value]
]
hidden_states = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
# B, S, 3, H, D
hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous()
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length),
dim=1)
if get_sequence_parallel_state():
hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
b, s, _, _= attn.shape
attn = attn.reshape(b, s, -1)
return attn, None