Image-Generator / torch /_functorch /autograd_function.py
Adi-69s's picture
Upload 5061 files
b2659ad verified
import torch
from torch._ops import HigherOrderOperator
from torch._C._functorch import TransformType
from torch._functorch.utils import enable_single_level_autograd_function
import torch.utils._pytree as pytree
from torch._C._functorch import (
_wrap_for_grad,
_unwrap_for_grad,
current_level,
)
from torch._functorch.vmap import (
wrap_batched,
unwrap_batched,
restore_vmap,
_add_batch_dim,
)
from torch._functorch.apis import vmap
from torch._functorch.vmap import _broadcast_to_and_flatten
from torch.autograd.forward_ad import _set_fwd_grad_enabled
from typing import Any, NamedTuple, Tuple
# autograd.Function technically runs before the regular PyTorch dispatcher.
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
# work with it. One day we might decide to change this, but until then,
# we need to give the illusion that autograd.Function runs before those things.
#
# We do this by using creating a custom HigherOrderOperator that only functorch
# dispatches specially.
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
def __init__(self):
super().__init__('custom_function_call')
def __call__(self, autograd_function, *args, **kwargs):
# When custom_function_call is done dispatching through functorch,
# it should just invoke the autograd.Function. This is consistent
# with the autograd.Function behavior of being invoked before the
# PyTorch dispatcher.
#
# This will lead us into trouble later down the line, but this is
# pre-existing. There is an invariant that a function traced by
# make_fx should have the same behavior when provided the same
# Tensor. However, make_fx sees autograd.Function as a composite
# (because autograd.Function happens before the Python dispatch key)
# and only traces the forward pass.
if torch._C._are_functorch_transforms_active():
return super().__call__(autograd_function, *args, **kwargs)
return autograd_function.apply(*args, **kwargs)
# "custom_function_call"
# This is the mechanism for an autograd.Function that works with functorch transforms.
# It wraps an autograd.Function; interactions with functorch transforms are defined
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
# dispatcher.
custom_function_call = CustomFunctionHigherOrderOperator()
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
# (autograd.Function that only works with a single layer (level) of functorch) that:
# - unwraps the inputs
# - redispatches to custom_function_call
# - wraps the outputs
# and whose backward pass calls the original autograd.Function's backward.
#
# Why do we need to redispatch to custom_function_call?
# -----------------------------------------------------
# This is consistent with how ATen operators work with functorch's grad transform:
# they always redispatch to the original operator.
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
#
# grad1 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin (*)
# - rewrap the outputs on the return
#
# On the redispatch in (*), grad0 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin
# - rewrap the outputs on the return
#
# To "set up the autograd graph", we generate a _SingleLevelFunction
# and apply it.
@custom_function_call.py_impl(TransformType.Grad)
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_single_level_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out
def generate_single_level_function(interpreter, autograd_function):
level = interpreter.level()
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor,
lambda x: _unwrap_for_grad(x, level),
operands)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)
# See NOTE [mark_dirty object identity check]
def wrap_fn(output):
return _wrap_for_grad(output, level)
return wrap_outputs_maintaining_identity(
unwrapped_output,
unwrapped_operands,
operands,
wrap_fn)
def setup_context(ctx, inputs, output):
return autograd_function.setup_context(ctx, inputs, output)
# backward is only used if the transform is TransformType.Grad
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
# This is the sequence of magic words to dynamically generate a Subclass with
# a given name. A Tensor's .grad_fn field has a class name that is the original
# autograd.Function's name + Backward, so we do this to generate some
# meaningful name.
name = f'{autograd_function.__name__}Generated'
Generated = type(
name,
(torch.autograd.function._SingleLevelFunction,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
'jvp': staticmethod(jvp),
'setup_context': staticmethod(setup_context),
},
)
return Generated
# wrap_outputs_maintaining_identity handles outputs from the vmap,
# backward (vjp), and jvp staticmethod. The way it distinguishes
# between the vmap case and the {backward, jvp} case is if the out_dims
# are specified or not.
#
# NB: we cannot use out_dims=None as the deciding factor. This because
# out_dims=None can still happen in the vmap staticmethod! What the
# user is saying in that case is that their output does not have a
# dimension that is being vmapped over, which is valid.
NO_OUT_DIMS = "not specified"
# NOTE [mark_dirty object identity check]
# autograd.Function's ctx.mark_dirty expect a returned input
# to have the same object identity as the input.
# Mode-only functorch will greatly simplify this logic.
def wrap_outputs_maintaining_identity(
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
unwrapped_input_to_orig_input = {
id(unwrapped): orig
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
}
flat_outputs, spec = pytree.tree_flatten(outputs)
result = []
out_dims_specified = out_dims != NO_OUT_DIMS
if out_dims_specified:
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
# _broadcast_to_and_flatten returns None if it is unable to broadcast.
# TODO: update following link from master to stable once that's out
if flat_out_dims is None:
raise RuntimeError(
f"The autograd.Function's vmap staticmethod returned an "
f"incompatible (output, out_dims) tuple. "
f"Expected out_dims={out_dims} "
f"to be compatible with the structure of `output`. "
f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
f"but output has structure {spec}. "
f"For more details, please see "
f"https://pytorch.org/docs/master/notes/extending.func.html"
)
for i, output in enumerate(flat_outputs):
if not isinstance(output, torch.Tensor):
result.append(output)
continue
if id(output) in unwrapped_input_to_orig_input:
result.append(unwrapped_input_to_orig_input[id(output)])
continue
if out_dims_specified:
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index]
else:
result.append(wrap_fn(output))
return pytree.tree_unflatten(result, spec)
# NOTE: [functorch vjp and autograd interaction]
# There's an edge case with the functorch vjp and autograd interaction
# that will eventually be fixed by mode-only functorch.
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
# so we (the framework) need to do it manually. Regular PyTorch operators
# automatically do so this is consistent.
#
# class MyExp(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.exp()
#
# @staticmethod
# def setup_context(ctx, inputs, output):
# y = output
# ctx.save_for_backward(y)
#
# @staticmethod
# def backward(gy):
# y, = ctx.saved_tensors()
# return MyMul.apply(gy, y)
#
# x = torch.randn([], requires_grad=True)
# gy = torch.randn([], requires_grad=True)
# _, vjp_fn = vjp(MySin.apply, x)
# result = vjp_fn(gy)
#
# MyMul is an autograd.Function that is not shown here.
# It saves a `y` for backward (since gy requires grad).
#
# in vjp_fn(gy), we get:
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
# but is now dead since we are outside the vjp context.
#
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
# will automatically unwrap the GradTensorWrapper when applied.
# But since autograd.Function technically sits above the regular PyTorch
# dispatcher, it doesn't get this treatment. So we manually do
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
class VmapInfo(NamedTuple):
batch_size: int
randomness: str
def has_overriden_vmap_rule(autograd_function):
return autograd_function.vmap is not torch.autograd.Function.vmap
def validate_vmap_returns_tuple_of_two_elements(result):
base_error_msg = (
"Expected the vmap staticmethod to have two returns, an output "
"and out_dims with pytree structure compatible with the output. "
)
if not isinstance(result, tuple):
raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
if not len(result) == 2:
raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
@custom_function_call.py_impl(TransformType.Vmap)
def custom_function_call_vmap(interpreter, autograd_function, *operands):
if autograd_function.generate_vmap_rule:
if has_overriden_vmap_rule(autograd_function):
# TODO: Update link to stable once that's out
# https://github.com/pytorch/pytorch/issues/92029
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it has both generate_vmap_rule=True and an overriden vmap "
f"staticmethod. Please set generate_vmap_rule=False or delete "
f"the overriden vmap staticmethod to avoid ambiguity. "
f"For more details, please see "
f"https://pytorch.org/docs/master/notes/extending.func.html")
return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands)
if not has_overriden_vmap_rule(autograd_function):
# TODO: Update link to stable once that's out
# https://github.com/pytorch/pytorch/issues/92029
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it does not have vmap support. Please override and implement the "
f"vmap staticmethod or set generate_vmap_rule=True. "
f"For more details, please see "
f"https://pytorch.org/docs/master/notes/extending.func.html")
current_level = interpreter.level()
info = VmapInfo(
batch_size=interpreter.batch_size(),
randomness=interpreter.randomness(),
)
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
return custom_function_call(autograd_function, *operands)
with interpreter.lower():
result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
validate_vmap_returns_tuple_of_two_elements(result)
unwrapped_output, out_dims = result
# See NOTE [mark_dirty object identity check]
def wrap_fn(output, out_dim):
return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)
return wrap_outputs_maintaining_identity(
unwrapped_output,
unwrapped_operands,
operands,
wrap_fn,
out_dims=out_dims)
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
vmapped_function, get_out_dims = vmapify_autograd_function(
autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness())
with interpreter.lower():
output = custom_function_call(vmapped_function, *unwrapped_operands)
out_dims = get_out_dims()
return wrap_batched(output, out_dims, interpreter.level())
@custom_function_call.py_impl(TransformType.Functionalize)
def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands):
raise RuntimeError("NYI: Functionalize rule for custom_function_call")
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
# The following values are saved from the forward() and setup_context()
# and used in backward().
# Why do we save the values out here instead of on the ctx object?
# - out_dims: There's no way to retrieve this from forward()
# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
# vmap(vmap( but not completely sure if it is a problem. If we
# assigned those fields to the ctx object, the worry is that they
# get overwritten.
init_val = "not populated"
out_dims = init_val
input_shapes: Any = init_val
saved_tensors_bdims: Any = init_val
def forward(*operands):
nonlocal out_dims
outputs, out_dims = restore_vmap(
autograd_function.forward, in_dims, batch_size, randomness)(*operands)
return outputs
def setup_context(ctx, inputs, outputs):
input_shapes_ = None
saved_tensors_bdims_ = None
def inner(inputs, outputs):
# wrapped_ctx.save_for_backward will:
# - unwrap batchedtensors into (tensor, bdim)
# - save_for_backward(*unwrapped_tensors)
# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
wrapped_ctx = CtxCustomSave(ctx, current_level())
autograd_function.setup_context(wrapped_ctx, inputs, outputs)
# input_shapes are used for reductify later to reduce expanded gradients
# to the correct shape.
# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
# for more details
nonlocal input_shapes_
input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None
for inp in inputs)
nonlocal saved_tensors_bdims_
saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
# See NOTE: [Why do we need to run setup_context under a vmap?]
restore_vmap(
inner,
(in_dims, out_dims),
batch_size,
randomness,
)(inputs, outputs)
nonlocal input_shapes
input_shapes = input_shapes_
nonlocal saved_tensors_bdims
saved_tensors_bdims = saved_tensors_bdims_
def jvp(ctx, *tangents):
assert out_dims != init_val
assert saved_tensors_bdims != init_val
def jvp_no_context(saved_tensors, tangents):
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
return autograd_function.jvp(wrapped_ctx, *tangents)
tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
out_tangents, out_tangents_dims = restore_vmap(
jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)(
ctx.saved_tensors, tangents)
result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
return result
def backward(ctx, *grad_outputs):
assert out_dims != init_val
assert input_shapes != init_val
assert saved_tensors_bdims != init_val
def backward_no_context(inputs):
saved_tensors, grad_outputs = inputs
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
return autograd_function.backward(wrapped_ctx, *grad_outputs)
grad_ins, grad_ins_dims = restore_vmap(
backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)(
(ctx.saved_tensors, grad_outputs))
result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
return result
name = f'Vmapped{autograd_function.__name__}'
Generated = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
'jvp': staticmethod(jvp),
'setup_context': staticmethod(setup_context),
'generate_vmap_rule': True
}
)
def get_out_dims():
assert out_dims != init_val
return out_dims
return Generated, get_out_dims
# tangents might be None, so we need to replace
# the corresponding in_dims with None.
def get_tangents_in_dims(input_dims, tangents):
flat_in_dims, spec = pytree.tree_flatten(input_dims)
flat_tangents = pytree.arg_tree_leaves(*tangents)
result = [None if tangent is None else in_dim
for in_dim, tangent in zip(flat_in_dims, flat_tangents)]
return pytree.tree_unflatten(result, spec)
# NOTE: [Why do we need to run setup_context under a vmap?]
# Consider the following autograd.Function
#
# class Sum(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.sum()
# @staticmethod
# def setup_context(ctx, inputs, outputs):
# ctx.x_shape = inputs[0]
# @staticmethod
# def backward(ctx, gy):
# return gy.expand(ctx.x_shape)
#
# x = torch.randn(B, 4)
# in_dims = 0
# vmap(Sum.apply, in_dims)(x)
#
# Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum:
#
# class VmappedSum(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return vmap(Sum.forward, in_dims)(x)
#
# @staticmethod
# def setup_context(ctx, inputs, outputs):
# Sum.setup_context(ctx, inputs, outputs)
#
# @staticmethod
# def backward(ctx, gy):
# def backward_no_context(gy):
# return gy.expand(ctx.x_shape)
#
# dims = (0,)
# gx = vmap(backward_no_context, dims)(gy)
# return gx
#
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
# and we’re doing:
#
# def backward_no_context(gy):
# return gy.expand([B, 4])
#
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
#
# This gives us the wrong result (gx has shape [B, B, 4], but it should
# have shape [4]). Performing vmap over setup_context means the shape
# saved has shape [4] and leads to a correct result shape for gx.
# Wraps a ctx object. Forwards all attr accesses to the underlying object
# except for the attrs in _pt_attrs
class WrappedCtx:
_pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')
def __init__(self, ctx):
if not isinstance(ctx, WrappedCtx):
reserved_attrs = type(self)._pt_reserved_attrs
for name in reserved_attrs:
if not hasattr(ctx, name):
continue
raise RuntimeError(
f'PyTorch reserves the {reserved_attrs} field on ctx. '
'Please name your fields on ctx something else to avoid name '
'collision.')
self._pt_inner_ctx = ctx
def __getattr__(self, name):
return getattr(self._pt_inner_ctx, name)
def __setattr__(self, name, value):
if name in type(self)._pt_reserved_attrs:
self.__dict__[name] = value
return
return setattr(self._pt_inner_ctx, name, value)
# Wraps ctx to create a new ctx object that overrides saved_tensors.
class CtxWithSavedTensors(WrappedCtx):
_pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)
def __init__(self, ctx, new_saved_tensors):
super().__init__(ctx)
self._pt_new_saved_tensors = new_saved_tensors
@property
def saved_tensors(self):
return self._pt_new_saved_tensors
class CtxCustomSave(WrappedCtx):
_pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level',
*WrappedCtx._pt_reserved_attrs)
def __init__(self, ctx, current_level):
super().__init__(ctx)
self._pt_saved_tensors_bdims = ()
self._pt_current_level = current_level
def save_for_backward(self, *tensors):
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
self._pt_saved_tensors_bdims = bdims
def save_for_forward(self, *tensors):
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
self._pt_saved_tensors_bdims = bdims
def reductify(grad_input, grad_input_bdim, input_bdim, batch_size,
target_shape_without_bdim_to_reduce_to=None):
if not isinstance(grad_input, tuple):
grad_input = (grad_input,)
if not isinstance(grad_input_bdim, tuple):
grad_input_bdim = (grad_input_bdim,)
if not isinstance(input_bdim, tuple):
input_bdim = (input_bdim,)
if target_shape_without_bdim_to_reduce_to is None:
target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
result = tuple(
reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
for gi, gi_bdim, i_bdim, maybe_ishape in
zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to)
)
return result
def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size,
target_shape_without_bdim_to_reduce_to=None):
if grad_input is None:
return None
if grad_input_bdim is None and input_bdim is None:
return grad_input
if grad_input_bdim is not None and input_bdim is None:
return grad_input.sum(grad_input_bdim)
# NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
# For reverse-mode AD,
# given a grad_input and input, it is valid for the user to return a
# grad_input that has a broadcasted shape when compared to the input.
# In this situation, autograd automatically reduces the grad_input to
# the shape of the input.
#
# However, when input_bdim is not None, we have problems.
#
# [example 1]
# grad_input: Tensor[3, 4], input: Tensor[B, 4]
# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# [example 2]
# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# This means that we need to also reduce the grad_input to the shape of the
# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
# if not-None then we do the reducing manually, otherwise, we do not do a reduction.
assert input_bdim is not None
if grad_input_bdim is None:
grad_input = grad_input.unsqueeze(input_bdim)
new_shape = list(grad_input.shape)
new_shape[input_bdim] = batch_size
grad_input = grad_input.expand(new_shape)
grad_input_bdim = input_bdim
if target_shape_without_bdim_to_reduce_to is not None:
return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(
grad_input, target_shape_without_bdim_to_reduce_to)
if input_bdim != grad_input_bdim:
grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
return grad_input