#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. import torch from torch import distributed as dist from torch import nn import pickle from collections import OrderedDict from .dist import _get_global_gloo_group, get_world_size ASYNC_NORM = ( nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, ) __all__ = [ "get_async_norm_states", "pyobj2tensor", "tensor2pyobj", "all_reduce", "all_reduce_norm", ] def get_async_norm_states(module): async_norm_states = OrderedDict() for name, child in module.named_modules(): if isinstance(child, ASYNC_NORM): for k, v in child.state_dict().items(): async_norm_states[".".join([name, k])] = v return async_norm_states def pyobj2tensor(pyobj, device="cuda"): """serialize picklable python object to tensor""" storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) return torch.ByteTensor(storage).to(device=device) def tensor2pyobj(tensor): """deserialize tensor to picklable python object""" return pickle.loads(tensor.cpu().numpy().tobytes()) def _get_reduce_op(op_name): return { "sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.SUM, }[op_name.lower()] def all_reduce(py_dict, op="sum", group=None): """ Apply all reduce function for python dict object. NOTE: make sure that every py_dict has the same keys and values are in the same shape. Args: py_dict (dict): dict to apply all reduce op. op (str): operator, could be "sum" or "mean". """ world_size = get_world_size() if world_size == 1: return py_dict if group is None: group = _get_global_gloo_group() if dist.get_world_size(group) == 1: return py_dict # all reduce logic across different devices. py_key = list(py_dict.keys()) py_key_tensor = pyobj2tensor(py_key) dist.broadcast(py_key_tensor, src=0) py_key = tensor2pyobj(py_key_tensor) tensor_shapes = [py_dict[k].shape for k in py_key] tensor_numels = [py_dict[k].numel() for k in py_key] flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) dist.all_reduce(flatten_tensor, op=_get_reduce_op(op)) if op == "mean": flatten_tensor /= world_size split_tensors = [ x.reshape(shape) for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes) ] return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) def all_reduce_norm(module): """ All reduce norm statistics in different devices. """ states = get_async_norm_states(module) states = all_reduce(states, op="mean") module.load_state_dict(states, strict=False)