Spaces:
Build error
Build error
from typing import Dict, List | |
import torch | |
if torch.__version__ < '1.9': | |
Iterable = torch._six.container_abcs.Iterable | |
else: | |
import collections | |
Iterable = collections.abc.Iterable | |
from torch.cuda.amp import GradScaler | |
class _MultiDeviceReplicator(object): | |
""" | |
Lazily serves copies of a tensor to requested devices. Copies are cached per-device. | |
""" | |
def __init__(self, master_tensor: torch.Tensor) -> None: | |
assert master_tensor.is_cuda | |
self.master = master_tensor | |
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} | |
def get(self, device) -> torch.Tensor: | |
retval = self._per_device_tensors.get(device, None) | |
if retval is None: | |
retval = self.master.to(device=device, non_blocking=True, copy=True) | |
self._per_device_tensors[device] = retval | |
return retval | |
class MaxClipGradScaler(GradScaler): | |
def __init__(self, init_scale, max_scale: float, growth_interval=100): | |
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) | |
self.max_scale = max_scale | |
def scale_clip(self): | |
if self.get_scale() == self.max_scale: | |
self.set_growth_factor(1) | |
elif self.get_scale() < self.max_scale: | |
self.set_growth_factor(2) | |
elif self.get_scale() > self.max_scale: | |
self._scale.fill_(self.max_scale) | |
self.set_growth_factor(1) | |
def scale(self, outputs): | |
""" | |
Multiplies ('scales') a tensor or list of tensors by the scale factor. | |
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned | |
unmodified. | |
Arguments: | |
outputs (Tensor or iterable of Tensors): Outputs to scale. | |
""" | |
if not self._enabled: | |
return outputs | |
self.scale_clip() | |
# Short-circuit for the common case. | |
if isinstance(outputs, torch.Tensor): | |
assert outputs.is_cuda | |
if self._scale is None: | |
self._lazy_init_scale_growth_tracker(outputs.device) | |
assert self._scale is not None | |
return outputs * self._scale.to(device=outputs.device, non_blocking=True) | |
# Invoke the more complex machinery only if we're treating multiple outputs. | |
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale | |
def apply_scale(val): | |
if isinstance(val, torch.Tensor): | |
assert val.is_cuda | |
if len(stash) == 0: | |
if self._scale is None: | |
self._lazy_init_scale_growth_tracker(val.device) | |
assert self._scale is not None | |
stash.append(_MultiDeviceReplicator(self._scale)) | |
return val * stash[0].get(val.device) | |
elif isinstance(val, Iterable): | |
iterable = map(apply_scale, val) | |
if isinstance(val, list) or isinstance(val, tuple): | |
return type(val)(iterable) | |
else: | |
return iterable | |
else: | |
raise ValueError("outputs must be a Tensor or an iterable of Tensors") | |
return apply_scale(outputs) | |