Spaces:
Runtime error
Runtime error
# Copyright 2019 Kakao Brain | |
# | |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Checkpointing with preceding recomputation. | |
PyTorch already provides the official checkpointing utilities in | |
:mod:`torch.utils.checkpoint`. The official checkpointing combines | |
recomputation and recursive backpropagation into one autograd function named | |
``CheckpointFunction``. Hence, the recomputation can be started only when the | |
gradients arrive to the function. In Pipe, the recomputation needs to precede | |
the gradient arrival to minimize the GPU idle time. | |
We solve this problem by introducing separate autograd functions named | |
:class:`Recompute` and :class:`Checkpoint`. Each function represents | |
recomputation and recursive backpropagation, respectively. We can manipulate | |
the control flow in aspect of both the autograd engine and CUDA with a pair of | |
the functions. | |
Specifically, we place CUDA stream synchronization between :class:`Recompute` | |
and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is | |
copied entirely. | |
""" | |
from collections import deque | |
from contextlib import contextmanager | |
import threading | |
from typing import ( | |
Any, | |
Deque, | |
Generator, | |
List, | |
Optional, | |
Protocol, | |
Union, | |
Sequence, | |
Tuple | |
) | |
import torch | |
from torch import Tensor | |
import torch.autograd | |
from .dependency import fork, join | |
from .microbatch import Batch | |
from .phony import get_phony | |
__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", | |
"enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", | |
"restore_rng_states", "Checkpoint", "Recompute"] | |
Tensors = Sequence[Tensor] | |
TensorOrTensors = Union[Tensor, Tensors] | |
# Types for shared memory between Checkpoint and Recompute. | |
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) | |
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) | |
# Protocol with __call__ instead of Callable can be used as an attribute type. | |
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 | |
class Function(Protocol): | |
def __call__(self, input: TensorOrTensors) -> TensorOrTensors: | |
... | |
def checkpoint(function: Function, input): | |
"""Make a checkpoint with a simple interface like | |
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug | |
:class:`Checkpoint` and :class:`Recompute` without boilerplate. | |
""" | |
batch = Batch(input) | |
chk = Checkpointing(function, batch) | |
batch = chk.checkpoint() | |
chk.recompute(batch) | |
return batch.values | |
class Checkpointing: | |
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" | |
def __init__(self, function: Function, batch: Batch) -> None: | |
self.function = function | |
self.batch = batch | |
# Shared memory between Checkpoint and Recompute. 1-length deque is | |
# used for mutability and length limitation. | |
self.recomputed: Deque[Recomputed] = deque(maxlen=1) | |
self.rng_states: Deque[RNGStates] = deque(maxlen=1) | |
def checkpoint(self) -> Batch: | |
"""Return a batch applied by :class:`Checkpoint`.""" | |
input_atomic = self.batch.atomic | |
inputs = tuple(self.batch) | |
# Use a phony which requires grad to ensure that Checkpoint can be | |
# tracked by the autograd engine even when none of the input tensors | |
# require grad. | |
phony = get_phony(self.batch.get_device(), requires_grad=True) | |
output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) | |
# Gradients are only supported for float Tensors. | |
if isinstance(output, tuple): | |
output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) | |
return Batch(output) | |
def recompute(self, batch: Batch) -> None: | |
"""Apply :class:`Recompute` to the batch in place.""" | |
input_atomic = self.batch.atomic | |
inputs = tuple(self.batch) | |
# Use a tensor in the batch to tie together fork-join | |
tensor_idx = batch.find_tensor_idx() | |
# batch[tensor_idx] is always requiring grad, because it has been passed | |
# checkpoint with a phony requiring grad. | |
batch[tensor_idx], phony = fork(batch[tensor_idx]) | |
phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) | |
batch[tensor_idx] = join(batch[tensor_idx], phony) | |
class ThreadLocal(threading.local): | |
def __init__(self) -> None: | |
self.is_checkpointing = False | |
self.is_recomputing = False | |
thread_local = ThreadLocal() | |
def enable_checkpointing() -> Generator[None, None, None]: | |
"""Make :func:`is_checkpointing` return :data:`True` within a context.""" | |
orig = thread_local.is_checkpointing | |
thread_local.is_checkpointing = True | |
try: | |
yield | |
finally: | |
thread_local.is_checkpointing = orig | |
def enable_recomputing() -> Generator[None, None, None]: | |
"""Makes :func:`is_recomputing` return :data:`True` within a context.""" | |
orig = thread_local.is_recomputing | |
thread_local.is_recomputing = True | |
try: | |
yield | |
finally: | |
thread_local.is_recomputing = orig | |
def is_checkpointing() -> bool: | |
"""Whether the current forward propagation is under checkpointing. | |
Returns: | |
bool: :data:`True` if it's under checkpointing. | |
""" | |
return thread_local.is_checkpointing | |
def is_recomputing() -> bool: | |
"""Whether the current forward propagation is under checkpoint recomputation. | |
Use this to prevent duplicated side-effects at forward | |
propagation:: | |
class Counter(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.counter = 0 | |
def forward(self, input): | |
if not is_recomputing(): | |
self.counter += 1 | |
return input | |
Returns: | |
bool: :data:`True` if it's under checkpoint recomputation. | |
.. seealso:: :ref:`Detecting Recomputation` | |
""" | |
return thread_local.is_recomputing | |
class Context: | |
"""The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" | |
recomputed: Deque[Recomputed] | |
rng_states: Deque[RNGStates] | |
function: Function | |
input_atomic: bool | |
inputs: Sequence[Any] | |
saved_tensors: Tuple[Tensor, ...] | |
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover | |
pass | |
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: | |
""": | |
Capture the current random number generator states. | |
meth:`Checkpoint.forward` captures the current PyTorch's random number | |
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. | |
.. seealso:: :ref:`Referential Transparency` | |
""" | |
cpu_rng_state = torch.get_rng_state() | |
gpu_rng_state: Optional[Tensor] | |
if device.type == "cuda": | |
gpu_rng_state = torch.cuda.get_rng_state(device) | |
else: | |
gpu_rng_state = None | |
rng_states.append((cpu_rng_state, gpu_rng_state)) | |
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: | |
""": | |
Restore the random number generator state. | |
meth:`Recompute.backward` restores the random number generator states | |
captured by :func:`save_rng_states` within its context. | |
.. seealso:: :ref:`Referential Transparency` | |
""" | |
cpu_rng_state, gpu_rng_state = rng_states.pop() | |
gpu_devices: List[torch.device] = [] | |
if device.type == "cuda": | |
gpu_devices.append(device) | |
with torch.random.fork_rng(gpu_devices): | |
torch.set_rng_state(cpu_rng_state) | |
if gpu_rng_state is not None: | |
torch.cuda.set_rng_state(gpu_rng_state, device) | |
yield | |
class Checkpoint(torch.autograd.Function): | |
# type: ignore[override] | |
def forward( | |
ctx: Context, | |
phony: Tensor, | |
recomputed: Deque[Recomputed], | |
rng_states: Deque[RNGStates], | |
function: Function, | |
input_atomic: bool, | |
*inputs, | |
): | |
ctx.recomputed = recomputed | |
ctx.rng_states = rng_states | |
save_rng_states(phony.device, ctx.rng_states) | |
ctx.function = function | |
ctx.input_atomic = input_atomic | |
if input_atomic: | |
tensors = [inputs[0]] | |
else: | |
tensors = [] | |
for input in inputs: | |
if torch.is_tensor(input): | |
tensors.append(input) | |
ctx.save_for_backward(*tensors) | |
with torch.no_grad(), enable_checkpointing(): | |
if input_atomic: | |
assert len(inputs) == 1 | |
output = function(inputs[0]) | |
else: | |
output = function(*inputs) | |
return output | |
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover | |
output, input_leaf = ctx.recomputed.pop() | |
if isinstance(output, tuple): | |
outputs = output | |
else: | |
outputs = (output,) | |
if any(torch.is_tensor(y) and y.requires_grad for y in outputs): | |
tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) | |
torch.autograd.backward(tensors, grad_output) | |
grad_input: List[Optional[Tensor]] = [None, None, None, None, None] | |
grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) | |
return tuple(grad_input) | |
class Recompute(torch.autograd.Function): | |
# type: ignore[override] | |
def forward( | |
ctx: Context, | |
phony: Tensor, | |
recomputed: Deque[Recomputed], | |
rng_states: Deque[RNGStates], | |
function: Function, | |
input_atomic: bool, | |
*inputs, | |
) -> Tensor: | |
ctx.recomputed = recomputed | |
ctx.rng_states = rng_states | |
ctx.function = function | |
ctx.input_atomic = input_atomic | |
ctx.inputs = inputs | |
if input_atomic: | |
tensors = [inputs[0]] | |
else: | |
tensors = [] | |
for input in inputs: | |
if torch.is_tensor(input): | |
tensors.append(input) | |
ctx.save_for_backward(*tensors) | |
return phony | |
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover | |
inputs = ctx.inputs | |
inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) | |
# Get the device for the inputs from a tensor | |
device = None | |
for input in inputs: | |
if torch.is_tensor(input): | |
device = input.device | |
break | |
if device is None: | |
raise RuntimeError(f'No tensors found in {inputs}') | |
with restore_rng_states(device, ctx.rng_states): | |
with torch.enable_grad(), enable_recomputing(): | |
if ctx.input_atomic: | |
assert len(inputs_leaf) == 1 | |
output = ctx.function(inputs_leaf[0]) | |
else: | |
output = ctx.function(*inputs_leaf) | |
ctx.recomputed.append((output, inputs_leaf)) | |
grad_input: List[None] = [None, None, None, None, None] | |
grad_input.extend(None for _ in ctx.inputs) | |
return tuple(grad_input) | |