Spaces:
Running
Running
""" Distributed training/validation utils | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import torch | |
from torch import distributed as dist | |
from .model import unwrap_model | |
def reduce_tensor(tensor, n): | |
rt = tensor.clone() | |
dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |
rt /= n | |
return rt | |
def distribute_bn(model, world_size, reduce=False): | |
# ensure every node has the same running bn stats | |
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): | |
if ('running_mean' in bn_name) or ('running_var' in bn_name): | |
if reduce: | |
# average bn stats across whole group | |
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) | |
bn_buf /= float(world_size) | |
else: | |
# broadcast bn stats from rank 0 to whole group | |
torch.distributed.broadcast(bn_buf, 0) | |