Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
r"""Zero Redundancy Optimizer.""" | |
import collections | |
import copy | |
import enum | |
import inspect | |
import io | |
import logging | |
from itertools import chain | |
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union | |
import torch | |
import torch.distributed as dist | |
from torch.distributed.algorithms.join import Join, Joinable, JoinHook | |
from torch.distributed.optim.utils import functional_optim_map | |
from torch.optim import Optimizer | |
logger = logging.getLogger(__name__) | |
__all__ = ["ZeroRedundancyOptimizer"] | |
# Credits: classy_vision/generic/distributed_util.py | |
def _recursive_copy_to_device( | |
value: Any, | |
non_blocking: bool, | |
device: torch.device, | |
) -> Any: | |
r""" | |
Recursively searches lists, tuples, dicts and copies tensors to device if possible. | |
Non-tensor values are passed as-is in the result. | |
.. note: These are all copies, so if there are two objects that reference | |
the same object, then after this call, there will be two different objects | |
referenced on the device. | |
""" | |
if isinstance(value, torch.Tensor): | |
return value.to(device, non_blocking=non_blocking) | |
if isinstance(value, (list, tuple)): | |
values = [ | |
_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) | |
for val in value | |
] | |
return values if isinstance(value, list) else tuple(values) | |
if isinstance(value, collections.abc.Mapping): | |
return { | |
key: _recursive_copy_to_device( | |
val, non_blocking=non_blocking, device=device | |
) | |
for key, val in value.items() | |
} | |
return value | |
def _is_trainable(param: torch.Tensor) -> bool: | |
r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" | |
return param.requires_grad | |
def _broadcast_object( | |
obj: Any, | |
src_rank: int, | |
group: object = dist.group.WORLD, | |
device: torch.device = torch.device("cpu"), | |
) -> Any: | |
r""" | |
Broadcasts an object to the given group. | |
It will be sending the object if called from the source rank and receiving | |
the object otherwise. | |
Arguments: | |
obj: object to broadcast; only used if called on the source rank. | |
src_rank (int): source rank. | |
group (``ProcessGroup``, optional): group used for the broadcast | |
(default: ``dist.group.WORLD``). | |
device (``torch.device``, optional): device to send from or receive | |
to (default: ``torch.device("cpu")``). | |
Returns: | |
The broadcasted object. | |
""" | |
if dist.get_rank() == src_rank: | |
# Send the object | |
buffer = io.BytesIO() | |
torch.save(obj, buffer) | |
data = bytearray(buffer.getbuffer()) | |
length_tensor = torch.LongTensor([len(data)]).to(device) | |
data_send_tensor = torch.ByteTensor(data).to(device) | |
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) | |
else: | |
# Receive the object | |
length_tensor = torch.LongTensor([0]).to(device) | |
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
data_recv_tensor = torch.empty( | |
[int(length_tensor.item())], dtype=torch.uint8, device=device | |
) | |
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) | |
buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) | |
obj = torch.load(buffer, map_location=device) | |
return obj | |
class _ZeROJoinHook(JoinHook): | |
def __init__(self, zero): | |
assert isinstance(zero, ZeroRedundancyOptimizer), ( | |
"ZeRO join hook requires passing in a ZeroRedundancyOptimizer " | |
"instance as the state" | |
) | |
self.zero = zero | |
super().__init__() | |
def main_hook(self): | |
""" | |
Perform an optimizer step. | |
This step updates the joined process's shard of | |
the parameters and broadcasts those parameters. | |
""" | |
self.zero.step() | |
class _DDPBucketAssignment: | |
r""" | |
Represent a :class:`DistributedDataParallel` bucket assignment. | |
This means that a (possibly non-strict) subset of the parameters corresponding to | |
a DDP bucket assigned to a rank to update. | |
Attributes: | |
bucket_index (int): index of the bucket determined by the DDP gradient | |
bucket all-reduce order. | |
parameters (List[torch.Tensor]): model parameters in the bucket | |
assigned to this rank. | |
offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` | |
giving the index of the first element in the passed-in | |
``parameters``; this equivalently indexes into the | |
:class:`GradBucket` 's :meth:`gradients`. | |
device (torch.device): device on which the parameters are stored. | |
tensor (torch.Tensor): flattened tensor giving the data of the | |
parameter subset assigned to the rank. | |
""" | |
def __init__( | |
self, | |
bucket_index: int, | |
parameters: List[torch.Tensor], | |
offset: int, | |
): | |
self.bucket_index = bucket_index | |
self.parameters = parameters | |
self.offset = offset | |
if len(self.parameters) == 0: | |
raise ValueError("Empty bucket assignment") | |
# DDP guarantees all parameters in the bucket have the same device | |
self.device: torch.device = self.parameters[0].device | |
self.tensor: Optional[torch.Tensor] = None | |
class _OverlapStatus(enum.IntEnum): | |
r""" | |
Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. | |
Attributes: | |
``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and | |
is waiting for DDP to finalize its bucketing. | |
``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that | |
its bucketing is finalized. The ZeRO instance can now collect the | |
necessary information about the DDP bucketing. | |
``INITIALIZED``: The ZeRO instance is fully initialized and can now | |
optimize parameters. | |
""" | |
UNINITIALIZED = 0 | |
DDP_HAS_REBUILT_BUCKETS = 1 | |
INITIALIZED = 2 | |
class _OverlapInfo: | |
r""" | |
Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. | |
Arguments: | |
world_size (int): world size of the process group being used. | |
Attributes: | |
shard_buckets (bool): if ``True``, then the assignment of each | |
:class:`DistributedDataParallel` bucket is partitioned across | |
possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. | |
across possibly multiple ranks) to approximate uniformity following | |
a threshold given by the total parameter size divided by the world | |
size; if ``False``, then each bucket is wholly assigned to a single | |
:class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); | |
this should be set to the value passed into the hook constructor. | |
status (_OverlapStatus): current status; see :class:`_OverlapStatus` | |
for more information. | |
params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` | |
gives the model parameters in the ``i``th bucket. | |
params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]`` | |
gives the model parameters assigned to the ``i``th rank, where the | |
parameters are grouped by increasing bucket indices. | |
offsets (Dict[int, int]): maps from bucket index to the offset in | |
``self.params_per_rank[rank]`` giving the index of the first | |
parameter in that bucket, where ``rank`` is this process's own | |
rank; the keys of this :class:`dict` are the bucket indices | |
assigned to this rank. | |
num_bucket_assignments (int): total number of bucket assignments across | |
all ranks; this is equal to the number of | |
:class:`DistributedDataParallel` gradient buckets if | |
``shard_buckets=False`` and possibly greater otherwise. | |
total_size (int, optional): total size of all buckets (i.e. sum of | |
``param.numel()`` for all ``param`` across all buckets) if | |
``shard_buckets=True``; otherwise, ``None``. | |
broadcast_handles (List[Work]): :class:`list` of async work handles for | |
the parameter broadcasts. | |
bucket_index_to_future (Dict[int, torch.futures.Future]): | |
:class:`dict` mapping bucket index to the corresponding all-reduce | |
future. | |
bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` | |
mapping bucket index to the corresponding bucket. | |
bucket_indices_seen (List[int]): :class:`list` of the bucket indices | |
seen on this iteration. | |
""" | |
def __init__(self, world_size) -> None: | |
self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED | |
self.shard_buckets: bool = False | |
# Modified per bucket reconstruction | |
self.params_per_bucket: List[List[torch.Tensor]] = [] | |
self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)] | |
self.offsets: Dict[int, int] = {} | |
# Group Ranks | |
self.assigned_ranks_per_bucket: List[Set[int]] = [] | |
self.num_bucket_assignments: int = 0 | |
self.total_size: Optional[int] = None | |
# Modified per iteration | |
self.broadcast_handles: List[Any] = [] | |
self.bucket_indices_seen: List[int] = [] | |
# Used by `hook_with_zero_step()` | |
self.bucket_index_to_future: Dict[int, torch.futures.Future] = {} | |
self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {} | |
def wait_for_broadcasts(self) -> None: | |
r""" | |
Wait for all parameter broadcasts. | |
This function should be called once all broadcasts have been scheduled, | |
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` | |
in preparation for the next iteration. | |
""" | |
assert ( | |
len(self.broadcast_handles) == self.num_bucket_assignments | |
), f"Missing at least one broadcast handle on rank {dist.get_rank()}" | |
_ = [x.wait() for x in self.broadcast_handles] | |
self.broadcast_handles.clear() | |
def clear_per_iter_info(self) -> None: | |
r""" | |
Clear the data structures that are modified per-iteration. | |
This function should be called at the end of an iteration. | |
""" | |
self.bucket_indices_seen.clear() | |
self.bucket_index_to_future.clear() | |
self.bucket_index_to_bucket.clear() | |
class ZeroRedundancyOptimizer(Optimizer, Joinable): | |
r""" | |
Wrap an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` and shards its states across ranks in the group. | |
The sharing is done as described by ZeRO_. | |
The local optimizer instance in each rank is only | |
responsible for updating approximately ``1 / world_size`` parameters and | |
hence only needs to keep ``1 / world_size`` optimizer states. After | |
parameters are updated locally, each rank will broadcast its parameters to | |
all other peers to keep all model replicas in the same state. | |
``ZeroRedundancyOptimizer`` can be used in conjunction with | |
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak | |
memory consumption. | |
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number | |
of parameters at each rank. Each parameter belongs to a single rank and is | |
not divided among ranks. The partition is arbitrary and might not match the | |
the parameter registration or usage order. | |
Arguments: | |
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s | |
or :class:`dict` s giving all parameters, which will be sharded | |
across ranks. | |
Keyword Args: | |
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local | |
optimizer. | |
process_group (``ProcessGroup``, optional): ``torch.distributed`` | |
``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by | |
:meth:`torch.distributed.init_process_group`). | |
parameters_as_bucket_view (bool, optional): if ``True``, parameters are | |
packed into buckets to speed up communication, and ``param.data`` | |
fields point to bucket views at different offsets; if ``False``, | |
each individual parameter is communicated separately, and each | |
``params.data`` stays intact (default: ``False``). | |
overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is | |
overlapped with :class:`DistributedDataParallel` 's gradient | |
synchronization; this requires (1) either a functional optimizer | |
for the ``optimizer_class`` argument or one with a functional | |
equivalent and (2) registering a DDP communication hook | |
constructed from one of the functions in ``ddp_zero_hook.py``; | |
parameters are packed into buckets matching those in | |
:class:`DistributedDataParallel`, meaning that the | |
``parameters_as_bucket_view`` argument is ignored. | |
If ``False``, :meth:`step` runs disjointly after the backward pass | |
(per normal). | |
(default: ``False``) | |
**defaults: any trailing arguments, which are forwarded to the local | |
optimizer. | |
Example:: | |
>>> # xdoctest: +SKIP | |
>>> import torch.nn as nn | |
>>> from torch.distributed.optim import ZeroRedundancyOptimizer | |
>>> from torch.nn.parallel import DistributedDataParallel as DDP | |
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) | |
>>> ddp = DDP(model, device_ids=[rank]) | |
>>> opt = ZeroRedundancyOptimizer( | |
>>> ddp.parameters(), | |
>>> optimizer_class=torch.optim.Adam, | |
>>> lr=0.01 | |
>>> ) | |
>>> ddp(inputs).sum().backward() | |
>>> opt.step() | |
.. warning:: | |
Currently, ``ZeroRedundancyOptimizer`` requires that all of the | |
passed-in parameters are the same dense type. | |
.. warning:: | |
If you pass ``overlap_with_ddp=True``, be wary of the following: Given | |
the way that overlapping :class:`DistributedDataParallel` with | |
:class:`ZeroRedundancyOptimizer` is currently implemented, the first | |
two or three training iterations do not perform parameter updates in | |
the optimizer step, depending on if ``static_graph=False`` or | |
``static_graph=True``, respectively. This is because it needs | |
information about the gradient bucketing strategy used by | |
:class:`DistributedDataParallel`, which is not finalized until the | |
second forward pass if ``static_graph=False`` or until the third | |
forward pass if ``static_graph=True``. To adjust for this, one option | |
is to prepend dummy inputs. | |
.. warning:: ZeroRedundancyOptimizer is experimental and subject to change. | |
.. _ZeRO: https://arxiv.org/abs/1910.02054 | |
""" | |
def __init__( | |
self, | |
params, | |
optimizer_class: Type[Optimizer], | |
process_group: Optional[Any] = None, | |
parameters_as_bucket_view: bool = False, | |
overlap_with_ddp: bool = False, | |
**defaults: Any, | |
): | |
r"""Init.""" | |
# Perform type and assumption checks on the input parameters | |
params = self._verify_and_init_params(params) | |
self._verify_same_dense_param_type() | |
# NOTE: The parent constructor uses `add_param_group()` which is | |
# partially overloaded in ZeroRedundancyOptimizer, so we use the | |
# `initialized` flag to dissociate the behaviour of `add_param_group()` | |
# between the parent and child. | |
self.initialized = False | |
Optimizer.__init__(self, params, defaults) | |
Joinable.__init__(self) | |
# Now, all parameters are held in both `self._all_params` and | |
# `self.param_groups` | |
# Internal data structures (`_cache` indicates lazily evaluated) | |
self._param_to_rank_cache: Dict[torch.Tensor, int] = {} | |
self._param_to_index_cache: Dict[torch.Tensor, int] = {} | |
self._partition_parameters_cache: List[List[Dict]] = [] | |
self._index_to_param_cache: List[torch.Tensor] = [] | |
self._device_to_params_per_rank_cache: Dict[ | |
torch.device, List[List[torch.Tensor]] | |
] = {} | |
self._bucket_assignments_per_rank_cache: List[ | |
Dict[int, _DDPBucketAssignment] | |
] = [] | |
self._is_trainable_mask = self._get_is_trainable_mask() | |
# Default device for collective communication and buckets | |
self._default_device = self._all_params[0].device | |
self.process_group = ( | |
process_group if process_group is not None else dist.group.WORLD | |
) | |
self.world_size: int = dist.get_world_size(self.process_group) | |
self.rank: int = dist.get_rank(self.process_group) | |
self.global_rank: int = dist.distributed_c10d.get_global_rank( | |
self.process_group, self.rank | |
) | |
self._overlap_with_ddp: bool = overlap_with_ddp | |
self._optim_defaults = defaults | |
self._optim_constructor = self._get_optimizer_constructor(optimizer_class) | |
# If `overlap_with_ddp=True`, local optimizer initialization is delayed | |
# to run time after the necessary information has been collected | |
if not overlap_with_ddp: | |
self._init_local_optimizer() | |
else: | |
self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) | |
if parameters_as_bucket_view: | |
logger.warning( | |
"`parameters_as_bucket_view=True` will be ignored since " | |
"`overlap_with_ddp=True`; instead, a different bucketing " | |
"strategy will be used" | |
) | |
# `self._buckets` is used if `parameters_as_bucket_view=True`, in | |
# which case parameter data is flattened into contiguous bucket tensors | |
self.parameters_as_bucket_view = parameters_as_bucket_view | |
self._buckets: List[List[torch.Tensor]] = [] | |
self._build_param_buckets() | |
# Optional consolidated optimizer state, only populated if this rank | |
# is the target in `consolidate_state_dict()` | |
self._all_state_dicts: List[Dict[str, Any]] = [] | |
self.initialized = True | |
def _clear_cache(self) -> None: | |
r"""Clear the cached data structures giving partition information.""" | |
self._partition_parameters_cache.clear() | |
self._param_to_rank_cache.clear() | |
self._index_to_param_cache.clear() | |
self._param_to_index_cache.clear() | |
self._device_to_params_per_rank_cache.clear() | |
self._bucket_assignments_per_rank_cache.clear() | |
def add_param_group(self, param_group: Dict[str, Any]) -> None: | |
r""" | |
Add a parameter group to the :class:`Optimizer` 's ``param_groups``. | |
This can be useful when fine tuning a pre-trained network, as frozen | |
layers can be made trainable and added to the :class:`Optimizer` as | |
training progresses. | |
Arguments: | |
param_group (dict): specifies the parameters to be optimized and | |
group-specific optimization options. | |
.. warning:: This method handles updating the shards on all partitions | |
but needs to be called on all ranks. Calling this on a subset of | |
the ranks will cause the training to hang because communication | |
primitives are called depending on the managed parameters and | |
expect all the ranks to participate on the same set of parameters. | |
""" | |
if self.initialized and self._overlap_with_ddp: | |
raise RuntimeError( | |
"ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " | |
"supports a single parameter group" | |
) | |
super().add_param_group(param_group) | |
# NOTE: The rest of the method assumes that the call to the parent's | |
# `add_param_group()` appends the new parameter group and preserves | |
# the previous parameter-group ordering | |
if self.initialized: | |
# Force a re-partitioning of the parameters | |
self._clear_cache() | |
param_groups = self._partition_parameters()[self.rank] | |
# NOTE: All parameters in the old parameter groups should be | |
# assigned to the same ranks so that the local optimizers do not | |
# need to be reinitialized | |
# Add the parameters assigned to this rank from the new parameter | |
# group to the local optimizer, if any | |
if len(param_groups) == len(self.optim.param_groups) + 1: | |
self.optim.add_param_group(param_groups[-1]) | |
# Update the bucketing strategy accordingly | |
if self.parameters_as_bucket_view: | |
self._build_param_buckets() | |
def consolidate_state_dict(self, to: int = 0) -> None: | |
r""" | |
Consolidate a list of ``state_dict`` s (one per rank) on the target rank. | |
Arguments: | |
to (int): the rank that receives the optimizer states (default: 0). | |
Raises: | |
RuntimeError: if ``overlap_with_ddp=True`` and this method is | |
called before this :class:`ZeroRedundancyOptimizer` instance | |
has been fully initialized, which happens once | |
:class:`DistributedDataParallel` gradient buckets have been | |
rebuilt. | |
.. warning:: This needs to be called on all ranks. | |
""" | |
self._check_overlap_initialized() | |
# Sync the exposed `param_groups` attributes to the local optimizer in | |
# case they have been updated | |
self._sync_param_groups(self.param_groups, self.optim.param_groups) | |
# Pull the sharded state from all ranks and store them in rank order | |
empty_messenger = torch.tensor( | |
[0], dtype=torch.uint8, device=self._default_device | |
) | |
# NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) | |
# due to compatibility issues with NCCL backend; a possible follow-up | |
# is to move all sharded state management to RPC RRef | |
self._all_state_dicts = [] | |
for rank in range(self.world_size): | |
global_rank = dist.distributed_c10d.get_global_rank( | |
self.process_group, rank | |
) | |
if self.rank == to: | |
# Consolidate all local `state_dict`s on this rank, storing on | |
# CPU to save GPU memory | |
if rank == self.rank: | |
# Directly append own optimizer state | |
self._all_state_dicts.append( | |
_recursive_copy_to_device( | |
self.optim.state_dict(), | |
non_blocking=True, | |
device=torch.device("cpu"), | |
) | |
) | |
else: | |
# Receive the optimizer state from the source rank | |
local_state_dict = _broadcast_object( | |
empty_messenger, | |
src_rank=global_rank, | |
group=self.process_group, | |
device=self._default_device, | |
) | |
self._all_state_dicts.append( | |
_recursive_copy_to_device( | |
local_state_dict, | |
non_blocking=True, | |
device=torch.device("cpu"), | |
) | |
) | |
else: | |
if rank == self.rank: | |
# Send the optimizer state to the target rank | |
_ = _broadcast_object( | |
self.optim.state_dict(), | |
src_rank=self.global_rank, | |
group=self.process_group, | |
device=self._default_device, | |
) | |
elif rank != to: | |
# Discard the received object; `broadcast()` is used for | |
# compatibility reasons | |
_ = _broadcast_object( | |
empty_messenger, | |
src_rank=global_rank, | |
group=self.process_group, | |
device=self._default_device, | |
) | |
def _verify_params_per_rank( | |
self, | |
params_per_rank: List[List[torch.Tensor]], | |
) -> None: | |
r""" | |
Verify ``params_per_rank`` for :meth:`_partition_parameters`. | |
The verification is done by checking that ``params_per_rank`` has length equal | |
to the world size and that it does not contain any parameters not passed into the | |
:class:`ZeroRedundancyOptimizer` constructor. | |
The parameters in ``params_per_rank`` being a strict subset of those | |
passed into the constructor is valid since some parameters may be | |
frozen. | |
Raises: | |
ValueError: if ``params_per_rank`` does not have length equal to | |
the world size or if it contains a parameter that was not | |
passed into the :class:`ZeroRedundancyOptimizer` constructor. | |
""" | |
if len(params_per_rank) != self.world_size: | |
raise ValueError( | |
"`params_per_rank` must have length equal to the world size" | |
) | |
all_params_set = set(self._all_params) | |
for params in params_per_rank: | |
for param in params: | |
if param not in all_params_set: | |
raise ValueError( | |
"Passing a new parameter in `params_per_rank` that " | |
"was not passed into the ZeroRedundancyOptimizer " | |
"constructor" | |
) | |
def _partition_param_group( | |
self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]] | |
) -> None: | |
r""" | |
Partition the parameter group ``param_group`` according to ``params_per_rank``. | |
The partition will modify the ``self._partition_parameters_cache``. This method should | |
only be used as a subroutine for :meth:`_partition_parameters`. | |
Arguments: | |
param_group (dict[str, Any]): a parameter group as normally defined | |
in an optimizer state. | |
params_per_rank (list[list[torch.Tensor]]): a :class:`list` of | |
length world size containing :class:`list` s of parameters to | |
assign to each rank. | |
""" | |
for rank, params in enumerate(params_per_rank): | |
rank_param_group = copy.copy(param_group) | |
rank_param_group["params"] = params | |
self._partition_parameters_cache[rank].append(rank_param_group) | |
def _partition_parameters( | |
self, | |
params_per_rank: Optional[List[List[torch.Tensor]]] = None, | |
) -> List[List[Dict]]: | |
r""" | |
Partitions parameters across distributed data parallel ranks. | |
Arguments: | |
params_per_rank (list[list[torch.Tensor]], optional): a | |
:class:`list` of length world size containing :class:`list` s | |
of parameters to assign to each rank; this provides a way to | |
specify a partition manually. | |
If ``None``, the parameters are partitioned according to an | |
internal algorithm. | |
(default: ``None``) | |
Returns: | |
A :class:`list` where each element of the list contains the | |
``param_groups`` for a rank (which itself is a :class:`list` of | |
:class:`dict`); element 0 corresponds to rank 0, etc.; each rank | |
stores the ``param_groups`` for all ranks for the collective | |
communication in :meth:`step`. | |
Raises: | |
ValueError: see :meth:`_validate_params_per_rank`. | |
RuntimeError: if ``params_per_rank`` is not ``None`` and this | |
:class:`ZeroRedundancyOptimizer` instance is using more than | |
one parameter group. | |
""" | |
if params_per_rank is None: | |
# Partition the parameters optimizing for uniformity | |
if len(self._partition_parameters_cache) == 0: | |
self._partition_parameters_cache = [[] for _ in range(self.world_size)] | |
sizes = [0] * self.world_size | |
for param_group in self.param_groups: | |
param_group_params_per_rank: List[List] = [ | |
[] for _ in range(self.world_size) | |
] | |
# Sort the parameters by size (largest first) | |
params_sorted = sorted( | |
param_group["params"], key=lambda t: t.numel(), reverse=True | |
) | |
for param in params_sorted: | |
# Greedily add the parameter to rank with smallest size so far | |
rank = self._get_min_index(sizes) | |
param_group_params_per_rank[rank].append(param) | |
sizes[rank] += param.numel() | |
# Apply the constructed partition of the parameter group | |
self._partition_param_group( | |
param_group, param_group_params_per_rank | |
) | |
return self._partition_parameters_cache | |
# Partition the parameters according to `params_per_rank` | |
assert len(self._partition_parameters_cache) == 0, ( | |
"Specifying `params_per_rank` should only be done when the " | |
"parameters have not been partitioned yet" | |
) | |
if len(self.param_groups) != 1: | |
raise RuntimeError( | |
"Specifying `params_per_rank` only supports a single parameter group" | |
) | |
self._verify_params_per_rank(params_per_rank) | |
self._partition_parameters_cache = [[] for _ in range(self.world_size)] | |
# Apply the passed-in partition of the parameter group | |
param_group = self.param_groups[0] | |
self._partition_param_group(param_group, params_per_rank) | |
return self._partition_parameters_cache | |
def _param_to_rank(self) -> Dict[torch.Tensor, int]: | |
r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" | |
if len(self._param_to_rank_cache) == 0: | |
for rank, param_groups in enumerate(self._partition_parameters()): | |
for param_group in param_groups: | |
for param in param_group["params"]: | |
self._param_to_rank_cache[param] = rank | |
return self._param_to_rank_cache | |
def _param_to_index(self) -> Dict[torch.Tensor, int]: | |
r""" | |
:class:`dict` mapping parameters to their indices in the global optimizer state. | |
NOTE: This assumes that the global optimizer state's indexing (in | |
``state_dict``) follows a linear ordering over the parameter groups. | |
""" | |
if len(self._param_to_index_cache) == 0: | |
self._param_to_index_cache = { | |
p: i | |
for i, p in enumerate(chain(*(g["params"] for g in self.param_groups))) | |
} | |
return self._param_to_index_cache | |
def _index_to_param(self) -> List[torch.Tensor]: | |
r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" | |
if len(self._index_to_param_cache) == 0: | |
self._index_to_param_cache = list( | |
chain(*(g["params"] for g in self.param_groups)) | |
) | |
return self._index_to_param_cache | |
def _broadcast_params_from_rank(self, rank: int): | |
r""" | |
Broadcast the shard of parameters from a given rank to all other ranks asynchronously. | |
Arguments: | |
rank (int): the source rank. | |
Returns: | |
A :class:`list` of async work handles for the ``broadcast()`` s | |
performed to synchronize the parameters. | |
""" | |
assert not self._overlap_with_ddp, ( | |
"`_broadcast_params_from_rank()` should not be used if " | |
"`overlap_with_ddp=True`; instead, the broadcasting should " | |
"happen in the DDP communication hook" | |
) | |
handles = [] | |
if self.parameters_as_bucket_view: | |
for dev_i_buckets in self._buckets: | |
bucket = dev_i_buckets[rank] | |
global_rank = dist.distributed_c10d.get_global_rank( | |
self.process_group, rank | |
) | |
handles.append( | |
dist.broadcast( | |
tensor=bucket, | |
src=global_rank, | |
group=self.process_group, | |
async_op=True, | |
) | |
) | |
else: | |
param_groups = self._partition_parameters()[rank] | |
global_rank = dist.distributed_c10d.get_global_rank( | |
self.process_group, rank | |
) | |
for param_group in param_groups: | |
for param in param_group["params"]: | |
handles.append( | |
dist.broadcast( | |
tensor=param.data, | |
src=global_rank, | |
group=self.process_group, | |
async_op=True, | |
) | |
) | |
return handles | |
def _sync_params(self): | |
r""" | |
Sync all parameter shards across the ranks. | |
This rank sends its shard of the parameters to all other ranks and | |
receives a shard from each other rank. This is done using | |
``broadcast()``. Parameters are sent bucket-by-bucket if | |
``parameters_as_bucket_view=True``and sent parameter-by-parameter | |
otherwise. | |
""" | |
handles = [] | |
for rank in range(self.world_size): | |
handles.extend(self._broadcast_params_from_rank(rank)) | |
_ = [x.wait() for x in handles] | |
def _device_to_params_per_rank( | |
self, | |
) -> Dict[torch.device, List[List[torch.Tensor]]]: | |
r""" | |
Return device parameters assigned per rank. | |
:class:`dict` mapping each device to a :class:`list` of the per-rank parameter | |
lists filtered to only include the parameters stored on that device. | |
Each per-rank parameter list gives the parameters assigned to that rank | |
to update. | |
This is used for constructing the parameter buckets if | |
``parameters_as_bucket_view=True``. | |
Let ``dev_i`` denote the ``i``th device for this rank. Then: | |
``dev_0`` maps to a list containing: | |
rank 0's assigned parameters stored on ``dev_0``, | |
rank 1's assigned parameters stored on ``dev_0``, | |
... | |
``dev_1`` maps to a list containing: | |
rank 0's assigned parameters stored on ``dev_1``, | |
rank 1's assigned parameters stored on ``dev_1``, | |
... | |
... | |
""" | |
assert self.parameters_as_bucket_view, ( | |
"`_device_to_params_per_rank` should only be used if " | |
"`parameters_as_bucket_view=True`" | |
) | |
if len(self._device_to_params_per_rank_cache) == 0: | |
for rank, param_groups in enumerate(self._partition_parameters()): | |
for param_group in param_groups: | |
for param in param_group["params"]: | |
device = param.device | |
if device not in self._device_to_params_per_rank_cache: | |
self._device_to_params_per_rank_cache[device] = [ | |
[] for _ in range(self.world_size) | |
] | |
self._device_to_params_per_rank_cache[device][rank].append( | |
param | |
) | |
return self._device_to_params_per_rank_cache | |
def _get_min_index( | |
self, | |
values: List[int], | |
disallowed_indices: Optional[Set[int]] = None, | |
) -> int: | |
r""" | |
Return ``values.index(min(values))``, except only uses one pass. | |
It also excludes any indices in ``disallowed_indices`` if provided. | |
Arguments: | |
values: (List[int]): :class:`list` of values. | |
disallowed_indices (Optional[Set[int]]): indices that are | |
disallowed from being the returned min index. | |
""" | |
min_index = -1 | |
min_value = float("inf") | |
for i, value in enumerate(values): | |
if disallowed_indices and i in disallowed_indices: | |
continue | |
if value < min_value: | |
min_value = value | |
min_index = i | |
assert min_index >= 0, "All indices are disallowed" | |
return min_index | |
def _assign_bucket_subset_to_rank( | |
self, | |
bucket_index: int, | |
bucket_params: List[torch.Tensor], | |
bucket_offset: int, | |
assigned_rank: int, | |
assigned_ranks_per_bucket: List[Set[int]], | |
) -> None: | |
r""" | |
Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. | |
The model parameters given by ``bucket_params`` represents a (possibly non-strict) | |
subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. | |
Arguments: | |
bucket_index (int): index of the :class:`DistributedDataParallel` | |
gradient bucket. | |
bucket_params (List[torch.Tensor]): subset of the parameters | |
corresponding to the bucket to assign. | |
bucket_offset (int): offset giving the index of the first element | |
in ``bucket_params`` in the bucket's full parameter list. | |
assigned_rank (int): group rank to assign to. | |
assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks | |
assigned to each bucket. | |
""" | |
overlap_info = self._overlap_info | |
if len(bucket_params) == 0: | |
raise ValueError("Empty bucket assignment") | |
params_per_rank = overlap_info.params_per_rank | |
offsets = overlap_info.offsets | |
self._bucket_assignments_per_rank_cache[assigned_rank][ | |
bucket_index | |
] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) | |
if self.global_rank == assigned_rank: | |
offsets[bucket_index] = len(params_per_rank[assigned_rank]) | |
params_per_rank[assigned_rank].extend(bucket_params) | |
assigned_ranks_per_bucket[bucket_index].add(assigned_rank) | |
self._overlap_info.num_bucket_assignments += 1 | |
def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: | |
r""" | |
Return DDP bucket parameters assigned per rank. | |
:class:`list` of length world size consisting of :class:`dict` s | |
mapping bucket indices to :class:`_DDPBucketAssignment` s for each | |
rank. | |
""" | |
assert self._overlap_with_ddp, ( | |
"`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" | |
) | |
if len(self._bucket_assignments_per_rank_cache) > 0: | |
return self._bucket_assignments_per_rank_cache | |
overlap_info = self._overlap_info | |
assert overlap_info.status == _OverlapStatus.INITIALIZED | |
self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] | |
params_per_bucket = overlap_info.params_per_bucket | |
if overlap_info.shard_buckets: | |
# Define the assignment threshold to approximate uniformity | |
assert overlap_info.total_size is not None, "`total_size` was not computed" | |
threshold = overlap_info.total_size / self.world_size # type: ignore[operator] | |
size_per_rank = [0 for _ in range(self.world_size)] | |
num_buckets = len(params_per_bucket) | |
overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] | |
assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket | |
if not overlap_info.shard_buckets: | |
# Assign each DDP bucket entirely to a single rank | |
for bucket_index, bucket_params in enumerate(params_per_bucket): | |
assert len(bucket_params) > 0, "Empty bucket" | |
assigned_rank = self._get_assigned_rank(bucket_index) | |
self._assign_bucket_subset_to_rank( | |
bucket_index, | |
bucket_params, | |
0, | |
assigned_rank, | |
assigned_ranks_per_bucket, | |
) | |
else: | |
# Assign each DDP bucket to possibly multiple ranks | |
# Specifically, sort the DDP buckets by increasing size, and for | |
# each bucket, iteratively assign the maximal unassigned subset | |
# with size less than `threshold` to the rank with the least total | |
# size so far -- each such assignment is represented by a | |
# `_DDPBucketAssignment` instance and only contains parameters from | |
# a single DDP bucket | |
params_per_bucket_enum = sorted( | |
enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) | |
) | |
for bucket_index, bucket_params in params_per_bucket_enum: | |
assert len(bucket_params) > 0, "Empty bucket" | |
bucket_offset = 0 | |
assignment_size = 0 | |
for param_index, param in enumerate(bucket_params): | |
param_numel = param.numel() | |
if ( | |
assignment_size + param_numel >= threshold | |
and param_index > bucket_offset | |
): | |
assigned_rank = self._get_min_index( | |
size_per_rank, assigned_ranks_per_bucket[bucket_index] | |
) | |
# Include up to but not including the parameter that | |
# exceeded the threshold | |
self._assign_bucket_subset_to_rank( | |
bucket_index, | |
bucket_params[bucket_offset:param_index], | |
bucket_offset, | |
assigned_rank, | |
assigned_ranks_per_bucket, | |
) | |
size_per_rank[assigned_rank] += assignment_size | |
bucket_offset = param_index | |
assignment_size = 0 | |
assignment_size += param_numel | |
# Assign the remainder of the bucket so that no assignment | |
# spans across two buckets | |
assigned_rank = self._get_min_index( | |
size_per_rank, assigned_ranks_per_bucket[bucket_index] | |
) | |
self._assign_bucket_subset_to_rank( | |
bucket_index, | |
bucket_params[bucket_offset:], | |
bucket_offset, | |
assigned_rank, | |
assigned_ranks_per_bucket, | |
) | |
size_per_rank[assigned_rank] += assignment_size | |
return self._bucket_assignments_per_rank_cache | |
def _local_step( | |
self, | |
gradients: Optional[List[Optional[torch.Tensor]]] = None, | |
closure: Optional[Callable[[], float]] = None, | |
**kwargs: Any, | |
) -> Optional[float]: | |
r""" | |
Perform a single optimizer step without syncing parameters across ranks. | |
Arguments: | |
gradients (list[Optional[torch.Tensor]], optional): a :class:`list` | |
of length equal to the number of parameters assigned to this | |
rank containing gradient tensors or ``None`` as its elements; | |
a ``None`` in the :class:`list` indicates that the | |
corresponding parameter should not be updated. | |
If the argument itself is ``None``, then all parameters are | |
updated, and the gradients are assumed to be already populated. | |
(default: ``None``) | |
closure (Callable): a closure that re-evaluates the model and | |
returns the loss; optional for most optimizers and should be | |
``None`` if ``gradients`` is not ``None``; (default: ``None``) | |
Returns: | |
Optional loss depending on the underlying local optimizer. | |
.. warning:: | |
The argument ``gradients`` should only be specified (i.e. not | |
``None``) if ``overlap_with_ddp=True``, in which case | |
:class:`ZeroRedundancyOptimizer` wraps a functional optimizer. | |
""" | |
Join.notify_join_context(self) | |
# Check if the model trainability has changed | |
is_trainable_mask = self._get_is_trainable_mask() | |
if is_trainable_mask != self._is_trainable_mask: | |
if self._overlap_with_ddp: | |
raise RuntimeError( | |
"ZeroRedundancyOptimizer with `overlap_with_ddp=True` " | |
"does not support changing parameter trainability at run " | |
"time" | |
) | |
logger.warning( | |
"ZeroRedundancyOptimizer detected that the trainable " | |
"parameters changed; rebuilding the parameter buckets if " | |
"enabled" | |
) | |
self._build_param_buckets() | |
self._is_trainable_mask = is_trainable_mask | |
# Sync the exposed `param_groups` attributes to the local optimizer in | |
# case they have been updated | |
self._sync_param_groups(self.param_groups, self.optim.param_groups) | |
# Run the optimizer step on this shard only | |
if gradients is None: | |
loss = ( | |
self.optim.step(**kwargs) | |
if closure is None | |
else self.optim.step(closure=closure, **kwargs) | |
) | |
else: | |
assert self._overlap_with_ddp, ( | |
"Specifying `gradients` should not " | |
"be used when `overlap_with_ddp=False`" | |
) | |
assert closure is None, ( | |
"`closure` is not supported when using a local functional optimizer" | |
) | |
loss = self.optim.step(gradients=gradients) | |
# Sync any updated attributes in the local optimizer to the exposed | |
# `param_groups` | |
self._sync_param_groups(self.optim.param_groups, self.param_groups) | |
return loss | |
def step( | |
self, | |
closure: Optional[Callable[[], float]] = None, | |
**kwargs: Any, | |
) -> Optional[float]: | |
r""" | |
Perform a single optimizer step and syncs parameters across all ranks. | |
Arguments: | |
closure (Callable): a closure that re-evaluates the model and | |
returns the loss; optional for most optimizers. | |
Returns: | |
Optional loss depending on the underlying local optimizer. | |
.. note: Any extra parameters are passed to the base optimizer as-is. | |
""" | |
if self._overlap_with_ddp: | |
logger.warning( | |
"`step()` should not be included in the training loop when " | |
"`overlap_with_ddp=True`" | |
) | |
return None | |
# Perform the local optimizer step | |
loss = self._local_step(closure=closure, **kwargs) | |
# Sync all of the updated parameter shards across the ranks | |
self._sync_params() | |
return loss | |
def join_hook(self, **kwargs): | |
r""" | |
Return the ZeRO join hook. | |
It enables training on uneven inputs by | |
shadowing the collective communications in the optimizer step. | |
Gradients must be properly set before this hook is called. | |
Arguments: | |
kwargs (dict): a :class:`dict` containing any keyword arguments | |
to modify the behavior of the join hook at run time; all | |
:class:`Joinable` instances sharing the same join context | |
manager are forwarded the same value for ``kwargs``. | |
This hook does not support any keyword arguments; i.e. ``kwargs`` is | |
unused. | |
""" | |
return _ZeROJoinHook(self) | |
def join_device(self) -> torch.device: | |
r"""Return default device.""" | |
return self._default_device | |
def join_process_group(self) -> Any: | |
r"""Return process group.""" | |
return self.process_group | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
r""" | |
Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. | |
Arguments: | |
state_dict (dict): optimizer state; should be an object returned | |
from a call to :meth:`state_dict`. | |
Raises: | |
RuntimeError: if ``overlap_with_ddp=True`` and this method is | |
called before this :class:`ZeroRedundancyOptimizer` instance | |
has been fully initialized, which happens once | |
:class:`DistributedDataParallel` gradient buckets have been | |
rebuilt. | |
""" | |
self._check_overlap_initialized() | |
for index, value in state_dict["state"].items(): | |
param = self._index_to_param[index] | |
if self._param_to_rank[param] != self.rank: | |
# Clear any state irrelevant to this rank | |
state_dict["state"][index] = None | |
else: | |
# Load the parameter state to the local optimizer | |
self.optim.state[param] = _recursive_copy_to_device( | |
value, non_blocking=True, device=param.device | |
) | |
# Force zero-dimensional tensors (like Adam "step") on CPU | |
for state_name, state_value in self.optim.state[param].items(): | |
if torch.is_tensor(state_value) and state_value.dim() == 0: | |
self.optim.state[param][state_name] = state_value.cpu() | |
super().load_state_dict(state_dict) | |
# Sync the input state with the exposed and local optimizer states | |
self._sync_param_groups(state_dict["param_groups"], self.param_groups) | |
self._sync_param_groups(self.param_groups, self.optim.param_groups) | |
def state_dict(self) -> Dict[str, Any]: | |
r""" | |
Return the last global optimizer state known to this rank. | |
.. warning: | |
If the state has not been consolidated to this rank, this raises a | |
runtime error, and even if it has, the state may not be up-to-date, | |
depending on when :meth:`consolidate_state_dict` was last called. | |
Raises: | |
RuntimeError: if ``overlap_with_ddp=True`` and this method is | |
called before this :class:`ZeroRedundancyOptimizer` instance | |
has been fully initialized, which happens once | |
:class:`DistributedDataParallel` gradient buckets have been | |
rebuilt; or if this method is called without a preceding call | |
to :meth:`consolidate_state_dict`. | |
""" | |
self._check_overlap_initialized() | |
if len(self._all_state_dicts) == 0: | |
raise RuntimeError( | |
"Optimizer state has not been consolidated on this rank. " | |
f"Please call `consolidate_state_dict(to={self.rank})` on " | |
"all ranks beforehand if you meant to save the global state." | |
) | |
# Get the possibly-stale global optimizer state that uses global | |
# parameter indexing | |
state_dict = super().state_dict() | |
# Update the global optimizer state with local state information, | |
# factoring in the translation from local to global indexing | |
for rank, local_state_dict in enumerate(self._all_state_dicts): | |
local_param_groups = local_state_dict["param_groups"] | |
global_param_groups = self._partition_parameters()[rank] | |
assert len(local_param_groups) == len( | |
global_param_groups | |
), "Mismatch between number of local and global parameter groups" | |
for local_param_group, global_param_group in zip( | |
local_param_groups, global_param_groups | |
): | |
# `local_param_group` stores local indices, while | |
# `global_param_group` stores the tensors directly | |
local_param_indices = local_param_group["params"] | |
global_params = global_param_group["params"] | |
assert len(local_param_indices) == len( | |
global_params | |
), "Mismatch between number of local and global parameters in parameter group" | |
for local_param_index, global_param in zip( | |
local_param_indices, global_params | |
): | |
# Update the global parameter state, if any | |
if local_param_index in local_state_dict["state"]: | |
global_param_index = self._param_to_index[global_param] | |
state_dict["state"][global_param_index] = local_state_dict[ | |
"state" | |
][local_param_index] | |
# Sort the parameters in the state | |
state_dict["state"] = dict(sorted(state_dict["state"].items())) | |
return state_dict | |
def _sync_param_groups( | |
src_param_groups: List[Dict[Any, Any]], | |
dst_param_groups: List[Dict[Any, Any]], | |
) -> None: | |
r""" | |
Sync the attributes from the source parameter groups to the destination parameter groups. | |
Example attributes include learning rate or scheduler attributes. The | |
two parameter groups should have the same length (i.e. same number of | |
parameter groups). | |
Arguments: | |
src_param_groups (list[dict]): parameter groups giving the | |
attribute settings to copy. | |
dst_param_groups (list[dict]): parameter groups giving the | |
attribute settings to set. | |
""" | |
assert len(src_param_groups) == len( | |
dst_param_groups | |
), "Mismatch between number of source and destination parameter groups" | |
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): | |
# Sync all attributes except the parameters | |
for attr in filter(lambda x: x != "params", src_param_group.keys()): | |
dst_param_group[attr] = src_param_group[attr] | |
def _build_param_buckets(self) -> None: | |
r""" | |
Build parameter buckets if ``parameters_as_bucket_view=True``. | |
For each device that stores this rank's parameters, there is a | |
bucket (represented as a tensor) containing all of the parameters on | |
that device that are assigned to a given rank in the parameter update | |
partition. | |
This method is called in the constructor and any time parameter | |
trainability is changed. | |
.. warning:: | |
The current implementation assumes that all of the parameters in a | |
bucket are of the same dense type when allocating the bucket's | |
tensor. | |
.. warning:: | |
If the model parameters are stored across more than one device, | |
then the storage partitioning must be the same across all | |
processes in order for parameter synchronization to work. | |
""" | |
if not self.parameters_as_bucket_view or self._overlap_with_ddp: | |
return | |
# `self._buckets[i][j]` are the parameters stored on device i and | |
# assigned to rank j | |
num_devices = len(self._device_to_params_per_rank) | |
self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] | |
for dev_i, (device, params_per_rank) in enumerate( | |
self._device_to_params_per_rank.items() | |
): | |
for params in params_per_rank: | |
bucket_size = 0 | |
dtype = None | |
trainable_params = [] | |
for param in params: | |
if not _is_trainable(param): | |
# Clone in case the parameter was previously part of | |
# a bucket to avoid the data from being destroyed | |
param.data = param.data.detach().clone() | |
else: | |
bucket_size += param.numel() | |
trainable_params.append(param) | |
dtype = param.dtype # assumes all same dtype | |
if bucket_size == 0: | |
# Create a dummy bucket if there are no parameters | |
bucket = torch.zeros(1, device=device) | |
else: | |
# Construct the bucket (assuming all dense and same dtype) | |
bucket = torch.empty(bucket_size, dtype=dtype, device=device) | |
offset = 0 | |
for param in trainable_params: | |
offset_next = offset + param.numel() | |
bucket[offset:offset_next].copy_(param.data.flatten()) | |
param.data = bucket[offset:offset_next].view_as(param.data) | |
offset = offset_next | |
self._buckets[dev_i].append(bucket) # type: ignore[arg-type] | |
def _build_ddp_param_buckets(self) -> None: | |
r""" | |
Build the DDP bucket with parameters assigned to this rank. | |
For each DDP bucket with parameters assigned to this rank, flattens the | |
data of those parameters into a single tensor and saves the tensor to | |
the ``tensor`` attribute in the corresponding | |
:class:`_DDPBucketAssignment` instance stored in | |
``self._bucket_assignments_per_rank``. | |
:class:`DistributedDataParallel` guarantees that the parameters | |
corresponding to a gradient bucket have the same device and the same | |
dtype. | |
""" | |
for bucket_assignments in self._bucket_assignments_per_rank: | |
for bucket_assignment in bucket_assignments.values(): | |
params = bucket_assignment.parameters | |
bucket_size = 0 | |
dtype = None | |
for param in params: | |
assert _is_trainable(param), ( | |
"Model parameter " | |
"corresponding to a gradient in a DDP bucket should " | |
"require a gradient" | |
) | |
bucket_size += param.numel() | |
dtype = param.dtype # assumes all same dtype | |
assert bucket_size > 0, "Empty bucket" | |
# Construct the bucket tensor (assuming all dense and same dtype) | |
tensor = torch.empty( | |
bucket_size, dtype=dtype, device=bucket_assignment.device | |
) | |
offset = 0 | |
for param in params: | |
offset_next = offset + param.numel() | |
tensor[offset:offset_next].copy_(param.data.flatten()) | |
param.data = tensor[offset:offset_next].view_as(param.data) | |
offset = offset_next | |
bucket_assignment.tensor = tensor | |
def _verify_and_init_params( | |
self, | |
params: Any, | |
) -> Union[List[torch.Tensor], List[dict]]: | |
r""" | |
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. | |
The initializagtion will first make sure that provided ``params`` is valid. | |
Arguments: | |
params (Any): Candidate parameter list or parameter groups to verify. | |
Raises: | |
TypeError: ``params`` has an invalid type. | |
ValueError: ``params`` is empty. | |
Returns: | |
The persistent form of ``params`` to be passed into the parent | |
:class:`Optimizer` constructor -- i.e. returns ``params`` as a | |
:class:`list` to ensure that it can be iterated over again. | |
""" | |
if isinstance(params, torch.Tensor): | |
raise TypeError( | |
"`params` argument should be an iterable of " | |
f"Tensors, but got {torch.typename(params)}" | |
) | |
try: | |
all_params = list(params) | |
except TypeError as e: | |
raise TypeError( | |
"`params` argument should be an iterable of Tensors" | |
f" or dicts, but got {torch.typename(params)}" | |
) from e | |
if len(all_params) == 0: | |
raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") | |
all_tensors = True | |
all_dicts = True | |
for param in all_params: | |
all_tensors &= isinstance(param, torch.Tensor) | |
all_dicts &= isinstance(param, dict) | |
if not all_tensors and not all_dicts: | |
raise TypeError( | |
"`params` argument should be an iterable of Tensors or dicts" | |
) | |
# Ensure that `self._all_params` contains a list of all parameters | |
if all_tensors: | |
self._all_params = all_params | |
elif all_dicts: | |
self._all_params = [] | |
# `all_params` contains parameter groups (not parameters) | |
for param_group in all_params: | |
if "params" not in param_group: | |
raise ValueError( | |
"Each parameter group passed-in via `params` must " | |
"have a 'params' key mapping to the parameters in " | |
"the group" | |
) | |
self._all_params.extend(param_group["params"]) | |
return all_params | |
def _verify_same_dense_param_type(self) -> None: | |
r""" | |
Verify that all parameters are of the same dense type. | |
The method assumes that ``self._all_params`` has been initialized | |
and is non-empty. | |
Raises: | |
ValueError: ``params`` contains sparse parameters or parameters | |
of varying dense types. | |
NOTE: This method can be removed once support for sparse parameters | |
and varying parameter types is added. | |
""" | |
typename = torch.typename(self._all_params[0]) | |
if self._all_params[0].is_sparse: | |
raise ValueError( | |
"ZeroRedundancyOptimizer only supports using " | |
"the same dense type for all parameters but got " | |
f"{typename}" | |
) | |
for param in self._all_params[1:]: | |
other_typename = torch.typename(param) | |
if other_typename != typename: | |
raise ValueError( | |
"ZeroRedundancyOptimizer only supports " | |
"using the same dense type for all " | |
f"parameters but got both {typename} and " | |
f"{other_typename}" | |
) | |
def _get_is_trainable_mask(self) -> List[bool]: | |
r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" | |
return list(map(_is_trainable, self._all_params)) | |
def _init_local_optimizer(self) -> None: | |
r""" | |
Initialize this rank's local optimizer, responsible for its subset of the parameters. | |
The local optimizer is saved in ``self.optim``. | |
""" | |
assert ( | |
self._optim_constructor is not None | |
), "The local optimizer class has not been set" | |
param_groups = self._partition_parameters()[self.rank] | |
# `overlap_with_ddp=True` requires a local functional optimizer | |
if self._overlap_with_ddp: | |
# Functional optimizers only support a single parameter group and | |
# require passing in the parameters as a list | |
assert len(param_groups) == 1, ( | |
"Initializing the local " | |
"functional optimizer with more than one parameter group" | |
) | |
params = param_groups[0]["params"] | |
# Try to pass `_allow_empty_param_list=True` to avoid erroring | |
if ( | |
"_allow_empty_param_list" | |
in inspect.signature(self._optim_constructor).parameters | |
): | |
self.optim: Any = self._optim_constructor( | |
params, **self._optim_defaults, _allow_empty_param_list=True | |
) | |
else: | |
logger.warning( | |
"%s does not support the argument " | |
"`_allow_empty_param_list`; ZeroRedundancyOptimizer may " | |
"error due to an empty parameter list", | |
self._optim_constructor | |
) | |
self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] | |
# Log information about the DDP and ZeRO bucketing | |
if dist.get_debug_level() != dist.DebugLevel.OFF: | |
local_numel = sum(p.numel() for p in params) | |
num_assigned_buckets = len( | |
self._bucket_assignments_per_rank[self.global_rank] | |
) | |
logger.info( | |
"rank %s with %s parameters " | |
"across %s buckets", | |
self.global_rank, local_numel, num_assigned_buckets | |
) | |
if self.global_rank == 0: | |
logger.info( | |
"%s DDP " | |
"buckets and " | |
"%s bucket " | |
"assignments", | |
len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments | |
) | |
else: | |
# NOTE: Passing `param_groups` into the local optimizer constructor | |
# bypasses the empty parameter list check | |
self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef] | |
# TODO: Manually add `self.param_groups` if using a functional | |
# optimizer; remove this if/when the functional optimizers support | |
# multiple parameter groups | |
if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): | |
assert hasattr(self.optim, "param_group"), ( | |
"The functional optimizer should set at least one of the " | |
"attributes `param_group` or `param_groups`" | |
) | |
self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] | |
self._sync_param_groups(self.optim.param_groups, self.param_groups) | |
def _init_zero_for_overlap(self) -> None: | |
r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" | |
assert self._overlap_with_ddp, ( | |
"`_init_zero_for_overlap()` should only be called when " | |
"`overlap_with_ddp=True`" | |
) | |
self._overlap_info.status = _OverlapStatus.INITIALIZED | |
self._clear_cache() | |
self._partition_parameters(self._overlap_info.params_per_rank) | |
self._build_ddp_param_buckets() | |
self._init_local_optimizer() | |
def _get_assigned_rank(self, bucket_index: int) -> int: | |
r""" | |
Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. | |
Arguments: | |
bucket_index (int): index of the :class:`DistributedDataParallel` | |
bucket for which to get the assigned rank. | |
""" | |
assert not self._overlap_info.shard_buckets, ( | |
"The bucket assignment requires global bucket information and " | |
"will be computed later; there should be no need to use this " | |
"method" | |
) | |
return bucket_index % self.world_size | |
def _check_overlap_initialized(self): | |
r""" | |
Check the delayed initialization depending on the value of ``overlap_with_ddp``. | |
The delayed initialization has occurred (see | |
:meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and | |
raises a ``RuntimeError`` if not. This should preface methods that | |
should not be run before that delayed initialization. | |
Raises: | |
RuntimeError: if ``overlap_with_ddp=True`` and | |
:meth:`_init_zero_for_overlap` has not been called. | |
""" | |
if ( | |
self._overlap_with_ddp | |
and self._overlap_info.status != _OverlapStatus.INITIALIZED | |
): | |
raise RuntimeError( | |
"This method should not be called until this " | |
"ZeroRedundancyOptimizer instance has been fully " | |
"initialized" | |
) | |
def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: | |
r""" | |
Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. | |
Returns: | |
- ``optimizer_class`` if ``overlap_with_ddp=False`` and | |
``optimizer_class`` is not a functional optimizer. | |
- ``optimizer_class`` if ``overlap_with_ddp=True`` and | |
``optimizer_class`` is already a functional optimizer. | |
- The functional equivalent of ``optimizer_class`` if | |
``overlap_with_ddp=True`` and ``optimizer_class`` is not | |
already a functional optimizer (assuming the equivalent | |
exists). | |
Raises: | |
ValueError: | |
- if ``overlap_with_ddp=True`` but ``optimizer_class`` is | |
neither a functional optimizer nor translatable to a | |
functional optimizer. | |
- if ``overlap_with_ddp=False`` and ``optimizer_class`` is a | |
functional optimizer. | |
""" | |
functional_optims = functional_optim_map.values() | |
if not self._overlap_with_ddp: | |
if optimizer_class in functional_optims: | |
# Using a functional optimizer is only supported when | |
# `overlap_with_ddp=True` | |
raise ValueError( | |
f"Passing in a functional optimizer {optimizer_class} " | |
"when `overlap_with_ddp=False`" | |
) | |
else: | |
return optimizer_class | |
else: | |
if optimizer_class in functional_optims: | |
# Already a functional optimizer | |
return optimizer_class | |
elif optimizer_class in functional_optim_map: | |
# Translate the passed-in optimizer class to its functional | |
# equivalent if `overlap_with_ddp=True` | |
optim_constructor = functional_optim_map[optimizer_class] | |
logger.info( | |
"Using the functional optimizer %s " | |
"instead of %s since " | |
"`overlap_with_ddp=True`", | |
optim_constructor, optimizer_class | |
) | |
return optim_constructor | |
else: | |
raise ValueError( | |
"Using `ddp_with_overlap=True` requires using a " | |
"functional optimizer, but there is no supported functional " | |
f"optimizer equivalent for {optimizer_class}" | |
) | |