Spaces:
Running
Running
import operator | |
import torch | |
import warnings | |
from itertools import chain | |
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union | |
from ..modules import Module | |
from .scatter_gather import scatter_kwargs, gather | |
from .replicate import replicate | |
from .parallel_apply import parallel_apply | |
from torch._utils import ( | |
_get_all_device_indices, | |
_get_available_device_type, | |
_get_device_index, | |
_get_devices_properties | |
) | |
__all__ = ['DataParallel', 'data_parallel'] | |
def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: | |
imbalance_warn = """ | |
There is an imbalance between your GPUs. You may want to exclude GPU {} which | |
has less than 75% of the memory or cores of GPU {}. You can do so by setting | |
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES | |
environment variable.""" | |
device_ids = [_get_device_index(x, True) for x in device_ids] | |
dev_props = _get_devices_properties(device_ids) | |
def warn_imbalance(get_prop): | |
values = [get_prop(props) for props in dev_props] | |
min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) | |
max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) | |
if min_val / max_val < 0.75: | |
warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) | |
return True | |
return False | |
if warn_imbalance(lambda props: props.total_memory): | |
return | |
if warn_imbalance(lambda props: props.multi_processor_count): | |
return | |
T = TypeVar("T", bound=Module) | |
class DataParallel(Module, Generic[T]): | |
r"""Implements data parallelism at the module level. | |
This container parallelizes the application of the given :attr:`module` by | |
splitting the input across the specified devices by chunking in the batch | |
dimension (other objects will be copied once per device). In the forward | |
pass, the module is replicated on each device, and each replica handles a | |
portion of the input. During the backwards pass, gradients from each replica | |
are summed into the original module. | |
The batch size should be larger than the number of GPUs used. | |
.. warning:: | |
It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`, | |
instead of this class, to do multi-GPU training, even if there is only a single | |
node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`. | |
Arbitrary positional and keyword inputs are allowed to be passed into | |
DataParallel but some types are specially handled. tensors will be | |
**scattered** on dim specified (default 0). tuple, list and dict types will | |
be shallow copied. The other types will be shared among different threads | |
and can be corrupted if written to in the model's forward pass. | |
The parallelized :attr:`module` must have its parameters and buffers on | |
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` | |
module. | |
.. warning:: | |
In each forward, :attr:`module` is **replicated** on each device, so any | |
updates to the running module in ``forward`` will be lost. For example, | |
if :attr:`module` has a counter attribute that is incremented in each | |
``forward``, it will always stay at the initial value because the update | |
is done on the replicas which are destroyed after ``forward``. However, | |
:class:`~torch.nn.DataParallel` guarantees that the replica on | |
``device[0]`` will have its parameters and buffers sharing storage with | |
the base parallelized :attr:`module`. So **in-place** updates to the | |
parameters or buffers on ``device[0]`` will be recorded. E.g., | |
:class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` | |
rely on this behavior to update the buffers. | |
.. warning:: | |
Forward and backward hooks defined on :attr:`module` and its submodules | |
will be invoked ``len(device_ids)`` times, each with inputs located on | |
a particular device. Particularly, the hooks are only guaranteed to be | |
executed in correct order with respect to operations on corresponding | |
devices. For example, it is not guaranteed that hooks set via | |
:meth:`~torch.nn.Module.register_forward_pre_hook` be executed before | |
`all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but | |
that each such hook be executed before the corresponding | |
:meth:`~torch.nn.Module.forward` call of that device. | |
.. warning:: | |
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in | |
:func:`forward`, this wrapper will return a vector of length equal to | |
number of devices used in data parallelism, containing the result from | |
each device. | |
.. note:: | |
There is a subtlety in using the | |
``pack sequence -> recurrent network -> unpack sequence`` pattern in a | |
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. | |
See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for | |
details. | |
Args: | |
module (Module): module to be parallelized | |
device_ids (list of int or torch.device): CUDA devices (default: all devices) | |
output_device (int or torch.device): device location of output (default: device_ids[0]) | |
Attributes: | |
module (Module): the module to be parallelized | |
Example:: | |
>>> # xdoctest: +SKIP | |
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) | |
>>> output = net(input_var) # input_var can be on any device, including CPU | |
""" | |
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well | |
def __init__( | |
self, | |
module: T, | |
device_ids: Optional[Sequence[Union[int, torch.device]]] = None, | |
output_device: Optional[Union[int, torch.device]] = None, | |
dim: int = 0, | |
) -> None: | |
super().__init__() | |
torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") | |
device_type = _get_available_device_type() | |
if device_type is None: | |
self.module = module | |
self.device_ids = [] | |
return | |
if device_ids is None: | |
device_ids = _get_all_device_indices() | |
if device_ids is None: | |
raise RuntimeError("no available devices were found") | |
if output_device is None: | |
output_device = device_ids[0] | |
self.dim = dim | |
self.module = module | |
self.device_ids = [_get_device_index(x, True) for x in device_ids] | |
self.output_device = _get_device_index(output_device, True) | |
self.src_device_obj = torch.device(device_type, self.device_ids[0]) | |
if device_type == "cuda": | |
_check_balance(self.device_ids) | |
if len(self.device_ids) == 1: | |
self.module.to(self.src_device_obj) | |
def forward(self, *inputs: Any, **kwargs: Any) -> Any: | |
with torch.autograd.profiler.record_function("DataParallel.forward"): | |
if not self.device_ids: | |
return self.module(*inputs, **kwargs) | |
for t in chain(self.module.parameters(), self.module.buffers()): | |
if t.device != self.src_device_obj: | |
raise RuntimeError("module must have its parameters and buffers " | |
f"on device {self.src_device_obj} (device_ids[0]) but found one of " | |
f"them on device: {t.device}") | |
inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
# for forward function without any inputs, empty list and dict will be created | |
# so the module can be executed on one device which is the first one in device_ids | |
if not inputs and not module_kwargs: | |
inputs = ((),) | |
module_kwargs = ({},) | |
if len(self.device_ids) == 1: | |
return self.module(*inputs[0], **module_kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
outputs = self.parallel_apply(replicas, inputs, module_kwargs) | |
return self.gather(outputs, self.output_device) | |
def replicate(self, module: T, device_ids: Sequence[Union[int, torch.device]]) -> List[T]: | |
return replicate(module, device_ids, not torch.is_grad_enabled()) | |
def scatter( | |
self, | |
inputs: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]], | |
device_ids: Sequence[Union[int, torch.device]], | |
) -> Any: | |
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]: | |
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | |
def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: | |
return gather(outputs, output_device, dim=self.dim) | |
def data_parallel( | |
module: Module, | |
inputs: Any, | |
device_ids: Optional[Sequence[Union[int, torch.device]]] = None, | |
output_device: Optional[Union[int, torch.device]] = None, | |
dim: int = 0, | |
module_kwargs: Optional[Any] = None, | |
) -> torch.Tensor: | |
r"""Evaluate module(input) in parallel across the GPUs given in device_ids. | |
This is the functional version of the DataParallel module. | |
Args: | |
module (Module): the module to evaluate in parallel | |
inputs (Tensor): inputs to the module | |
device_ids (list of int or torch.device): GPU ids on which to replicate module | |
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. | |
(default: device_ids[0]) | |
Returns: | |
a Tensor containing the result of module(input) located on | |
output_device | |
""" | |
if not isinstance(inputs, tuple): | |
inputs = (inputs,) if inputs is not None else () | |
device_type = _get_available_device_type() | |
if device_type is None: | |
raise RuntimeError("device type could not be determined") | |
if device_ids is None: | |
device_ids = _get_all_device_indices() | |
if device_ids is None: | |
raise RuntimeError("no available devices were found") | |
if output_device is None: | |
output_device = device_ids[0] | |
device_ids = [_get_device_index(x, True) for x in device_ids] | |
output_device = _get_device_index(output_device, True) | |
src_device_obj = torch.device(device_type, device_ids[0]) | |
for t in chain(module.parameters(), module.buffers()): | |
if t.device != src_device_obj: | |
raise RuntimeError("module must have its parameters and buffers " | |
f"on device {src_device_obj} (device_ids[0]) but found one of " | |
f"them on device: {t.device}") | |
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) | |
# for module without any inputs, empty list and dict will be created | |
# so the module can be executed on one device which is the first one in device_ids | |
if not inputs and not module_kwargs: | |
inputs = ((),) | |
module_kwargs = ({},) | |
assert module_kwargs is not None | |
if len(device_ids) == 1: | |
return module(*inputs[0], **module_kwargs[0]) | |
used_device_ids = device_ids[:len(inputs)] | |
replicas = replicate(module, used_device_ids) | |
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) | |
return gather(outputs, output_device, dim) | |