# Copyright (c) OpenMMLab. All rights reserved. from itertools import chain from torch.nn.parallel import DataParallel from .scatter_gather import scatter_kwargs class MMDataParallel(DataParallel): """The DataParallel module that supports DataContainer. MMDataParallel has two main differences with PyTorch DataParallel: - It supports a custom type :class:`DataContainer` which allows more flexible control of input data during both GPU and CPU inference. - It implement two more APIs ``train_step()`` and ``val_step()``. Args: module (:class:`nn.Module`): Module to be encapsulated. device_ids (list[int]): Device IDS of modules to be scattered to. Defaults to None when GPU is not available. output_device (str | int): Device ID for output. Defaults to None. dim (int): Dimension used to scatter the data. Defaults to 0. """ def __init__(self, *args, dim=0, **kwargs): super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) self.dim = dim def forward(self, *inputs, **kwargs): """Override the original forward function. The main difference lies in the CPU inference where the data in :class:`DataContainers` will still be gathered. """ if not self.device_ids: # We add the following line thus the module could gather and # convert data containers as those in GPU inference inputs, kwargs = self.scatter(inputs, kwargs, [-1]) return self.module(*inputs[0], **kwargs[0]) else: return super().forward(*inputs, **kwargs) def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def train_step(self, *inputs, **kwargs): if not self.device_ids: # We add the following line thus the module could gather and # convert data containers as those in GPU inference inputs, kwargs = self.scatter(inputs, kwargs, [-1]) return self.module.train_step(*inputs[0], **kwargs[0]) assert len(self.device_ids) == 1, \ ('MMDataParallel only supports single GPU training, if you need to' ' train with multiple GPUs, please use MMDistributedDataParallel' 'instead.') for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError( 'module must have its parameters and buffers ' f'on device {self.src_device_obj} (device_ids[0]) but ' f'found one of them on device: {t.device}') inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) return self.module.train_step(*inputs[0], **kwargs[0]) def val_step(self, *inputs, **kwargs): if not self.device_ids: # We add the following line thus the module could gather and # convert data containers as those in GPU inference inputs, kwargs = self.scatter(inputs, kwargs, [-1]) return self.module.val_step(*inputs[0], **kwargs[0]) assert len(self.device_ids) == 1, \ ('MMDataParallel only supports single GPU training, if you need to' ' train with multiple GPUs, please use MMDistributedDataParallel' ' instead.') for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError( 'module must have its parameters and buffers ' f'on device {self.src_device_obj} (device_ids[0]) but ' f'found one of them on device: {t.device}') inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) return self.module.val_step(*inputs[0], **kwargs[0])