|
|
|
|
|
from contextlib import contextmanager |
|
from typing import Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from mmengine.device import (is_cuda_available, is_mlu_available, |
|
is_musa_available, is_npu_available) |
|
from mmengine.registry import OPTIM_WRAPPERS |
|
from mmengine.utils import digit_version |
|
from mmengine.utils.dl_utils import TORCH_VERSION |
|
from mmengine.optim.optimizer import OptimWrapper |
|
|
|
|
|
if is_npu_available(): |
|
from torch.npu.amp import GradScaler |
|
elif is_mlu_available(): |
|
from torch.mlu.amp import GradScaler |
|
else: |
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
|
|
class AmpOptimWrapper(OptimWrapper): |
|
"""A subclass of :class:`OptimWrapper` that supports automatic mixed |
|
precision training based on torch.cuda.amp. |
|
|
|
``AmpOptimWrapper`` provides a unified interface with |
|
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way |
|
as ``OptimWrapper``. |
|
|
|
Warnings: |
|
``AmpOptimWrapper`` requires PyTorch >= 1.6. |
|
|
|
Args: |
|
loss_scale (float or str or dict): The initial configuration of |
|
`torch.cuda.amp.GradScaler`. See more specific arguments |
|
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501 |
|
Defaults to ``dynamic``. |
|
|
|
- "dynamic": Initialize GradScale without any arguments. |
|
- float: Initialize GradScaler with ``init_scale``. |
|
- dict: Initialize GradScaler with more detail configuration. |
|
|
|
dtype (str or torch.dtype, optional): The data type to autocast in amp. |
|
If a ``str`` is given, it will be converted to ``torch.dtype``. |
|
Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and |
|
`'float64'`. If set to ``None``, the default data type will be used. |
|
Defaults to None. |
|
`New in version 0.6.1.` |
|
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should |
|
be enabled when using ``FullyShardedDataParallel``. |
|
Defaults to False. |
|
`New in version 0.8.0.` |
|
**kwargs: Keyword arguments passed to OptimWrapper. |
|
|
|
Warnings: |
|
``dtype`` argument is only available with PyTorch version >= 1.10.0. If |
|
you use PyTorch of an older version, it will be ignored. |
|
|
|
Note: |
|
If you use ``IterBasedRunner`` and enable gradient accumulation, |
|
the original `max_iters` should be multiplied by |
|
``accumulative_counts``. |
|
""" |
|
|
|
valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') |
|
|
|
def __init__(self, |
|
loss_scale: str = 'dynamic', |
|
dtype: Union[str, torch.dtype] = None, |
|
use_fsdp: bool = False, |
|
**kwargs): |
|
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( |
|
'`torch.cuda.amp` is only available when pytorch version >= 1.6') |
|
assert is_cuda_available() or is_npu_available() or is_mlu_available( |
|
) or is_musa_available(), ( |
|
'``AmpOptimizerWrapper`` is only available training ' |
|
'on gpu, npu, mlu or musa') |
|
super().__init__(**kwargs) |
|
self._scale_update_param = None |
|
|
|
if use_fsdp: |
|
if digit_version(torch.__version__) >= digit_version('2.0.0'): |
|
from torch.distributed.fsdp.sharded_grad_scaler import \ |
|
ShardedGradScaler |
|
scaler_type = ShardedGradScaler |
|
else: |
|
raise RuntimeError( |
|
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') |
|
else: |
|
scaler_type = GradScaler |
|
|
|
enable_loss_scaler = dtype != torch.bfloat16 |
|
|
|
if loss_scale == 'dynamic': |
|
|
|
|
|
self.loss_scaler = scaler_type(enabled=enable_loss_scaler) |
|
elif isinstance(loss_scale, float): |
|
|
|
self._scale_update_param = loss_scale |
|
self.loss_scaler = scaler_type( |
|
init_scale=loss_scale, enabled=enable_loss_scaler) |
|
elif isinstance(loss_scale, dict): |
|
|
|
loss_scale['enabled'] = loss_scale.pop('enabled', |
|
True) and enable_loss_scaler |
|
self.loss_scaler = scaler_type(**loss_scale) |
|
else: |
|
raise TypeError('loss_scale must be of type float, dict, or ' |
|
f'"dynamic", but got {loss_scale}') |
|
|
|
|
|
if isinstance(dtype, str): |
|
assert dtype in self.valid_dtypes, ( |
|
f'dtype should be any of {self.valid_dtypes}, got {dtype}') |
|
dtype = getattr(torch, dtype) |
|
|
|
assert dtype is None or isinstance(dtype, torch.dtype), ( |
|
f'dtype should be None or instance of torch.dtype, got {dtype}') |
|
self.cast_dtype = dtype |
|
|
|
def backward(self, loss: torch.Tensor, **kwargs): |
|
"""Perform gradient back propagation with :attr:`loss_scaler`. |
|
|
|
Args: |
|
loss (torch.Tensor): The loss of current iteration. |
|
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` |
|
""" |
|
self.loss_scaler.scale(loss).backward(**kwargs) |
|
self._inner_count += 1 |
|
|
|
def step(self, **kwargs): |
|
"""Update parameters with :attr:`loss_scaler`. |
|
|
|
Args: |
|
kwargs: Keyword arguments passed to |
|
:meth:`torch.optim.Optimizer.step`. |
|
""" |
|
if self.clip_grad_kwargs: |
|
self.loss_scaler.unscale_(self.optimizer) |
|
self._clip_grad() |
|
self.loss_scaler.step(self.optimizer, **kwargs) |
|
self.loss_scaler.update(self._scale_update_param) |
|
|
|
def state_dict(self) -> dict: |
|
"""Get the state dictionary of :attr:`optimizer` and |
|
:attr:`loss_scaler`. |
|
|
|
Based on the state dictionary of the optimizer, the returned state |
|
dictionary will add a key named "loss_scaler". |
|
|
|
Returns: |
|
dict: The merged state dict of :attr:`loss_scaler` and |
|
:attr:`optimizer`. |
|
""" |
|
|
|
state_dict = super().state_dict() |
|
state_dict['loss_scaler'] = self.loss_scaler.state_dict() |
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict: dict): |
|
"""Load and parse the state dictionary of :attr:`optimizer` and |
|
:attr:`loss_scaler`. |
|
|
|
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will |
|
load the corresponding keys. Otherwise, only the :attr:`optimizer` |
|
will load the state dictionary. |
|
|
|
Args: |
|
state_dict (dict): The state dict of :attr:`optimizer` and |
|
:attr:`loss_scaler` |
|
""" |
|
if 'loss_scaler' in state_dict: |
|
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) |
|
|
|
if 'base_param_settings' in state_dict: |
|
self.base_param_settings = state_dict.pop('base_param_settings') |
|
|
|
|
|
self.optimizer.load_state_dict(state_dict) |
|
|
|
@contextmanager |
|
def optim_context(self, model: nn.Module): |
|
"""Enables the context for mixed precision training, and enables the |
|
context for disabling gradient synchronization during gradient |
|
accumulation context. |
|
|
|
Args: |
|
model (nn.Module): The training model. |
|
""" |
|
from mmengine.runner.amp import autocast |
|
with super().optim_context(model), autocast(dtype=self.cast_dtype): |
|
yield |