|
|
|
import torch
|
|
from torch.nn.parallel.distributed import (DistributedDataParallel,
|
|
_find_tensors)
|
|
|
|
from annotator.uniformer.mmcv import print_log
|
|
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
|
|
from .scatter_gather import scatter_kwargs
|
|
|
|
|
|
class MMDistributedDataParallel(DistributedDataParallel):
|
|
"""The DDP module that supports DataContainer.
|
|
|
|
MMDDP has two main differences with PyTorch DDP:
|
|
|
|
- It supports a custom type :class:`DataContainer` which allows more
|
|
flexible control of input data.
|
|
- It implement two APIs ``train_step()`` and ``val_step()``.
|
|
"""
|
|
|
|
def to_kwargs(self, inputs, kwargs, device_id):
|
|
|
|
|
|
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def train_step(self, *inputs, **kwargs):
|
|
"""train_step() API for module wrapped by DistributedDataParallel.
|
|
|
|
This method is basically the same as
|
|
``DistributedDataParallel.forward()``, while replacing
|
|
``self.module.forward()`` with ``self.module.train_step()``.
|
|
It is compatible with PyTorch 1.1 - 1.5.
|
|
"""
|
|
|
|
|
|
|
|
if ('parrots' not in TORCH_VERSION
|
|
and digit_version(TORCH_VERSION) >= digit_version('1.7')
|
|
and self.reducer._rebuild_buckets()):
|
|
print_log(
|
|
'Reducer buckets have been rebuilt in this iteration.',
|
|
logger='mmcv')
|
|
|
|
if getattr(self, 'require_forward_param_sync', True):
|
|
self._sync_params()
|
|
if self.device_ids:
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
output = self.module.train_step(*inputs[0], **kwargs[0])
|
|
else:
|
|
outputs = self.parallel_apply(
|
|
self._module_copies[:len(inputs)], inputs, kwargs)
|
|
output = self.gather(outputs, self.output_device)
|
|
else:
|
|
output = self.module.train_step(*inputs, **kwargs)
|
|
|
|
if torch.is_grad_enabled() and getattr(
|
|
self, 'require_backward_grad_sync', True):
|
|
if self.find_unused_parameters:
|
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
|
else:
|
|
self.reducer.prepare_for_backward([])
|
|
else:
|
|
if ('parrots' not in TORCH_VERSION
|
|
and digit_version(TORCH_VERSION) > digit_version('1.2')):
|
|
self.require_forward_param_sync = False
|
|
return output
|
|
|
|
def val_step(self, *inputs, **kwargs):
|
|
"""val_step() API for module wrapped by DistributedDataParallel.
|
|
|
|
This method is basically the same as
|
|
``DistributedDataParallel.forward()``, while replacing
|
|
``self.module.forward()`` with ``self.module.val_step()``.
|
|
It is compatible with PyTorch 1.1 - 1.5.
|
|
"""
|
|
|
|
|
|
if ('parrots' not in TORCH_VERSION
|
|
and digit_version(TORCH_VERSION) >= digit_version('1.7')
|
|
and self.reducer._rebuild_buckets()):
|
|
print_log(
|
|
'Reducer buckets have been rebuilt in this iteration.',
|
|
logger='mmcv')
|
|
|
|
if getattr(self, 'require_forward_param_sync', True):
|
|
self._sync_params()
|
|
if self.device_ids:
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
if len(self.device_ids) == 1:
|
|
output = self.module.val_step(*inputs[0], **kwargs[0])
|
|
else:
|
|
outputs = self.parallel_apply(
|
|
self._module_copies[:len(inputs)], inputs, kwargs)
|
|
output = self.gather(outputs, self.output_device)
|
|
else:
|
|
output = self.module.val_step(*inputs, **kwargs)
|
|
|
|
if torch.is_grad_enabled() and getattr(
|
|
self, 'require_backward_grad_sync', True):
|
|
if self.find_unused_parameters:
|
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
|
else:
|
|
self.reducer.prepare_for_backward([])
|
|
else:
|
|
if ('parrots' not in TORCH_VERSION
|
|
and digit_version(TORCH_VERSION) > digit_version('1.2')):
|
|
self.require_forward_param_sync = False
|
|
return output
|
|
|