Spaces:
Running
Running
import collections | |
import functools | |
import torch | |
try: | |
import numpy as np | |
HAS_NUMPY = True | |
except ModuleNotFoundError: | |
np = None # type: ignore[assignment] | |
from typing import Any | |
__all__ = ["autocast", "custom_fwd", "custom_bwd"] | |
class autocast(torch.amp.autocast_mode.autocast): | |
r"""See :class:`torch.autocast`. | |
``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)`` | |
""" | |
def __init__( | |
self, | |
enabled: bool = True, | |
dtype: torch.dtype = torch.float16, | |
cache_enabled: bool = True, | |
): | |
if torch._jit_internal.is_scripting(): | |
self._enabled = enabled | |
self.device = "cuda" | |
self.fast_dtype = dtype | |
return | |
super().__init__( | |
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled | |
) | |
def __enter__(self): | |
if torch._jit_internal.is_scripting(): | |
return self | |
return super().__enter__() | |
# TODO: discuss a unified TorchScript-friendly API for autocast | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] | |
if torch._jit_internal.is_scripting(): | |
return | |
return super().__exit__(exc_type, exc_val, exc_tb) | |
def __call__(self, func): | |
if torch._jit_internal.is_scripting(): | |
return func | |
return super().__call__(func) | |
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which | |
# may be falsely detected as "Iterables." | |
def _cast(value, dtype): | |
if isinstance(value, torch.Tensor): | |
is_eligible = ( | |
value.is_floating_point() | |
and value.is_cuda | |
and (value.dtype is not torch.float64) | |
) | |
return value.to(dtype) if is_eligible else value | |
elif isinstance(value, (str, bytes)): | |
return value | |
elif HAS_NUMPY and isinstance(value, np.ndarray): | |
return value | |
elif isinstance(value, collections.abc.Mapping): | |
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} | |
elif isinstance(value, collections.abc.Iterable): | |
iterable = (_cast(v, dtype) for v in value) | |
if isinstance(value, (list, tuple)): | |
return type(value)(iterable) | |
else: | |
return iterable | |
else: | |
return value | |
# custom_fwd is a decorator that may or may not be used with arguments, following | |
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument. | |
# this works: | |
# @custom_fwd | |
# def forward(...): | |
# this also works: | |
# @custom_fwd(cast_inputs=torch.float) | |
# def forward(...): | |
def custom_fwd(fwd=None, *, cast_inputs=None): | |
""" | |
Create a helper decorator for ``forward`` methods of custom autograd functions. | |
Autograd functions are subclasses of :class:`torch.autograd.Function`. | |
See the :ref:`example page<amp-custom-examples>` for more detail. | |
Args: | |
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, | |
when ``forward`` runs in an autocast-enabled region, casts incoming | |
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), | |
then executes ``forward`` with autocast disabled. | |
If ``None``, ``forward``'s internal ops execute with the current autocast state. | |
.. note:: | |
If the decorated ``forward`` is called outside an autocast-enabled region, | |
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect. | |
""" | |
if fwd is None: | |
return functools.partial(custom_fwd, cast_inputs=cast_inputs) | |
def decorate_fwd(*args, **kwargs): | |
args[0]._dtype = torch.get_autocast_gpu_dtype() | |
if cast_inputs is None: | |
args[0]._fwd_used_autocast = torch.is_autocast_enabled() | |
return fwd(*args, **kwargs) | |
else: | |
autocast_context = torch.is_autocast_enabled() | |
args[0]._fwd_used_autocast = False | |
if autocast_context: | |
with autocast(enabled=False): | |
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs)) | |
else: | |
return fwd(*args, **kwargs) | |
return decorate_fwd | |
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate | |
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match | |
# cast_inputs supplied to custom_fwd. | |
def custom_bwd(bwd): | |
"""Create a helper decorator for backward methods of custom autograd functions. | |
Autograd functions are subclasses of :class:`torch.autograd.Function`. | |
Ensures that ``backward`` executes with the same autocast state as ``forward``. | |
See the :ref:`example page<amp-custom-examples>` for more detail. | |
""" | |
def decorate_bwd(*args, **kwargs): | |
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype): | |
return bwd(*args, **kwargs) | |
return decorate_bwd | |