|
|
|
|
|
"""Distributed helpers.""" |
|
|
|
import functools |
|
import logging |
|
import pickle |
|
import torch |
|
import torch.distributed as dist |
|
|
|
_LOCAL_PROCESS_GROUP = None |
|
|
|
|
|
def all_gather(tensors): |
|
""" |
|
All gathers the provided tensors from all processes across machines. |
|
Args: |
|
tensors (list): tensors to perform all gather across all processes in |
|
all machines. |
|
""" |
|
|
|
gather_list = [] |
|
output_tensor = [] |
|
world_size = dist.get_world_size() |
|
for tensor in tensors: |
|
tensor_placeholder = [ |
|
torch.ones_like(tensor) for _ in range(world_size) |
|
] |
|
dist.all_gather(tensor_placeholder, tensor, async_op=False) |
|
gather_list.append(tensor_placeholder) |
|
for gathered_tensor in gather_list: |
|
output_tensor.append(torch.cat(gathered_tensor, dim=0)) |
|
return output_tensor |
|
|
|
|
|
def all_reduce(tensors, average=True): |
|
""" |
|
All reduce the provided tensors from all processes across machines. |
|
Args: |
|
tensors (list): tensors to perform all reduce across all processes in |
|
all machines. |
|
average (bool): scales the reduced tensor by the number of overall |
|
processes across all machines. |
|
""" |
|
|
|
for tensor in tensors: |
|
dist.all_reduce(tensor, async_op=False) |
|
if average: |
|
world_size = dist.get_world_size() |
|
for tensor in tensors: |
|
tensor.mul_(1.0 / world_size) |
|
return tensors |
|
|
|
|
|
def init_process_group( |
|
local_rank, |
|
local_world_size, |
|
shard_id, |
|
num_shards, |
|
init_method, |
|
dist_backend="nccl", |
|
): |
|
""" |
|
Initializes the default process group. |
|
Args: |
|
local_rank (int): the rank on the current local machine. |
|
local_world_size (int): the world size (number of processes running) on |
|
the current local machine. |
|
shard_id (int): the shard index (machine rank) of the current machine. |
|
num_shards (int): number of shards for distributed training. |
|
init_method (string): supporting three different methods for |
|
initializing process groups: |
|
"file": use shared file system to initialize the groups across |
|
different processes. |
|
"tcp": use tcp address to initialize the groups across different |
|
dist_backend (string): backend to use for distributed training. Options |
|
includes gloo, mpi and nccl, the details can be found here: |
|
https://pytorch.org/docs/stable/distributed.html |
|
""" |
|
|
|
torch.cuda.set_device(local_rank) |
|
|
|
proc_rank = local_rank + shard_id * local_world_size |
|
world_size = local_world_size * num_shards |
|
dist.init_process_group( |
|
backend=dist_backend, |
|
init_method=init_method, |
|
world_size=world_size, |
|
rank=proc_rank, |
|
) |
|
|
|
|
|
def is_master_proc(num_gpus=8): |
|
""" |
|
Determines if the current process is the master process. |
|
""" |
|
if torch.distributed.is_initialized(): |
|
return dist.get_rank() % num_gpus == 0 |
|
else: |
|
return True |
|
|
|
|
|
def is_root_proc(): |
|
""" |
|
Determines if the current process is the root process. |
|
""" |
|
if torch.distributed.is_initialized(): |
|
return dist.get_rank() == 0 |
|
else: |
|
return True |
|
|
|
|
|
def get_world_size(): |
|
""" |
|
Get the size of the world. |
|
""" |
|
if not dist.is_available(): |
|
return 1 |
|
if not dist.is_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
|
|
def get_rank(): |
|
""" |
|
Get the rank of the current process. |
|
""" |
|
if not dist.is_available(): |
|
return 0 |
|
if not dist.is_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def synchronize(): |
|
""" |
|
Helper function to synchronize (barrier) among all processes when |
|
using distributed training |
|
""" |
|
if not dist.is_available(): |
|
return |
|
if not dist.is_initialized(): |
|
return |
|
world_size = dist.get_world_size() |
|
if world_size == 1: |
|
return |
|
dist.barrier() |
|
|
|
|
|
@functools.lru_cache() |
|
def _get_global_gloo_group(): |
|
""" |
|
Return a process group based on gloo backend, containing all the ranks |
|
The result is cached. |
|
Returns: |
|
(group): pytorch dist group. |
|
""" |
|
if dist.get_backend() == "nccl": |
|
return dist.new_group(backend="gloo") |
|
else: |
|
return dist.group.WORLD |
|
|
|
|
|
def _serialize_to_tensor(data, group): |
|
""" |
|
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` |
|
backend is supported. |
|
Args: |
|
data (data): data to be serialized. |
|
group (group): pytorch dist group. |
|
Returns: |
|
tensor (ByteTensor): tensor that serialized. |
|
""" |
|
|
|
backend = dist.get_backend(group) |
|
assert backend in ["gloo", "nccl"] |
|
device = torch.device("cpu" if backend == "gloo" else "cuda") |
|
|
|
buffer = pickle.dumps(data) |
|
if len(buffer) > 1024 ** 3: |
|
logger = logging.getLogger(__name__) |
|
logger.warning( |
|
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format( |
|
get_rank(), len(buffer) / (1024 ** 3), device |
|
) |
|
) |
|
storage = torch.ByteStorage.from_buffer(buffer) |
|
tensor = torch.ByteTensor(storage).to(device=device) |
|
return tensor |
|
|
|
|
|
def _pad_to_largest_tensor(tensor, group): |
|
""" |
|
Padding all the tensors from different GPUs to the largest ones. |
|
Args: |
|
tensor (tensor): tensor to pad. |
|
group (group): pytorch dist group. |
|
Returns: |
|
list[int]: size of the tensor, on each rank |
|
Tensor: padded tensor that has the max size |
|
""" |
|
world_size = dist.get_world_size(group=group) |
|
assert ( |
|
world_size >= 1 |
|
), "comm.gather/all_gather must be called from ranks within the given group!" |
|
local_size = torch.tensor( |
|
[tensor.numel()], dtype=torch.int64, device=tensor.device |
|
) |
|
size_list = [ |
|
torch.zeros([1], dtype=torch.int64, device=tensor.device) |
|
for _ in range(world_size) |
|
] |
|
dist.all_gather(size_list, local_size, group=group) |
|
size_list = [int(size.item()) for size in size_list] |
|
|
|
max_size = max(size_list) |
|
|
|
|
|
|
|
if local_size != max_size: |
|
padding = torch.zeros( |
|
(max_size - local_size,), dtype=torch.uint8, device=tensor.device |
|
) |
|
tensor = torch.cat((tensor, padding), dim=0) |
|
return size_list, tensor |
|
|
|
|
|
def all_gather_unaligned(data, group=None): |
|
""" |
|
Run all_gather on arbitrary picklable data (not necessarily tensors). |
|
|
|
Args: |
|
data: any picklable object |
|
group: a torch process group. By default, will use a group which |
|
contains all ranks on gloo backend. |
|
|
|
Returns: |
|
list[data]: list of data gathered from each rank |
|
""" |
|
if get_world_size() == 1: |
|
return [data] |
|
if group is None: |
|
group = _get_global_gloo_group() |
|
if dist.get_world_size(group) == 1: |
|
return [data] |
|
|
|
tensor = _serialize_to_tensor(data, group) |
|
|
|
size_list, tensor = _pad_to_largest_tensor(tensor, group) |
|
max_size = max(size_list) |
|
|
|
|
|
tensor_list = [ |
|
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) |
|
for _ in size_list |
|
] |
|
dist.all_gather(tensor_list, tensor, group=group) |
|
|
|
data_list = [] |
|
for size, tensor in zip(size_list, tensor_list): |
|
buffer = tensor.cpu().numpy().tobytes()[:size] |
|
data_list.append(pickle.loads(buffer)) |
|
|
|
return data_list |
|
|
|
|
|
def init_distributed_training(cfg): |
|
""" |
|
Initialize variables needed for distributed training. |
|
""" |
|
if cfg.NUM_GPUS <= 1: |
|
return |
|
num_gpus_per_machine = cfg.NUM_GPUS |
|
num_machines = dist.get_world_size() // num_gpus_per_machine |
|
for i in range(num_machines): |
|
ranks_on_i = list( |
|
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) |
|
) |
|
pg = dist.new_group(ranks_on_i) |
|
if i == cfg.SHARD_ID: |
|
global _LOCAL_PROCESS_GROUP |
|
_LOCAL_PROCESS_GROUP = pg |
|
|
|
|
|
def get_local_size() -> int: |
|
""" |
|
Returns: |
|
The size of the per-machine process group, |
|
i.e. the number of processes per machine. |
|
""" |
|
if not dist.is_available(): |
|
return 1 |
|
if not dist.is_initialized(): |
|
return 1 |
|
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) |
|
|
|
|
|
def get_local_rank() -> int: |
|
""" |
|
Returns: |
|
The rank of the current process within the local (per-machine) process group. |
|
""" |
|
if not dist.is_available(): |
|
return 0 |
|
if not dist.is_initialized(): |
|
return 0 |
|
assert _LOCAL_PROCESS_GROUP is not None |
|
return dist.get_rank(group=_LOCAL_PROCESS_GROUP) |
|
|