# python3.7 """Contains the synchronizing operator.""" import torch import torch.distributed as dist __all__ = ['all_gather'] def all_gather(tensor): """Gathers tensor from all devices and does averaging.""" if not dist.is_initialized(): return tensor world_size = dist.get_world_size() tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] dist.all_gather(tensor_list, tensor, async_op=False) return torch.mean(torch.stack(tensor_list, dim=0), dim=0)