|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model and data parallel groups.""" |
|
|
|
import torch |
|
|
|
from .utils import ensure_divisibility |
|
|
|
|
|
_MODEL_PARALLEL_GROUP = None |
|
|
|
_DATA_PARALLEL_GROUP = None |
|
|
|
_PIPE_PARALLEL_GROUP = None |
|
|
|
|
|
|
|
|
|
_IO_PARALLEL_GROUP = None |
|
|
|
|
|
_MPU_WORLD_SIZE = None |
|
_MPU_RANK = None |
|
|
|
|
|
_MPU_TOPOLOGY = None |
|
|
|
|
|
_FP32_ALLREDUCE = None |
|
|
|
|
|
def is_unitialized(): |
|
"""Useful for code segments that may be accessed with or without mpu initialization""" |
|
return _DATA_PARALLEL_GROUP is None |
|
|
|
|
|
def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False): |
|
""" |
|
Initialize model data parallel groups. |
|
|
|
Arguments: |
|
model_parallel_size: number of GPUs used to parallelize model. |
|
|
|
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we |
|
use 2 GPUs to parallelize the model. The present function will |
|
create 4 model parallel groups and 2 data parallel groups as: |
|
4 model parallel groups: |
|
[g0, g1], [g2, g3], [g4, g5], [g6, g7] |
|
2 data parallel groups: |
|
[g0, g2, g4, g6], [g1, g3, g5, g7] |
|
Note that for efficiency, the caller should make sure adjacent ranks |
|
are on the same DGX box. For example if we are using 2 DGX-1 boxes |
|
with a total of 16 GPUs, rank 0 to 7 belong to the first box and |
|
ranks 8 to 15 belong to the second box. |
|
""" |
|
if torch.distributed.get_rank() == 0: |
|
print("> initializing model parallel with size {}".format(model_parallel_size)) |
|
|
|
assert torch.distributed.is_initialized() |
|
world_size = torch.distributed.get_world_size() |
|
if world_size < model_parallel_size: |
|
raise ValueError("world size cannot be smaller than model parallel size") |
|
ensure_divisibility(world_size, model_parallel_size) |
|
rank = torch.distributed.get_rank() |
|
|
|
global _MPU_TOPOLOGY |
|
if topology: |
|
_MPU_TOPOLOGY = topology |
|
|
|
|
|
global _DATA_PARALLEL_GROUP |
|
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" |
|
if topology: |
|
for dp_group in topology.get_axis_comm_lists("data"): |
|
group = torch.distributed.new_group(ranks=dp_group) |
|
if rank == 0: |
|
print(f"MPU DP:", dp_group) |
|
if rank in dp_group: |
|
_DATA_PARALLEL_GROUP = group |
|
else: |
|
for i in range(model_parallel_size): |
|
ranks = range(i, world_size, model_parallel_size) |
|
group = torch.distributed.new_group(ranks) |
|
if i == (rank % model_parallel_size): |
|
_DATA_PARALLEL_GROUP = group |
|
|
|
|
|
if topology is not None: |
|
global _PIPE_PARALLEL_GROUP |
|
for pp_group in topology.get_axis_comm_lists("pipe"): |
|
group = torch.distributed.new_group(ranks=pp_group) |
|
if rank == 0: |
|
print(f"MPU PP:", pp_group) |
|
if rank in pp_group: |
|
_PIPE_PARALLEL_GROUP = group |
|
|
|
|
|
global _IO_PARALLEL_GROUP |
|
if topology and topology.get_dim("pipe") > 1: |
|
io_stages = [0, topology.get_dim("pipe") - 1] |
|
io_group = [] |
|
for stage in io_stages: |
|
io_group.extend(topology.filter_match(pipe=stage, model=0)) |
|
if rank == 0: |
|
print(f"MPU IO:", io_group) |
|
group = torch.distributed.new_group(ranks=io_group) |
|
if rank in io_group: |
|
_IO_PARALLEL_GROUP = group |
|
else: |
|
_IO_PARALLEL_GROUP = get_data_parallel_group() |
|
|
|
|
|
global _MODEL_PARALLEL_GROUP |
|
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" |
|
if topology: |
|
|
|
|
|
if model_parallel_size == 1: |
|
for group_rank in range(world_size): |
|
group = torch.distributed.new_group(ranks=[group_rank]) |
|
if rank == 0: |
|
print(f"MPU MP:", [group_rank]) |
|
if rank == group_rank: |
|
_MODEL_PARALLEL_GROUP = group |
|
return |
|
|
|
for mp_group in topology.get_axis_comm_lists("model"): |
|
group = torch.distributed.new_group(ranks=mp_group) |
|
if rank == 0: |
|
print(f"MPU MP:", mp_group) |
|
if rank in mp_group: |
|
_MODEL_PARALLEL_GROUP = group |
|
|
|
else: |
|
for i in range(world_size // model_parallel_size): |
|
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) |
|
group = torch.distributed.new_group(ranks) |
|
if i == (rank // model_parallel_size): |
|
_MODEL_PARALLEL_GROUP = group |
|
|
|
global _FP32_ALLREDUCE |
|
assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" |
|
_FP32_ALLREDUCE = fp32_allreduce |
|
|
|
|
|
def model_parallel_is_initialized(): |
|
"""Check if model and data parallel groups are initialized.""" |
|
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: |
|
return False |
|
return True |
|
|
|
|
|
def get_model_parallel_group(): |
|
"""Get the model parallel group the caller rank belongs to.""" |
|
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" |
|
return _MODEL_PARALLEL_GROUP |
|
|
|
|
|
def get_data_parallel_group(): |
|
"""Get the data parallel group the caller rank belongs to.""" |
|
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" |
|
return _DATA_PARALLEL_GROUP |
|
|
|
|
|
def get_io_parallel_group(): |
|
"""Get the IO parallel group the caller rank belongs to.""" |
|
assert _IO_PARALLEL_GROUP is not None, "IO parallel group is not initialized" |
|
return _IO_PARALLEL_GROUP |
|
|
|
|
|
def set_model_parallel_world_size(world_size): |
|
"""Set the model parallel size""" |
|
global _MPU_WORLD_SIZE |
|
_MPU_WORLD_SIZE = world_size |
|
|
|
|
|
def get_model_parallel_world_size(): |
|
"""Return world size for the model parallel group.""" |
|
global _MPU_WORLD_SIZE |
|
if _MPU_WORLD_SIZE is not None: |
|
return _MPU_WORLD_SIZE |
|
return torch.distributed.get_world_size(group=get_model_parallel_group()) |
|
|
|
|
|
def set_model_parallel_rank(rank): |
|
"""Set model parallel rank.""" |
|
global _MPU_RANK |
|
_MPU_RANK = rank |
|
|
|
|
|
def get_model_parallel_rank(): |
|
"""Return my rank for the model parallel group.""" |
|
global _MPU_RANK |
|
if _MPU_RANK is not None: |
|
return _MPU_RANK |
|
return torch.distributed.get_rank(group=get_model_parallel_group()) |
|
|
|
|
|
def get_model_parallel_src_rank(): |
|
"""Calculate the global rank corresponding to a local rank zero |
|
in the model parallel group.""" |
|
global_rank = torch.distributed.get_rank() |
|
local_world_size = get_model_parallel_world_size() |
|
return (global_rank // local_world_size) * local_world_size |
|
|
|
|
|
def get_data_parallel_src_rank(): |
|
"""Calculate the global rank corresponding to a local rank zero |
|
in the data parallel group.""" |
|
global_rank = torch.distributed.get_rank() |
|
topo = get_topology() |
|
if topo is None: |
|
|
|
return global_rank % get_model_parallel_world_size() |
|
else: |
|
|
|
d = topo.get_axis_comm_lists("data") |
|
for l in d: |
|
if global_rank in l: |
|
return l[0] |
|
|
|
|
|
def get_data_parallel_world_size(): |
|
"""Return world size for the data parallel group.""" |
|
return torch.distributed.get_world_size(group=get_data_parallel_group()) |
|
|
|
|
|
def get_data_parallel_rank(): |
|
"""Return my rank for the data parallel group.""" |
|
return torch.distributed.get_rank(group=get_data_parallel_group()) |
|
|
|
|
|
def get_topology(): |
|
return _MPU_TOPOLOGY |
|
|
|
|
|
def get_pipe_parallel_group(): |
|
"""Get the pipe parallel group the caller rank belongs to.""" |
|
assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized" |
|
return _PIPE_PARALLEL_GROUP |
|
|
|
|
|
def get_pipe_parallel_rank(): |
|
"""Return my rank for the pipe parallel group.""" |
|
return torch.distributed.get_rank(group=get_pipe_parallel_group()) |
|
|
|
|
|
def get_pipe_parallel_world_size(): |
|
"""Return world size for the pipe parallel group.""" |
|
return torch.distributed.get_world_size(group=get_pipe_parallel_group()) |
|
|
|
|
|
def set_tensor_model_parallel_world_size(world_size): |
|
"""Set the tensor model parallel size""" |
|
set_model_parallel_world_size(world_size) |
|
|
|
|
|
def get_tensor_model_parallel_group(): |
|
"""Get the tensor model parallel group the caller rank belongs to.""" |
|
return get_model_parallel_group() |
|
|
|
|
|
def get_tensor_model_parallel_src_rank(): |
|
"""Calculate the global rank corresponding to the first local rank |
|
in the tensor model parallel group.""" |
|
return get_model_parallel_rank() |
|
|
|
|
|
|
|
def get_tensor_model_parallel_world_size(): |
|
"""Return world size for the tensor model parallel group.""" |
|
return get_model_parallel_world_size() |
|
|
|
|
|
def set_tensor_model_parallel_rank(rank): |
|
"""Set tensor model parallel rank.""" |
|
set_model_parallel_rank(rank) |
|
|
|
|
|
def get_tensor_model_parallel_rank(): |
|
"""Return my rank for the tensor model parallel group.""" |
|
return get_model_parallel_rank() |
|
|
|
|
|
def destroy_model_parallel(): |
|
"""Set the groups to none.""" |
|
global _MODEL_PARALLEL_GROUP |
|
_MODEL_PARALLEL_GROUP = None |
|
global _DATA_PARALLEL_GROUP |
|
_DATA_PARALLEL_GROUP = None |
|
global _PIPE_PARALLEL_GROUP |
|
_PIPE_PARALLEL_GROUP = None |
|
global _IO_PARALLEL_GROUP |
|
_IO_PARALLEL_GROUP = None |
|
global _MPU_WORLD_SIZE |
|
global _MPU_RANK |
|
_MPU_WORLD_SIZE = None |
|
_MPU_RANK = None |
|
global _MPU_TOPOLOGY |
|
_MPU_TOPOLOGY = None |
|
global _FP32_ALLREDUCE |
|
_FP32_ALLREDUCE = None |
|
|
|
|
|
def get_fp32_allreduce(): |
|
"""Get the fp32 allreduce flag""" |
|
assert _FP32_ALLREDUCE is not None, "fp32_allreduce is not Initialized" |
|
return _FP32_ALLREDUCE |
|
|