|
from typing import Dict, List |
|
|
|
import torch |
|
|
|
if torch.__version__ < '1.9': |
|
Iterable = torch._six.container_abcs.Iterable |
|
else: |
|
import collections |
|
|
|
Iterable = collections.abc.Iterable |
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
class _MultiDeviceReplicator(object): |
|
""" |
|
Lazily serves copies of a tensor to requested devices. Copies are cached per-device. |
|
""" |
|
|
|
def __init__(self, master_tensor: torch.Tensor) -> None: |
|
assert master_tensor.is_cuda |
|
self.master = master_tensor |
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} |
|
|
|
def get(self, device) -> torch.Tensor: |
|
retval = self._per_device_tensors.get(device, None) |
|
if retval is None: |
|
retval = self.master.to(device=device, non_blocking=True, copy=True) |
|
self._per_device_tensors[device] = retval |
|
return retval |
|
|
|
|
|
class MaxClipGradScaler(GradScaler): |
|
def __init__(self, init_scale, max_scale: float, growth_interval=100): |
|
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) |
|
self.max_scale = max_scale |
|
|
|
def scale_clip(self): |
|
if self.get_scale() == self.max_scale: |
|
self.set_growth_factor(1) |
|
elif self.get_scale() < self.max_scale: |
|
self.set_growth_factor(2) |
|
elif self.get_scale() > self.max_scale: |
|
self._scale.fill_(self.max_scale) |
|
self.set_growth_factor(1) |
|
|
|
def scale(self, outputs): |
|
""" |
|
Multiplies ('scales') a tensor or list of tensors by the scale factor. |
|
|
|
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned |
|
unmodified. |
|
|
|
Arguments: |
|
outputs (Tensor or iterable of Tensors): Outputs to scale. |
|
""" |
|
if not self._enabled: |
|
return outputs |
|
self.scale_clip() |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
assert outputs.is_cuda |
|
if self._scale is None: |
|
self._lazy_init_scale_growth_tracker(outputs.device) |
|
assert self._scale is not None |
|
return outputs * self._scale.to(device=outputs.device, non_blocking=True) |
|
|
|
|
|
stash: List[_MultiDeviceReplicator] = [] |
|
|
|
def apply_scale(val): |
|
if isinstance(val, torch.Tensor): |
|
assert val.is_cuda |
|
if len(stash) == 0: |
|
if self._scale is None: |
|
self._lazy_init_scale_growth_tracker(val.device) |
|
assert self._scale is not None |
|
stash.append(_MultiDeviceReplicator(self._scale)) |
|
return val * stash[0].get(val.device) |
|
elif isinstance(val, Iterable): |
|
iterable = map(apply_scale, val) |
|
if isinstance(val, list) or isinstance(val, tuple): |
|
return type(val)(iterable) |
|
else: |
|
return iterable |
|
else: |
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors") |
|
|
|
return apply_scale(outputs) |
|
|