NEOX / megatron /mpu /initialize.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
import torch
from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPE_PARALLEL_GROUP = None
# A group used to sync during the IO process. Usually this is data_parallel_group(),
# but with pipeline parallelism it must also involve the last stage (which is not in the
# DP group of rank 0)
_IO_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None
_MPU_RANK = None
# Used to query 3D topology
_MPU_TOPOLOGY = None
# Get fp32_allreduce flag
_FP32_ALLREDUCE = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False):
"""
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel groups as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
if world_size < model_parallel_size:
raise ValueError("world size cannot be smaller than model parallel size")
ensure_divisibility(world_size, model_parallel_size)
rank = torch.distributed.get_rank()
global _MPU_TOPOLOGY
if topology:
_MPU_TOPOLOGY = topology
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
if topology:
for dp_group in topology.get_axis_comm_lists("data"):
group = torch.distributed.new_group(ranks=dp_group)
if rank == 0:
print(f"MPU DP:", dp_group)
if rank in dp_group:
_DATA_PARALLEL_GROUP = group
else:
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
# Build pipeline parallel group
if topology is not None:
global _PIPE_PARALLEL_GROUP
for pp_group in topology.get_axis_comm_lists("pipe"):
group = torch.distributed.new_group(ranks=pp_group)
if rank == 0:
print(f"MPU PP:", pp_group)
if rank in pp_group:
_PIPE_PARALLEL_GROUP = group
# Build IO group
global _IO_PARALLEL_GROUP
if topology and topology.get_dim("pipe") > 1:
io_stages = [0, topology.get_dim("pipe") - 1]
io_group = []
for stage in io_stages:
io_group.extend(topology.filter_match(pipe=stage, model=0))
if rank == 0:
print(f"MPU IO:", io_group)
group = torch.distributed.new_group(ranks=io_group)
if rank in io_group:
_IO_PARALLEL_GROUP = group
else:
_IO_PARALLEL_GROUP = get_data_parallel_group()
# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
if topology:
# Short circuit case without model parallelism.
# TODO: it would be nice to avoid this branching case?
if model_parallel_size == 1:
for group_rank in range(world_size):
group = torch.distributed.new_group(ranks=[group_rank])
if rank == 0:
print(f"MPU MP:", [group_rank])
if rank == group_rank:
_MODEL_PARALLEL_GROUP = group
return
for mp_group in topology.get_axis_comm_lists("model"):
group = torch.distributed.new_group(ranks=mp_group)
if rank == 0:
print(f"MPU MP:", mp_group)
if rank in mp_group:
_MODEL_PARALLEL_GROUP = group
else:
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
global _FP32_ALLREDUCE
assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized"
_FP32_ALLREDUCE = fp32_allreduce
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_io_parallel_group():
"""Get the IO parallel group the caller rank belongs to."""
assert _IO_PARALLEL_GROUP is not None, "IO parallel group is not initialized"
return _IO_PARALLEL_GROUP
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
global _MPU_WORLD_SIZE
_MPU_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zero
in the model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zero
in the data parallel group."""
global_rank = torch.distributed.get_rank()
topo = get_topology()
if topo is None:
# we are just using model parallel
return global_rank % get_model_parallel_world_size()
else:
# We are using pipeline parallel
d = topo.get_axis_comm_lists("data")
for l in d:
if global_rank in l:
return l[0]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def get_topology():
return _MPU_TOPOLOGY
def get_pipe_parallel_group():
"""Get the pipe parallel group the caller rank belongs to."""
assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _PIPE_PARALLEL_GROUP
def get_pipe_parallel_rank():
"""Return my rank for the pipe parallel group."""
return torch.distributed.get_rank(group=get_pipe_parallel_group())
def get_pipe_parallel_world_size():
"""Return world size for the pipe parallel group."""
return torch.distributed.get_world_size(group=get_pipe_parallel_group())
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
set_model_parallel_world_size(world_size)
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
return get_model_parallel_group()
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
return get_model_parallel_rank()
# Needed for MOE. True tensor parallelism todo.
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_model_parallel_world_size()
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
set_model_parallel_rank(rank)
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_model_parallel_rank()
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _PIPE_PARALLEL_GROUP
_PIPE_PARALLEL_GROUP = None
global _IO_PARALLEL_GROUP
_IO_PARALLEL_GROUP = None
global _MPU_WORLD_SIZE
global _MPU_RANK
_MPU_WORLD_SIZE = None
_MPU_RANK = None
global _MPU_TOPOLOGY
_MPU_TOPOLOGY = None
global _FP32_ALLREDUCE
_FP32_ALLREDUCE = None
def get_fp32_allreduce():
"""Get the fp32 allreduce flag"""
assert _FP32_ALLREDUCE is not None, "fp32_allreduce is not Initialized"
return _FP32_ALLREDUCE