bytetrack / yolox /utils /allreduce_norm.py
AK391
all files
7734d5b
raw
history blame
2.85 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import torch
from torch import distributed as dist
from torch import nn
import pickle
from collections import OrderedDict
from .dist import _get_global_gloo_group, get_world_size
ASYNC_NORM = (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
)
__all__ = [
"get_async_norm_states",
"pyobj2tensor",
"tensor2pyobj",
"all_reduce",
"all_reduce_norm",
]
def get_async_norm_states(module):
async_norm_states = OrderedDict()
for name, child in module.named_modules():
if isinstance(child, ASYNC_NORM):
for k, v in child.state_dict().items():
async_norm_states[".".join([name, k])] = v
return async_norm_states
def pyobj2tensor(pyobj, device="cuda"):
"""serialize picklable python object to tensor"""
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
return torch.ByteTensor(storage).to(device=device)
def tensor2pyobj(tensor):
"""deserialize tensor to picklable python object"""
return pickle.loads(tensor.cpu().numpy().tobytes())
def _get_reduce_op(op_name):
return {
"sum": dist.ReduceOp.SUM,
"mean": dist.ReduceOp.SUM,
}[op_name.lower()]
def all_reduce(py_dict, op="sum", group=None):
"""
Apply all reduce function for python dict object.
NOTE: make sure that every py_dict has the same keys and values are in the same shape.
Args:
py_dict (dict): dict to apply all reduce op.
op (str): operator, could be "sum" or "mean".
"""
world_size = get_world_size()
if world_size == 1:
return py_dict
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return py_dict
# all reduce logic across different devices.
py_key = list(py_dict.keys())
py_key_tensor = pyobj2tensor(py_key)
dist.broadcast(py_key_tensor, src=0)
py_key = tensor2pyobj(py_key_tensor)
tensor_shapes = [py_dict[k].shape for k in py_key]
tensor_numels = [py_dict[k].numel() for k in py_key]
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
if op == "mean":
flatten_tensor /= world_size
split_tensors = [
x.reshape(shape)
for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
]
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
def all_reduce_norm(module):
"""
All reduce norm statistics in different devices.
"""
states = get_async_norm_states(module)
states = all_reduce(states, op="mean")
module.load_state_dict(states, strict=False)