|
import torch |
|
from typing import Tuple, List |
|
from . import forward_ad as fwAD |
|
from torch._vmap_internals import _vmap |
|
|
|
|
|
|
|
def _as_tuple_nocheck(x): |
|
if isinstance(x, tuple): |
|
return x |
|
elif isinstance(x, list): |
|
return tuple(x) |
|
else: |
|
return x, |
|
|
|
def _as_tuple(inp, arg_name=None, fn_name=None): |
|
|
|
|
|
if arg_name is None and fn_name is None: |
|
return _as_tuple_nocheck(inp) |
|
|
|
is_inp_tuple = True |
|
if not isinstance(inp, tuple): |
|
inp = (inp,) |
|
is_inp_tuple = False |
|
|
|
for i, el in enumerate(inp): |
|
if not isinstance(el, torch.Tensor): |
|
if is_inp_tuple: |
|
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the" |
|
" value at index {} has type {}.".format(arg_name, fn_name, i, type(el))) |
|
else: |
|
raise TypeError("The {} given to {} must be either a Tensor or a tuple of Tensors but the" |
|
" given {} has type {}.".format(arg_name, fn_name, arg_name, type(el))) |
|
|
|
return is_inp_tuple, inp |
|
|
|
def _tuple_postprocess(res, to_unpack): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(to_unpack, tuple): |
|
assert len(to_unpack) == 2 |
|
if not to_unpack[1]: |
|
res = tuple(el[0] for el in res) |
|
if not to_unpack[0]: |
|
res = res[0] |
|
else: |
|
if not to_unpack: |
|
res = res[0] |
|
return res |
|
|
|
def _grad_preprocess(inputs, create_graph, need_graph): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = [] |
|
for inp in inputs: |
|
if create_graph and inp.requires_grad: |
|
|
|
if not inp.is_sparse: |
|
|
|
res.append(inp.view_as(inp)) |
|
else: |
|
|
|
res.append(inp.clone()) |
|
else: |
|
res.append(inp.detach().requires_grad_(need_graph)) |
|
return tuple(res) |
|
|
|
|
|
def _grad_postprocess(inputs, create_graph): |
|
|
|
|
|
if isinstance(inputs[0], torch.Tensor): |
|
if not create_graph: |
|
return tuple(inp.detach() for inp in inputs) |
|
else: |
|
return inputs |
|
else: |
|
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) |
|
|
|
def _validate_v(v, other, is_other_tuple): |
|
|
|
|
|
if len(other) != len(v): |
|
if is_other_tuple: |
|
raise RuntimeError("v is a tuple of invalid length: should be {} but got {}.".format(len(other), len(v))) |
|
else: |
|
raise RuntimeError("The given v should contain a single Tensor.") |
|
|
|
for idx, (el_v, el_other) in enumerate(zip(v, other)): |
|
if el_v.size() != el_other.size(): |
|
prepend = "" |
|
if is_other_tuple: |
|
prepend = "Entry {} in ".format(idx) |
|
raise RuntimeError("{}v has invalid size: should be {} but got {}.".format( |
|
prepend, el_other.size(), el_v.size())) |
|
|
|
|
|
def _check_requires_grad(inputs, input_type, strict): |
|
|
|
if not strict: |
|
return |
|
|
|
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: |
|
raise RuntimeError("Invalid input_type to _check_requires_grad") |
|
for i, inp in enumerate(inputs): |
|
if inp is None: |
|
|
|
raise RuntimeError("The output of the user-provided function is independent of input {}." |
|
" This is not allowed in strict mode.".format(i)) |
|
if not inp.requires_grad: |
|
if input_type == "hessian": |
|
raise RuntimeError("The hessian of the user-provided function with respect to input {}" |
|
" is independent of the input. This is not allowed in strict mode." |
|
" You should ensure that your function is thrice differentiable and that" |
|
" the hessian depends on the inputs.".format(i)) |
|
elif input_type == "jacobian": |
|
raise RuntimeError("While computing the hessian, found that the jacobian of the user-provided" |
|
" function with respect to input {} is independent of the input. This is not" |
|
" allowed in strict mode. You should ensure that your function is twice" |
|
" differentiable and that the jacobian depends on the inputs (this would be" |
|
" violated by a linear function for example).".format(i)) |
|
elif input_type == "grad_inputs": |
|
raise RuntimeError("The gradient with respect to input {} is independent of the inputs of the" |
|
" user-provided function. This is not allowed in strict mode.".format(i)) |
|
else: |
|
raise RuntimeError("Output {} of the user-provided function does not require gradients." |
|
" The outputs must be computed in a differentiable manner from the input" |
|
" when running in strict mode.".format(i)) |
|
|
|
def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retain_graph=None, is_grads_batched=False): |
|
|
|
|
|
assert isinstance(outputs, tuple) |
|
if grad_outputs is None: |
|
grad_outputs = (None,) * len(outputs) |
|
assert isinstance(grad_outputs, tuple) |
|
assert len(outputs) == len(grad_outputs) |
|
|
|
new_outputs: Tuple[torch.Tensor, ...] = tuple() |
|
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple() |
|
for out, grad_out in zip(outputs, grad_outputs): |
|
if out is not None and out.requires_grad: |
|
new_outputs += (out,) |
|
new_grad_outputs += (grad_out,) |
|
|
|
if len(new_outputs) == 0: |
|
|
|
return (None,) * len(inputs) |
|
else: |
|
return torch.autograd.grad(new_outputs, inputs, new_grad_outputs, allow_unused=True, |
|
create_graph=create_graph, retain_graph=retain_graph, |
|
is_grads_batched=is_grads_batched) |
|
|
|
def _fill_in_zeros(grads, refs, strict, create_graph, stage): |
|
|
|
|
|
|
|
|
|
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: |
|
raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage)) |
|
|
|
res: Tuple[torch.Tensor, ...] = tuple() |
|
for i, grads_i in enumerate(grads): |
|
if grads_i is None: |
|
if strict: |
|
if stage == "back": |
|
raise RuntimeError("The output of the user-provided function is independent of " |
|
"input {}. This is not allowed in strict mode.".format(i)) |
|
elif stage == "back_trick": |
|
raise RuntimeError("The gradient with respect to the input is independent of entry {}" |
|
" in the grad_outputs when using the double backward trick to compute" |
|
" forward mode gradients. This is not allowed in strict mode.".format(i)) |
|
elif stage == "double_back": |
|
raise RuntimeError("The jacobian of the user-provided function is independent of " |
|
"input {}. This is not allowed in strict mode.".format(i)) |
|
else: |
|
raise RuntimeError("The hessian of the user-provided function is independent of " |
|
"entry {} in the grad_jacobian. This is not allowed in strict " |
|
"mode as it prevents from using the double backward trick to " |
|
"replace forward mode AD.".format(i)) |
|
|
|
grads_i = torch.zeros_like(refs[i]) |
|
else: |
|
if strict and create_graph and not grads_i.requires_grad: |
|
if "double" not in stage: |
|
raise RuntimeError("The jacobian of the user-provided function is independent of " |
|
"input {}. This is not allowed in strict mode when create_graph=True.".format(i)) |
|
else: |
|
raise RuntimeError("The hessian of the user-provided function is independent of " |
|
"input {}. This is not allowed in strict mode when create_graph=True.".format(i)) |
|
|
|
res += (grads_i,) |
|
|
|
return res |
|
|
|
|
|
|
|
def vjp(func, inputs, v=None, create_graph=False, strict=False): |
|
r"""Function that computes the dot product between a vector ``v`` and the |
|
Jacobian of the given function at the point given by the inputs. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a tuple of Tensors or a Tensor. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
v (tuple of Tensors or Tensor): The vector for which the vector |
|
Jacobian product is computed. Must be the same size as the output |
|
of ``func``. This argument is optional when the output of ``func`` |
|
contains a single element and (if it is not provided) will be set |
|
as a Tensor containing a single ``1``. |
|
create_graph (bool, optional): If ``True``, both the output and result |
|
will be computed in a differentiable way. Note that when ``strict`` |
|
is ``False``, the result can not require gradients or be |
|
disconnected from the inputs. Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we |
|
detect that there exists an input such that all the outputs are |
|
independent of it. If ``False``, we return a Tensor of zeros as the |
|
vjp for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
|
|
Returns: |
|
output (tuple): tuple with: |
|
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` |
|
|
|
vjp (tuple of Tensors or Tensor): result of the dot product with |
|
the same shape as the inputs. |
|
|
|
Example: |
|
|
|
>>> def exp_reducer(x): |
|
... return x.exp().sum(dim=1) |
|
>>> inputs = torch.rand(4, 4) |
|
>>> v = torch.ones(4) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> vjp(exp_reducer, inputs, v) |
|
(tensor([5.7817, 7.2458, 5.7830, 6.7782]), |
|
tensor([[1.4458, 1.3962, 1.3042, 1.6354], |
|
[2.1288, 1.0652, 1.5483, 2.5035], |
|
[2.2046, 1.1292, 1.1432, 1.3059], |
|
[1.3225, 1.6652, 1.7753, 2.0152]])) |
|
|
|
>>> vjp(exp_reducer, inputs, v, create_graph=True) |
|
(tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>), |
|
tensor([[1.4458, 1.3962, 1.3042, 1.6354], |
|
[2.1288, 1.0652, 1.5483, 2.5035], |
|
[2.2046, 1.1292, 1.1432, 1.3059], |
|
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>)) |
|
|
|
>>> def adder(x, y): |
|
... return 2 * x + 3 * y |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> v = torch.ones(2) |
|
>>> vjp(adder, inputs, v) |
|
(tensor([2.4225, 2.3340]), |
|
(tensor([2., 2.]), tensor([3., 3.]))) |
|
""" |
|
|
|
with torch.enable_grad(): |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") |
|
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) |
|
|
|
outputs = func(*inputs) |
|
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vjp") |
|
_check_requires_grad(outputs, "outputs", strict=strict) |
|
|
|
if v is not None: |
|
_, v = _as_tuple(v, "v", "vjp") |
|
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) |
|
_validate_v(v, outputs, is_outputs_tuple) |
|
else: |
|
if len(outputs) != 1 or outputs[0].nelement() != 1: |
|
raise RuntimeError("The vector v can only be None if the " |
|
"user-provided function returns " |
|
"a single Tensor with a single element.") |
|
|
|
enable_grad = True if create_graph else torch.is_grad_enabled() |
|
with torch.set_grad_enabled(enable_grad): |
|
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) |
|
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") |
|
|
|
|
|
outputs = _grad_postprocess(outputs, create_graph) |
|
vjp = _grad_postprocess(vjp, create_graph) |
|
|
|
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vjp, is_inputs_tuple) |
|
|
|
|
|
def jvp(func, inputs, v=None, create_graph=False, strict=False): |
|
r"""Function that computes the dot product between the Jacobian of |
|
the given function at the point given by the inputs and a vector ``v``. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a tuple of Tensors or a Tensor. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
v (tuple of Tensors or Tensor): The vector for which the Jacobian |
|
vector product is computed. Must be the same size as the input of |
|
``func``. This argument is optional when the input to ``func`` |
|
contains a single element and (if it is not provided) will be set |
|
as a Tensor containing a single ``1``. |
|
create_graph (bool, optional): If ``True``, both the output and result |
|
will be computed in a differentiable way. Note that when ``strict`` |
|
is ``False``, the result can not require gradients or be |
|
disconnected from the inputs. Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we |
|
detect that there exists an input such that all the outputs are |
|
independent of it. If ``False``, we return a Tensor of zeros as the |
|
jvp for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
|
|
Returns: |
|
output (tuple): tuple with: |
|
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` |
|
|
|
jvp (tuple of Tensors or Tensor): result of the dot product with |
|
the same shape as the output. |
|
|
|
Note: |
|
``autograd.functional.jvp`` computes the jvp by using the backward of |
|
the backward (sometimes called the double backwards trick). This is not |
|
the most performant way of computing the jvp. Please consider using |
|
`functorch's jvp <https://github.com/pytorch/functorch#jvp>`_ |
|
or the :ref:`low-level forward-mode AD API <forward-mode-ad>` instead. |
|
|
|
Example: |
|
|
|
>>> def exp_reducer(x): |
|
... return x.exp().sum(dim=1) |
|
>>> inputs = torch.rand(4, 4) |
|
>>> v = torch.ones(4, 4) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> jvp(exp_reducer, inputs, v) |
|
(tensor([6.3090, 4.6742, 7.9114, 8.2106]), |
|
tensor([6.3090, 4.6742, 7.9114, 8.2106])) |
|
|
|
>>> jvp(exp_reducer, inputs, v, create_graph=True) |
|
(tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>), |
|
tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>)) |
|
|
|
>>> def adder(x, y): |
|
... return 2 * x + 3 * y |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> v = (torch.ones(2), torch.ones(2)) |
|
>>> jvp(adder, inputs, v) |
|
(tensor([2.2399, 2.5005]), |
|
tensor([5., 5.])) |
|
|
|
""" |
|
|
|
with torch.enable_grad(): |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") |
|
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) |
|
|
|
if v is not None: |
|
_, v = _as_tuple(v, "v", "jvp") |
|
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) |
|
_validate_v(v, inputs, is_inputs_tuple) |
|
else: |
|
if len(inputs) != 1 or inputs[0].nelement() != 1: |
|
raise RuntimeError("The vector v can only be None if the input to " |
|
"the user-provided function is a single Tensor " |
|
"with a single element.") |
|
|
|
outputs = func(*inputs) |
|
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp") |
|
_check_requires_grad(outputs, "outputs", strict=strict) |
|
|
|
|
|
|
|
grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs) |
|
|
|
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) |
|
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict) |
|
|
|
if create_graph: |
|
with torch.enable_grad(): |
|
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) |
|
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") |
|
else: |
|
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) |
|
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") |
|
|
|
|
|
outputs = _grad_postprocess(outputs, create_graph) |
|
jvp = _grad_postprocess(jvp, create_graph) |
|
|
|
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple) |
|
|
|
|
|
def _construct_standard_basis_for(tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(tensors) == len(tensor_numels) |
|
assert len(tensors) > 0 |
|
total_numel = sum(tensor_numels) |
|
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel) |
|
for tensor, tensor_numel in zip(tensors, tensor_numels)) |
|
diag_start_idx = 0 |
|
for chunk, numel in zip(chunks, tensor_numels): |
|
chunk.diagonal(diag_start_idx).fill_(1) |
|
diag_start_idx -= numel |
|
return chunks |
|
|
|
|
|
def _jacfwd(func, inputs, strict=False, vectorize=False): |
|
if strict: |
|
raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` ' |
|
'and `strategy="forward-mode"` are not supported together (yet). ' |
|
'Please either set `strict=False` or ' |
|
'`strategy="reverse-mode"`.') |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") |
|
output_info = [] |
|
|
|
if vectorize: |
|
|
|
input_numels = tuple(input.numel() for input in inputs) |
|
|
|
|
|
tangents = _construct_standard_basis_for(inputs, input_numels) |
|
|
|
|
|
def jvp(tangents): |
|
with fwAD.dual_level(): |
|
dual_inputs = tuple( |
|
fwAD.make_dual(input, tangent.view_as(input)) for input, tangent in zip(inputs, tangents)) |
|
_is_outputs_tuple, dual_outputs = _as_tuple(func(*dual_inputs), "outputs") |
|
output_info.append(_is_outputs_tuple) |
|
jv = [] |
|
primal_outs = [] |
|
for dual_out in dual_outputs: |
|
primal, tangent = fwAD.unpack_dual(dual_out) |
|
primal_outs.append(primal) |
|
if tangent is not None: |
|
jv.append(tangent) |
|
else: |
|
jv.append(torch.zeros_like(primal)) |
|
output_info.append(primal_outs) |
|
return tuple(jv) |
|
|
|
outputs_before_split = _vmap(jvp)(tangents) |
|
is_outputs_tuple, outputs = output_info |
|
|
|
jacobian_input_output = [] |
|
for jac, output_i in zip(outputs_before_split, outputs): |
|
jacobian_output_i_output = [] |
|
for jac, input_j in zip(jac.split(input_numels, dim=0), inputs): |
|
|
|
|
|
jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0) \ |
|
.reshape(tuple([*output_i.shape, *input_j.shape])) |
|
|
|
jacobian_output_i_output.append(jacobian_input_i_output_j) |
|
jacobian_input_output.append(jacobian_output_i_output) |
|
|
|
|
|
return _tuple_postprocess(jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)) |
|
else: |
|
raise NotImplementedError("Computing Jacobian using forward-AD or forward-over-reverse Hessian is" |
|
"only implemented for `vectorize=True`.") |
|
|
|
|
|
def jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, strategy="reverse-mode"): |
|
r"""Function that computes the Jacobian of a given function. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a tuple of Tensors or a Tensor. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
create_graph (bool, optional): If ``True``, the Jacobian will be |
|
computed in a differentiable manner. Note that when ``strict`` is |
|
``False``, the result can not require gradients or be disconnected |
|
from the inputs. Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we |
|
detect that there exists an input such that all the outputs are |
|
independent of it. If ``False``, we return a Tensor of zeros as the |
|
jacobian for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
vectorize (bool, optional): This feature is experimental. |
|
Please consider using |
|
`functorch's jacrev or jacfwd <https://github.com/pytorch/functorch#what-are-the-transforms>`_ |
|
instead if you are looking for something less experimental and more performant. |
|
When computing the jacobian, usually we invoke |
|
``autograd.grad`` once per row of the jacobian. If this flag is |
|
``True``, we perform only a single ``autograd.grad`` call with |
|
``batched_grad=True`` which uses the vmap prototype feature. |
|
Though this should lead to performance improvements in many cases, |
|
because this feature is still experimental, there may be performance |
|
cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for |
|
more information. |
|
strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to |
|
determine whether the Jacobian will be computed with forward or reverse |
|
mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. |
|
Defaults to ``"reverse-mode"``. If ``func`` has more outputs than |
|
inputs, ``"forward-mode"`` tends to be more performant. Otherwise, |
|
prefer to use ``"reverse-mode"``. |
|
|
|
Returns: |
|
Jacobian (Tensor or nested tuple of Tensors): if there is a single |
|
input and output, this will be a single Tensor containing the |
|
Jacobian for the linearized inputs and output. If one of the two is |
|
a tuple, then the Jacobian will be a tuple of Tensors. If both of |
|
them are tuples, then the Jacobian will be a tuple of tuple of |
|
Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the |
|
``i``\th output and ``j``\th input and will have as size the |
|
concatenation of the sizes of the corresponding output and the |
|
corresponding input and will have same dtype and device as the |
|
corresponding input. If strategy is ``forward-mode``, the dtype will be |
|
that of the output; otherwise, the input. |
|
|
|
Example: |
|
|
|
>>> def exp_reducer(x): |
|
... return x.exp().sum(dim=1) |
|
>>> inputs = torch.rand(2, 2) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> jacobian(exp_reducer, inputs) |
|
tensor([[[1.4917, 2.4352], |
|
[0.0000, 0.0000]], |
|
[[0.0000, 0.0000], |
|
[2.4369, 2.3799]]]) |
|
|
|
>>> jacobian(exp_reducer, inputs, create_graph=True) |
|
tensor([[[1.4917, 2.4352], |
|
[0.0000, 0.0000]], |
|
[[0.0000, 0.0000], |
|
[2.4369, 2.3799]]], grad_fn=<ViewBackward>) |
|
|
|
>>> def exp_adder(x, y): |
|
... return 2 * x.exp() + 3 * y |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> jacobian(exp_adder, inputs) |
|
(tensor([[2.8052, 0.0000], |
|
[0.0000, 3.3963]]), |
|
tensor([[3., 0.], |
|
[0., 3.]])) |
|
""" |
|
assert strategy in ("forward-mode", "reverse-mode"), ( |
|
'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' |
|
'function has more outputs than inputs, "forward-mode" tends to be more performant. ' |
|
'Otherwise, prefer to use "reverse-mode".') |
|
if strategy == "forward-mode": |
|
if create_graph: |
|
raise NotImplementedError('torch.autograd.functional.jacobian: `create_graph=True` ' |
|
'and `strategy="forward-mode"` are not supported together (yet). ' |
|
'Please either set `create_graph=False` or ' |
|
'`strategy="reverse-mode"`.') |
|
return _jacfwd(func, inputs, strict, vectorize) |
|
|
|
with torch.enable_grad(): |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") |
|
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) |
|
|
|
outputs = func(*inputs) |
|
is_outputs_tuple, outputs = _as_tuple(outputs, |
|
"outputs of the user-provided function", |
|
"jacobian") |
|
_check_requires_grad(outputs, "outputs", strict=strict) |
|
|
|
if vectorize: |
|
if strict: |
|
raise RuntimeError('torch.autograd.functional.jacobian: `strict=True` ' |
|
'and `vectorized=True` are not supported together. ' |
|
'Please either set `strict=False` or ' |
|
'`vectorize=False`.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_numels = tuple(output.numel() for output in outputs) |
|
grad_outputs = _construct_standard_basis_for(outputs, output_numels) |
|
flat_outputs = tuple(output.reshape(-1) for output in outputs) |
|
|
|
|
|
def vjp(grad_output): |
|
vj = list(_autograd_grad(flat_outputs, inputs, grad_output, create_graph=create_graph, is_grads_batched=True)) |
|
for el_idx, vj_el in enumerate(vj): |
|
if vj_el is not None: |
|
continue |
|
vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand((sum(output_numels),) + inputs[el_idx].shape) |
|
return tuple(vj) |
|
|
|
jacobians_of_flat_output = vjp(grad_outputs) |
|
|
|
|
|
|
|
jacobian_input_output = [] |
|
for jac, input_i in zip(jacobians_of_flat_output, inputs): |
|
jacobian_input_i_output = [] |
|
for jac, output_j in zip(jac.split(output_numels, dim=0), outputs): |
|
jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) |
|
jacobian_input_i_output.append(jacobian_input_i_output_j) |
|
jacobian_input_output.append(jacobian_input_i_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
jacobian_output_input = tuple(zip(*jacobian_input_output)) |
|
|
|
jacobian_output_input = _grad_postprocess(jacobian_output_input, create_graph) |
|
return _tuple_postprocess(jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)) |
|
|
|
jacobian: Tuple[torch.Tensor, ...] = tuple() |
|
|
|
for i, out in enumerate(outputs): |
|
|
|
|
|
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) |
|
for j in range(out.nelement()): |
|
vj = _autograd_grad((out.reshape(-1)[j],), inputs, |
|
retain_graph=True, create_graph=create_graph) |
|
|
|
for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(zip(jac_i, vj, inputs)): |
|
if vj_el is not None: |
|
if strict and create_graph and not vj_el.requires_grad: |
|
msg = ("The jacobian of the user-provided function is " |
|
"independent of input {}. This is not allowed in " |
|
"strict mode when create_graph=True.".format(i)) |
|
raise RuntimeError(msg) |
|
jac_i_el.append(vj_el) |
|
else: |
|
if strict: |
|
msg = ("Output {} of the user-provided function is " |
|
"independent of input {}. This is not allowed in " |
|
"strict mode.".format(i, el_idx)) |
|
raise RuntimeError(msg) |
|
jac_i_el.append(torch.zeros_like(inp_el)) |
|
|
|
jacobian += (tuple(torch.stack(jac_i_el, dim=0).view(out.size() |
|
+ inputs[el_idx].size()) for (el_idx, jac_i_el) in enumerate(jac_i)), ) |
|
|
|
jacobian = _grad_postprocess(jacobian, create_graph) |
|
|
|
return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) |
|
|
|
def hessian(func, inputs, create_graph=False, strict=False, vectorize=False, outer_jacobian_strategy="reverse-mode"): |
|
r"""Function that computes the Hessian of a given scalar function. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a Tensor with a single element. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
create_graph (bool, optional): If ``True``, the Hessian will be computed in |
|
a differentiable manner. Note that when ``strict`` is ``False``, the result can not |
|
require gradients or be disconnected from the inputs. |
|
Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input |
|
such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the |
|
hessian for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
vectorize (bool, optional): This feature is experimental. |
|
Please consider using |
|
`functorch <https://github.com/pytorch/functorch#what-are-the-transforms>`_ |
|
instead if you are looking for something less experimental and more performant. |
|
When computing the hessian, usually we invoke |
|
``autograd.grad`` once per row of the hessian. If this flag is |
|
``True``, we use the vmap prototype feature as the backend to |
|
vectorize calls to ``autograd.grad`` so we only invoke it once |
|
instead of once per row. This should lead to performance |
|
improvements in many use cases, however, due to this feature |
|
being incomplete, there may be performance cliffs. Please |
|
use `torch._C._debug_only_display_vmap_fallback_warnings(True)` |
|
to show any performance warnings and file us issues if |
|
warnings exist for your use case. Defaults to ``False``. |
|
outer_jacobian_strategy (str, optional): The Hessian is computed by |
|
computing the Jacobian of a Jacobian. The inner Jacobian is always |
|
computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` |
|
or ``"reverse-mode"`` determines whether the outer Jacobian will be |
|
computed with forward or reverse mode AD. Currently, computing the outer |
|
Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults |
|
to ``"reverse-mode"``. |
|
|
|
Returns: |
|
Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, |
|
this will be a single Tensor containing the Hessian for the input. |
|
If it is a tuple, then the Hessian will be a tuple of tuples where |
|
``Hessian[i][j]`` will contain the Hessian of the ``i``\th input |
|
and ``j``\th input with size the sum of the size of the ``i``\th input plus |
|
the size of the ``j``\th input. ``Hessian[i][j]`` will have the same |
|
dtype and device as the corresponding ``i``\th input. |
|
|
|
Example: |
|
|
|
>>> def pow_reducer(x): |
|
... return x.pow(3).sum() |
|
>>> inputs = torch.rand(2, 2) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> hessian(pow_reducer, inputs) |
|
tensor([[[[5.2265, 0.0000], |
|
[0.0000, 0.0000]], |
|
[[0.0000, 4.8221], |
|
[0.0000, 0.0000]]], |
|
[[[0.0000, 0.0000], |
|
[1.9456, 0.0000]], |
|
[[0.0000, 0.0000], |
|
[0.0000, 3.2550]]]]) |
|
|
|
>>> hessian(pow_reducer, inputs, create_graph=True) |
|
tensor([[[[5.2265, 0.0000], |
|
[0.0000, 0.0000]], |
|
[[0.0000, 4.8221], |
|
[0.0000, 0.0000]]], |
|
[[[0.0000, 0.0000], |
|
[1.9456, 0.0000]], |
|
[[0.0000, 0.0000], |
|
[0.0000, 3.2550]]]], grad_fn=<ViewBackward>) |
|
|
|
|
|
>>> def pow_adder_reducer(x, y): |
|
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> hessian(pow_adder_reducer, inputs) |
|
((tensor([[4., 0.], |
|
[0., 4.]]), |
|
tensor([[0., 0.], |
|
[0., 0.]])), |
|
(tensor([[0., 0.], |
|
[0., 0.]]), |
|
tensor([[6., 0.], |
|
[0., 6.]]))) |
|
""" |
|
|
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") |
|
assert outer_jacobian_strategy in ("forward-mode", "reverse-mode"), ( |
|
'Expected strategy to be either "forward-mode" or "reverse-mode".') |
|
|
|
def ensure_single_output_function(*inp): |
|
out = func(*inp) |
|
is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian") |
|
_check_requires_grad(t_out, "outputs", strict=strict) |
|
|
|
if is_out_tuple or not isinstance(out, torch.Tensor): |
|
raise RuntimeError("The function given to hessian should return a single Tensor") |
|
|
|
if out.nelement() != 1: |
|
raise RuntimeError("The Tensor returned by the function given to hessian should contain a single element") |
|
|
|
return out.squeeze() |
|
|
|
def jac_func(*inp): |
|
if outer_jacobian_strategy == "forward-mode": |
|
|
|
|
|
inp = tuple(t.requires_grad_(True) for t in inp) |
|
jac = jacobian(ensure_single_output_function, inp, create_graph=True) |
|
_check_requires_grad(jac, "jacobian", strict=strict) |
|
return jac |
|
|
|
res = jacobian(jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize, |
|
strategy=outer_jacobian_strategy) |
|
return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) |
|
|
|
|
|
def vhp(func, inputs, v=None, create_graph=False, strict=False): |
|
r"""Function that computes the dot product between a vector ``v`` and the |
|
Hessian of a given scalar function at the point given by the inputs. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a Tensor with a single element. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
v (tuple of Tensors or Tensor): The vector for which the vector Hessian |
|
product is computed. Must be the same size as the input of |
|
``func``. This argument is optional when ``func``'s input contains |
|
a single element and (if it is not provided) will be set as a |
|
Tensor containing a single ``1``. |
|
create_graph (bool, optional): If ``True``, both the output and result |
|
will be computed in a differentiable way. Note that when ``strict`` |
|
is ``False``, the result can not require gradients or be |
|
disconnected from the inputs. |
|
Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we |
|
detect that there exists an input such that all the outputs are |
|
independent of it. If ``False``, we return a Tensor of zeros as the |
|
vhp for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
|
|
Returns: |
|
output (tuple): tuple with: |
|
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` |
|
|
|
vhp (tuple of Tensors or Tensor): result of the dot product with the |
|
same shape as the inputs. |
|
|
|
Example: |
|
|
|
>>> def pow_reducer(x): |
|
... return x.pow(3).sum() |
|
>>> inputs = torch.rand(2, 2) |
|
>>> v = torch.ones(2, 2) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> vhp(pow_reducer, inputs, v) |
|
(tensor(0.5591), |
|
tensor([[1.0689, 1.2431], |
|
[3.0989, 4.4456]])) |
|
>>> vhp(pow_reducer, inputs, v, create_graph=True) |
|
(tensor(0.5591, grad_fn=<SumBackward0>), |
|
tensor([[1.0689, 1.2431], |
|
[3.0989, 4.4456]], grad_fn=<MulBackward0>)) |
|
>>> def pow_adder_reducer(x, y): |
|
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> v = (torch.zeros(2), torch.ones(2)) |
|
>>> vhp(pow_adder_reducer, inputs, v) |
|
(tensor(4.8053), |
|
(tensor([0., 0.]), |
|
tensor([6., 6.]))) |
|
""" |
|
|
|
with torch.enable_grad(): |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") |
|
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) |
|
|
|
if v is not None: |
|
_, v = _as_tuple(v, "v", "vhp") |
|
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) |
|
_validate_v(v, inputs, is_inputs_tuple) |
|
else: |
|
if len(inputs) != 1 or inputs[0].nelement() != 1: |
|
raise RuntimeError("The vector v can only be None if the input to the user-provided function " |
|
"is a single Tensor with a single element.") |
|
outputs = func(*inputs) |
|
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "vhp") |
|
_check_requires_grad(outputs, "outputs", strict=strict) |
|
|
|
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): |
|
raise RuntimeError("The function given to vhp should return a single Tensor") |
|
|
|
if outputs[0].nelement() != 1: |
|
raise RuntimeError("The Tensor returned by the function given to vhp should contain a single element") |
|
|
|
jac = _autograd_grad(outputs, inputs, create_graph=True) |
|
_check_requires_grad(jac, "jacobian", strict=strict) |
|
|
|
enable_grad = True if create_graph else torch.is_grad_enabled() |
|
with torch.set_grad_enabled(enable_grad): |
|
grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) |
|
vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") |
|
|
|
outputs = _grad_postprocess(outputs, create_graph) |
|
vhp = _grad_postprocess(vhp, create_graph) |
|
|
|
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(vhp, is_inputs_tuple) |
|
|
|
|
|
def hvp(func, inputs, v=None, create_graph=False, strict=False): |
|
r"""Function that computes the dot product between the Hessian of a given scalar |
|
function and a vector ``v`` at the point given by the inputs. |
|
|
|
Args: |
|
func (function): a Python function that takes Tensor inputs and returns |
|
a Tensor with a single element. |
|
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. |
|
v (tuple of Tensors or Tensor): The vector for which the Hessian vector |
|
product is computed. Must be the same size as the input of |
|
``func``. This argument is optional when ``func``'s input contains |
|
a single element and (if it is not provided) will be set as a |
|
Tensor containing a single ``1``. |
|
create_graph (bool, optional): If ``True``, both the output and result will be |
|
computed in a differentiable way. Note that when ``strict`` is |
|
``False``, the result can not require gradients or be disconnected |
|
from the inputs. Defaults to ``False``. |
|
strict (bool, optional): If ``True``, an error will be raised when we |
|
detect that there exists an input such that all the outputs are |
|
independent of it. If ``False``, we return a Tensor of zeros as the |
|
hvp for said inputs, which is the expected mathematical value. |
|
Defaults to ``False``. |
|
Returns: |
|
output (tuple): tuple with: |
|
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` |
|
|
|
hvp (tuple of Tensors or Tensor): result of the dot product with |
|
the same shape as the inputs. |
|
|
|
Example: |
|
|
|
>>> def pow_reducer(x): |
|
... return x.pow(3).sum() |
|
>>> inputs = torch.rand(2, 2) |
|
>>> v = torch.ones(2, 2) |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> hvp(pow_reducer, inputs, v) |
|
(tensor(0.1448), |
|
tensor([[2.0239, 1.6456], |
|
[2.4988, 1.4310]])) |
|
|
|
>>> hvp(pow_reducer, inputs, v, create_graph=True) |
|
(tensor(0.1448, grad_fn=<SumBackward0>), |
|
tensor([[2.0239, 1.6456], |
|
[2.4988, 1.4310]], grad_fn=<MulBackward0>)) |
|
|
|
|
|
>>> def pow_adder_reducer(x, y): |
|
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() |
|
>>> inputs = (torch.rand(2), torch.rand(2)) |
|
>>> v = (torch.zeros(2), torch.ones(2)) |
|
>>> hvp(pow_adder_reducer, inputs, v) |
|
(tensor(2.3030), |
|
(tensor([0., 0.]), |
|
tensor([6., 6.]))) |
|
|
|
Note: |
|
|
|
This function is significantly slower than `vhp` due to backward mode AD constraints. |
|
If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you |
|
know that your function satisfies this condition, you should use vhp instead that is |
|
much faster with the current implementation. |
|
|
|
""" |
|
|
|
with torch.enable_grad(): |
|
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") |
|
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) |
|
|
|
if v is not None: |
|
_, v = _as_tuple(v, "v", "hvp") |
|
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) |
|
_validate_v(v, inputs, is_inputs_tuple) |
|
else: |
|
if len(inputs) != 1 or inputs[0].nelement() != 1: |
|
raise RuntimeError("The vector v can only be None if the input to the user-provided function " |
|
"is a single Tensor with a single element.") |
|
outputs = func(*inputs) |
|
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "hvp") |
|
_check_requires_grad(outputs, "outputs", strict=strict) |
|
|
|
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): |
|
raise RuntimeError("The function given to hvp should return a single Tensor") |
|
|
|
if outputs[0].nelement() != 1: |
|
raise RuntimeError("The Tensor returned by the function given to hvp should contain a single element") |
|
|
|
jac = _autograd_grad(outputs, inputs, create_graph=True) |
|
_check_requires_grad(jac, "jacobian", strict=strict) |
|
|
|
grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs) |
|
|
|
double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) |
|
_check_requires_grad(jac, "hessian", strict=strict) |
|
|
|
enable_grad = True if create_graph else torch.is_grad_enabled() |
|
with torch.set_grad_enabled(enable_grad): |
|
grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) |
|
hvp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back_trick") |
|
|
|
outputs = _grad_postprocess(outputs, create_graph) |
|
hvp = _grad_postprocess(hvp, create_graph) |
|
|
|
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(hvp, is_inputs_tuple) |
|
|