Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from torch._utils import (_flatten_dense_tensors, _take_tensors, | |
_unflatten_dense_tensors) | |
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version | |
from .registry import MODULE_WRAPPERS | |
from .scatter_gather import scatter_kwargs | |
class MMDistributedDataParallel(nn.Module): | |
def __init__(self, | |
module, | |
dim=0, | |
broadcast_buffers=True, | |
bucket_cap_mb=25): | |
super(MMDistributedDataParallel, self).__init__() | |
self.module = module | |
self.dim = dim | |
self.broadcast_buffers = broadcast_buffers | |
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 | |
self._sync_params() | |
def _dist_broadcast_coalesced(self, tensors, buffer_size): | |
for tensors in _take_tensors(tensors, buffer_size): | |
flat_tensors = _flatten_dense_tensors(tensors) | |
dist.broadcast(flat_tensors, 0) | |
for tensor, synced in zip( | |
tensors, _unflatten_dense_tensors(flat_tensors, tensors)): | |
tensor.copy_(synced) | |
def _sync_params(self): | |
module_states = list(self.module.state_dict().values()) | |
if len(module_states) > 0: | |
self._dist_broadcast_coalesced(module_states, | |
self.broadcast_bucket_size) | |
if self.broadcast_buffers: | |
if (TORCH_VERSION != 'parrots' | |
and digit_version(TORCH_VERSION) < digit_version('1.0')): | |
buffers = [b.data for b in self.module._all_buffers()] | |
else: | |
buffers = [b.data for b in self.module.buffers()] | |
if len(buffers) > 0: | |
self._dist_broadcast_coalesced(buffers, | |
self.broadcast_bucket_size) | |
def scatter(self, inputs, kwargs, device_ids): | |
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
def forward(self, *inputs, **kwargs): | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
return self.module(*inputs[0], **kwargs[0]) | |
def train_step(self, *inputs, **kwargs): | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
output = self.module.train_step(*inputs[0], **kwargs[0]) | |
return output | |
def val_step(self, *inputs, **kwargs): | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
output = self.module.val_step(*inputs[0], **kwargs[0]) | |
return output | |