File size: 502 Bytes
ff2b8e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# python3.7
"""Contains the synchronizing operator."""
import torch
import torch.distributed as dist
__all__ = ['all_gather']
def all_gather(tensor):
"""Gathers tensor from all devices and does averaging."""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, async_op=False)
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
|