mshukor
init
3eb682b
raw
history blame
8.86 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Distributed helpers."""
import functools
import logging
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
def all_gather(tensors):
"""
All gathers the provided tensors from all processes across machines.
Args:
tensors (list): tensors to perform all gather across all processes in
all machines.
"""
gather_list = []
output_tensor = []
world_size = dist.get_world_size()
for tensor in tensors:
tensor_placeholder = [
torch.ones_like(tensor) for _ in range(world_size)
]
dist.all_gather(tensor_placeholder, tensor, async_op=False)
gather_list.append(tensor_placeholder)
for gathered_tensor in gather_list:
output_tensor.append(torch.cat(gathered_tensor, dim=0))
return output_tensor
def all_reduce(tensors, average=True):
"""
All reduce the provided tensors from all processes across machines.
Args:
tensors (list): tensors to perform all reduce across all processes in
all machines.
average (bool): scales the reduced tensor by the number of overall
processes across all machines.
"""
for tensor in tensors:
dist.all_reduce(tensor, async_op=False)
if average:
world_size = dist.get_world_size()
for tensor in tensors:
tensor.mul_(1.0 / world_size)
return tensors
def init_process_group(
local_rank,
local_world_size,
shard_id,
num_shards,
init_method,
dist_backend="nccl",
):
"""
Initializes the default process group.
Args:
local_rank (int): the rank on the current local machine.
local_world_size (int): the world size (number of processes running) on
the current local machine.
shard_id (int): the shard index (machine rank) of the current machine.
num_shards (int): number of shards for distributed training.
init_method (string): supporting three different methods for
initializing process groups:
"file": use shared file system to initialize the groups across
different processes.
"tcp": use tcp address to initialize the groups across different
dist_backend (string): backend to use for distributed training. Options
includes gloo, mpi and nccl, the details can be found here:
https://pytorch.org/docs/stable/distributed.html
"""
# Sets the GPU to use.
torch.cuda.set_device(local_rank)
# Initialize the process group.
proc_rank = local_rank + shard_id * local_world_size
world_size = local_world_size * num_shards
dist.init_process_group(
backend=dist_backend,
init_method=init_method,
world_size=world_size,
rank=proc_rank,
)
def is_master_proc(num_gpus=8):
"""
Determines if the current process is the master process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def is_root_proc():
"""
Determines if the current process is the root process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() == 0
else:
return True
def get_world_size():
"""
Get the size of the world.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank():
"""
Get the rank of the current process.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
Returns:
(group): pytorch dist group.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
"""
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
backend is supported.
Args:
data (data): data to be serialized.
group (group): pytorch dist group.
Returns:
tensor (ByteTensor): tensor that serialized.
"""
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024 ** 3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Padding all the tensors from different GPUs to the largest ones.
Args:
tensor (tensor): tensor to pad.
group (group): pytorch dist group.
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor(
[tensor.numel()], dtype=torch.int64, device=tensor.device
)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device)
for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros(
(max_size - local_size,), dtype=torch.uint8, device=tensor.device
)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather_unaligned(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def init_distributed_training(cfg):
"""
Initialize variables needed for distributed training.
"""
if cfg.NUM_GPUS <= 1:
return
num_gpus_per_machine = cfg.NUM_GPUS
num_machines = dist.get_world_size() // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == cfg.SHARD_ID:
global _LOCAL_PROCESS_GROUP
_LOCAL_PROCESS_GROUP = pg
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)