jwyang
first commit
4121bec
# Copyright (c) Facebook, Inc. and its affiliates.
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import functools
import logging
import numpy as np
import pickle
import torch
import torch.distributed as dist
import diffdist
from torch import nn
from torch.nn import functional as F
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024 ** 3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group=group) == 1:
return [data]
rank = dist.get_rank(group=group)
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving Tensor from all ranks
if rank == dst:
max_size = max(size_list)
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2 ** 31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
def gather_tensors(tensor, method=""):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
world_size = get_world_size()
rank = get_rank()
if world_size <= 1:
return tensor, tensor.shape[0]
batch_size = torch.tensor(tensor.shape[0], device=tensor.device)
batch_size_full = [torch.zeros_like(batch_size)
for _ in range(world_size)]
dist.all_gather(batch_size_full, batch_size)
# cutting all data to min batch size across all GPUs
min_bs = min([bs.item() for bs in batch_size_full])
if min_bs < batch_size:
tensor = tensor[:min_bs]
if "svd" in method:
# curently, svd does not support half-precision
# convert tenosr back to full-precision
with torch.cuda.amp.autocast(enabled=False):
U, Sig, V = torch.svd_lowrank(tensor.cpu(), q=int(method.split("_")[1]))
# gather U
Us = _gather_tensor(U.to(tensor.device), world_size) # N x B x LR
Sigs = _gather_tensor(torch.diag(Sig.to(tensor.device)), world_size) # N x LR x LR
Vs = _gather_tensor(V.to(tensor.device).T, world_size) # N x LR x D
# perform batch mm
# outputs = []
# for k in range(Us.shape[0]):
# temp = torch.mm(Us[k], Sigs[k]) # B x LR
# output = torch.mm(temp, Vs[k]) # B x D
# outputs.append(output)
# outputs[rank] = tensor
# output = torch.cat(outputs, 0)
output = torch.bmm(torch.bmm(Us, Sigs), Vs)
output[rank] = tensor
output = output.view(-1, output.shape[-1])
elif "pca" in method:
raise NotImplementedError
else:
tensors_gather = [
torch.ones_like(tensor)
for _ in range(world_size)
]
# dist.all_gather(tensors_gather, tensor, async_op=False)
# need to do this to restore propagation of the gradients
# tensors_gather[rank] = tensor
tensors_gather = diffdist.functional.all_gather(tensors_gather, tensor)
output = torch.cat(tensors_gather, dim=0)
return output, min_bs
class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target, dim=-1):
loss = torch.sum(-target * F.log_softmax(x, dim=dim), dim=dim) / (torch.sum(target, dim=dim) + 1e-6)
return loss.mean()
class MILCrossEntropy(nn.Module):
"""
Multi-instance learning loss
"""
def __init__(self):
super(MILCrossEntropy, self).__init__()
def forward(self, x, target, dim=-1, weights=None, avg_positives=False):
# for numerical stability
logits_max, _ = torch.max(x, dim=1, keepdim=True)
logits = x - logits_max.detach()
exp_logits = torch.exp(logits)
# get non-zero entries off-diagonal
# identity = torch.eye(target.shape[0]).type_as(target)
# laplacian = 1 - (target - identity)
probs = exp_logits / (exp_logits).sum(dim=dim, keepdim=True)
if avg_positives: # average the logits over positive targets
loss = -torch.log(torch.sum(target * probs, dim=dim) / (torch.sum(target, dim=dim) + 1e-6))
else: # sum the logits over positive targets
loss = -torch.log(torch.sum(target * probs, dim=dim))
if weights is not None:
return (loss * weights).mean()
return loss.mean()