# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Distributed helpers.""" import functools import logging import pickle import torch import torch.distributed as dist _LOCAL_PROCESS_GROUP = None def all_gather(tensors): """ All gathers the provided tensors from all processes across machines. Args: tensors (list): tensors to perform all gather across all processes in all machines. """ gather_list = [] output_tensor = [] world_size = dist.get_world_size() for tensor in tensors: tensor_placeholder = [ torch.ones_like(tensor) for _ in range(world_size) ] dist.all_gather(tensor_placeholder, tensor, async_op=False) gather_list.append(tensor_placeholder) for gathered_tensor in gather_list: output_tensor.append(torch.cat(gathered_tensor, dim=0)) return output_tensor def all_reduce(tensors, average=True): """ All reduce the provided tensors from all processes across machines. Args: tensors (list): tensors to perform all reduce across all processes in all machines. average (bool): scales the reduced tensor by the number of overall processes across all machines. """ for tensor in tensors: dist.all_reduce(tensor, async_op=False) if average: world_size = dist.get_world_size() for tensor in tensors: tensor.mul_(1.0 / world_size) return tensors def init_process_group( local_rank, local_world_size, shard_id, num_shards, init_method, dist_backend="nccl", ): """ Initializes the default process group. Args: local_rank (int): the rank on the current local machine. local_world_size (int): the world size (number of processes running) on the current local machine. shard_id (int): the shard index (machine rank) of the current machine. num_shards (int): number of shards for distributed training. init_method (string): supporting three different methods for initializing process groups: "file": use shared file system to initialize the groups across different processes. "tcp": use tcp address to initialize the groups across different dist_backend (string): backend to use for distributed training. Options includes gloo, mpi and nccl, the details can be found here: https://pytorch.org/docs/stable/distributed.html """ # Sets the GPU to use. torch.cuda.set_device(local_rank) # Initialize the process group. proc_rank = local_rank + shard_id * local_world_size world_size = local_world_size * num_shards dist.init_process_group( backend=dist_backend, init_method=init_method, world_size=world_size, rank=proc_rank, ) def is_master_proc(num_gpus=8): """ Determines if the current process is the master process. """ if torch.distributed.is_initialized(): return dist.get_rank() % num_gpus == 0 else: return True def is_root_proc(): """ Determines if the current process is the root process. """ if torch.distributed.is_initialized(): return dist.get_rank() == 0 else: return True def get_world_size(): """ Get the size of the world. """ if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def get_rank(): """ Get the rank of the current process. """ if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() @functools.lru_cache() def _get_global_gloo_group(): """ Return a process group based on gloo backend, containing all the ranks The result is cached. Returns: (group): pytorch dist group. """ if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") else: return dist.group.WORLD def _serialize_to_tensor(data, group): """ Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` backend is supported. Args: data (data): data to be serialized. group (group): pytorch dist group. Returns: tensor (ByteTensor): tensor that serialized. """ backend = dist.get_backend(group) assert backend in ["gloo", "nccl"] device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) if len(buffer) > 1024 ** 3: logger = logging.getLogger(__name__) logger.warning( "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( get_rank(), len(buffer) / (1024 ** 3), device ) ) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device=device) return tensor def _pad_to_largest_tensor(tensor, group): """ Padding all the tensors from different GPUs to the largest ones. Args: tensor (tensor): tensor to pad. group (group): pytorch dist group. Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size """ world_size = dist.get_world_size(group=group) assert ( world_size >= 1 ), "comm.gather/all_gather must be called from ranks within the given group!" local_size = torch.tensor( [tensor.numel()], dtype=torch.int64, device=tensor.device ) size_list = [ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) ] dist.all_gather(size_list, local_size, group=group) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: padding = torch.zeros( (max_size - local_size,), dtype=torch.uint8, device=tensor.device ) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor def all_gather_unaligned(data, group=None): """ Run all_gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: list of data gathered from each rank """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() if dist.get_world_size(group) == 1: return [data] tensor = _serialize_to_tensor(data, group) size_list, tensor = _pad_to_largest_tensor(tensor, group) max_size = max(size_list) # receiving Tensor from all ranks tensor_list = [ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list ] dist.all_gather(tensor_list, tensor, group=group) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def init_distributed_training(cfg): """ Initialize variables needed for distributed training. """ if cfg.NUM_GPUS <= 1: return num_gpus_per_machine = cfg.NUM_GPUS num_machines = dist.get_world_size() // num_gpus_per_machine for i in range(num_machines): ranks_on_i = list( range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) ) pg = dist.new_group(ranks_on_i) if i == cfg.SHARD_ID: global _LOCAL_PROCESS_GROUP _LOCAL_PROCESS_GROUP = pg def get_local_size() -> int: """ Returns: The size of the per-machine process group, i.e. the number of processes per machine. """ if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) def get_local_rank() -> int: """ Returns: The rank of the current process within the local (per-machine) process group. """ if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 assert _LOCAL_PROCESS_GROUP is not None return dist.get_rank(group=_LOCAL_PROCESS_GROUP)