| | |
| | from contextlib import ExitStack, contextmanager |
| | from typing import Dict, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn.parallel.distributed import DistributedDataParallel |
| |
|
| | from mmengine.device import get_device |
| | from mmengine.optim import OptimWrapperDict |
| | from mmengine.registry import MODEL_WRAPPERS |
| | from .distributed import MMDistributedDataParallel |
| |
|
| |
|
| | @MODEL_WRAPPERS.register_module() |
| | class MMSeparateDistributedDataParallel(DistributedDataParallel): |
| | """A DistributedDataParallel wrapper for models in MMGeneration. |
| | |
| | In MMedting and MMGeneration there is a need to wrap different modules in |
| | the models with separate DistributedDataParallel. Otherwise, it will cause |
| | errors for GAN training. For example, the GAN model, usually has two |
| | submodules: generator and discriminator. If we wrap both of them in one |
| | standard DistributedDataParallel, it will cause errors during training, |
| | because when we update the parameters of the generator (or discriminator), |
| | the parameters of the discriminator (or generator) is not updated, which is |
| | not allowed for DistributedDataParallel. So we design this wrapper to |
| | separately wrap DistributedDataParallel for generator and discriminator. |
| | In this wrapper, we perform two operations: |
| | |
| | 1. Wraps each module in the models with separate MMDistributedDataParallel. |
| | Note that only modules with parameters will be wrapped. |
| | 2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to |
| | get losses and predictions. |
| | |
| | Args: |
| | module (nn.Module): model contain multiple submodules which have |
| | separately updating strategy. |
| | broadcast_buffers (bool): Same as that in |
| | ``torch.nn.parallel.distributed.DistributedDataParallel``. |
| | Defaults to False. |
| | find_unused_parameters (bool): Same as that in |
| | ``torch.nn.parallel.distributed.DistributedDataParallel``. |
| | Traverse the autograd graph of all tensors contained in returned |
| | value of the wrapped module's forward function. Defaults to False. |
| | **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``. |
| | |
| | - device_ids (List[int] or torch.device, optional): CUDA devices |
| | for module. |
| | - output_device (int or torch.device, optional): Device location of |
| | output for single-device CUDA modules. |
| | - dim (int): Defaults to 0. |
| | - process_group (ProcessGroup, optional): The process group to be |
| | used for distributed data all-reduction. |
| | - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults |
| | to 25. |
| | - check_reduction (bool): This argument is deprecated. Defaults |
| | to False. |
| | - gradient_as_bucket_view (bool): Defaults to False. |
| | - static_graph (bool): Defaults to False. |
| | |
| | See more information about arguments in |
| | :class:`torch.nn.parallel.DistributedDataParallel`. |
| | """ |
| |
|
| | def __init__(self, |
| | module: nn.Module, |
| | broadcast_buffers: bool = False, |
| | find_unused_parameters: bool = False, |
| | **kwargs): |
| | super(DistributedDataParallel, self).__init__() |
| | self.module = module |
| | device = get_device() |
| | |
| | |
| | for name, sub_module in module._modules.items(): |
| | |
| | if next(sub_module.parameters(), None) is None: |
| | sub_module = sub_module.to(device) |
| | elif all(not p.requires_grad for p in sub_module.parameters()): |
| | sub_module = sub_module.to(device) |
| | else: |
| | sub_module = MMDistributedDataParallel( |
| | module=sub_module.to(device), |
| | broadcast_buffers=broadcast_buffers, |
| | find_unused_parameters=find_unused_parameters, |
| | **kwargs) |
| | module._modules[name] = sub_module |
| |
|
| | def train_step(self, data: Union[dict, tuple, list], |
| | optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: |
| | """Interface for model forward, backward and parameters updating during |
| | training process. |
| | |
| | Args: |
| | data (dict or tuple or list): Data sampled from dataset. |
| | optim_wrapper (OptimWrapperDict): A wrapper of optimizer to |
| | update parameters. |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: A dict of tensor for logging. |
| | """ |
| | return self.module.train_step(data, optim_wrapper) |
| |
|
| | def val_step(self, data: Union[dict, tuple, list]) -> list: |
| | """Gets the prediction of module during validation process. |
| | |
| | Args: |
| | data (dict or tuple or list): Data sampled from dataset. |
| | |
| | Returns: |
| | list: The predictions of given data. |
| | """ |
| | return self.module.val_step(data) |
| |
|
| | def test_step(self, data: Union[dict, tuple, list]) -> list: |
| | """Gets the predictions of module during testing process. |
| | |
| | Args: |
| | data (dict or tuple or list): Data sampled from dataset. |
| | |
| | Returns: |
| | list: The predictions of given data. |
| | """ |
| | return self.module.test_step(data) |
| |
|
| | @contextmanager |
| | def no_sync(self): |
| | """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel`` |
| | modules.""" |
| | with ExitStack() as stack: |
| | for sub_ddp_model in self.module._modules.values(): |
| | stack.enter_context(sub_ddp_model.no_sync()) |
| | yield |
| |
|
| | def train(self, mode: bool = True) -> 'MMSeparateDistributedDataParallel': |
| | """Sets the module in training mode. |
| | |
| | In order to make the ddp wrapper inheritance hierarchy more uniform, |
| | ``MMSeparateDistributedDataParallel`` inherits from |
| | ``DistributedDataParallel``, but will not call its constructor. |
| | Since the attributes of ``DistributedDataParallel`` have not been |
| | initialized, call the ``train`` method of ``DistributedDataParallel`` |
| | will raise an error if pytorch version <= 1.9. Therefore, override |
| | this method to call the ``train`` method of submodules. |
| | |
| | Args: |
| | mode (bool): whether to set training mode (``True``) or evaluation |
| | mode (``False``). Defaults to ``True``. |
| | |
| | Returns: |
| | Module: self. |
| | """ |
| | self.training = mode |
| | self.module.train(mode) |
| | return self |
| |
|