Spaces:
Runtime error
Runtime error
import torch | |
from torch._ops import HigherOrderOperator | |
from torch._C._functorch import TransformType | |
from torch._functorch.utils import enable_single_level_autograd_function | |
import torch.utils._pytree as pytree | |
from torch._C._functorch import ( | |
_wrap_for_grad, | |
_unwrap_for_grad, | |
current_level, | |
) | |
from torch._functorch.vmap import ( | |
wrap_batched, | |
unwrap_batched, | |
restore_vmap, | |
_add_batch_dim, | |
) | |
from torch._functorch.apis import vmap | |
from torch._functorch.vmap import _broadcast_to_and_flatten | |
from torch.autograd.forward_ad import _set_fwd_grad_enabled | |
from typing import Any, NamedTuple, Tuple | |
# autograd.Function technically runs before the regular PyTorch dispatcher. | |
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) | |
# work with it. One day we might decide to change this, but until then, | |
# we need to give the illusion that autograd.Function runs before those things. | |
# | |
# We do this by using creating a custom HigherOrderOperator that only functorch | |
# dispatches specially. | |
class CustomFunctionHigherOrderOperator(HigherOrderOperator): | |
def __init__(self): | |
super().__init__('custom_function_call') | |
def __call__(self, autograd_function, *args, **kwargs): | |
# When custom_function_call is done dispatching through functorch, | |
# it should just invoke the autograd.Function. This is consistent | |
# with the autograd.Function behavior of being invoked before the | |
# PyTorch dispatcher. | |
# | |
# This will lead us into trouble later down the line, but this is | |
# pre-existing. There is an invariant that a function traced by | |
# make_fx should have the same behavior when provided the same | |
# Tensor. However, make_fx sees autograd.Function as a composite | |
# (because autograd.Function happens before the Python dispatch key) | |
# and only traces the forward pass. | |
if torch._C._are_functorch_transforms_active(): | |
return super().__call__(autograd_function, *args, **kwargs) | |
return autograd_function.apply(*args, **kwargs) | |
# "custom_function_call" | |
# This is the mechanism for an autograd.Function that works with functorch transforms. | |
# It wraps an autograd.Function; interactions with functorch transforms are defined | |
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch | |
# dispatcher. | |
custom_function_call = CustomFunctionHigherOrderOperator() | |
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction | |
# (autograd.Function that only works with a single layer (level) of functorch) that: | |
# - unwraps the inputs | |
# - redispatches to custom_function_call | |
# - wraps the outputs | |
# and whose backward pass calls the original autograd.Function's backward. | |
# | |
# Why do we need to redispatch to custom_function_call? | |
# ----------------------------------------------------- | |
# This is consistent with how ATen operators work with functorch's grad transform: | |
# they always redispatch to the original operator. | |
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) | |
# | |
# grad1 will: | |
# - set up the autograd graph | |
# - unwrap the inputs | |
# - redispatch to at::sin (*) | |
# - rewrap the outputs on the return | |
# | |
# On the redispatch in (*), grad0 will: | |
# - set up the autograd graph | |
# - unwrap the inputs | |
# - redispatch to at::sin | |
# - rewrap the outputs on the return | |
# | |
# To "set up the autograd graph", we generate a _SingleLevelFunction | |
# and apply it. | |
def custom_function_call_grad(interpreter, autograd_function, *operands): | |
Generated = generate_single_level_function(interpreter, autograd_function) | |
with enable_single_level_autograd_function(): | |
flat_out = Generated.apply(*operands) | |
return flat_out | |
def generate_single_level_function(interpreter, autograd_function): | |
level = interpreter.level() | |
def forward(*operands): | |
unwrapped_operands = pytree.tree_map_only( | |
torch.Tensor, | |
lambda x: _unwrap_for_grad(x, level), | |
operands) | |
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter | |
# the transform. _SingleLevelFunction will turn off both fwd and bwd | |
# gradient computation and we need to turn it back on here. | |
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): | |
unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands) | |
# See NOTE [mark_dirty object identity check] | |
def wrap_fn(output): | |
return _wrap_for_grad(output, level) | |
return wrap_outputs_maintaining_identity( | |
unwrapped_output, | |
unwrapped_operands, | |
operands, | |
wrap_fn) | |
def setup_context(ctx, inputs, output): | |
return autograd_function.setup_context(ctx, inputs, output) | |
# backward is only used if the transform is TransformType.Grad | |
def backward(ctx, *grads): | |
result = autograd_function.backward(ctx, *grads) | |
return result | |
# jvp is only used if the transform is TransformType.Jvp | |
def jvp(ctx, *tangents): | |
result = autograd_function.jvp(ctx, *tangents) | |
return result | |
# This is the sequence of magic words to dynamically generate a Subclass with | |
# a given name. A Tensor's .grad_fn field has a class name that is the original | |
# autograd.Function's name + Backward, so we do this to generate some | |
# meaningful name. | |
name = f'{autograd_function.__name__}Generated' | |
Generated = type( | |
name, | |
(torch.autograd.function._SingleLevelFunction,), | |
{ | |
'forward': staticmethod(forward), | |
'backward': staticmethod(backward), | |
'jvp': staticmethod(jvp), | |
'setup_context': staticmethod(setup_context), | |
}, | |
) | |
return Generated | |
# wrap_outputs_maintaining_identity handles outputs from the vmap, | |
# backward (vjp), and jvp staticmethod. The way it distinguishes | |
# between the vmap case and the {backward, jvp} case is if the out_dims | |
# are specified or not. | |
# | |
# NB: we cannot use out_dims=None as the deciding factor. This because | |
# out_dims=None can still happen in the vmap staticmethod! What the | |
# user is saying in that case is that their output does not have a | |
# dimension that is being vmapped over, which is valid. | |
NO_OUT_DIMS = "not specified" | |
# NOTE [mark_dirty object identity check] | |
# autograd.Function's ctx.mark_dirty expect a returned input | |
# to have the same object identity as the input. | |
# Mode-only functorch will greatly simplify this logic. | |
def wrap_outputs_maintaining_identity( | |
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS): | |
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) | |
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) | |
unwrapped_input_to_orig_input = { | |
id(unwrapped): orig | |
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs) | |
} | |
flat_outputs, spec = pytree.tree_flatten(outputs) | |
result = [] | |
out_dims_specified = out_dims != NO_OUT_DIMS | |
if out_dims_specified: | |
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec) | |
# _broadcast_to_and_flatten returns None if it is unable to broadcast. | |
# TODO: update following link from master to stable once that's out | |
if flat_out_dims is None: | |
raise RuntimeError( | |
f"The autograd.Function's vmap staticmethod returned an " | |
f"incompatible (output, out_dims) tuple. " | |
f"Expected out_dims={out_dims} " | |
f"to be compatible with the structure of `output`. " | |
f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " | |
f"but output has structure {spec}. " | |
f"For more details, please see " | |
f"https://pytorch.org/docs/master/notes/extending.func.html" | |
) | |
for i, output in enumerate(flat_outputs): | |
if not isinstance(output, torch.Tensor): | |
result.append(output) | |
continue | |
if id(output) in unwrapped_input_to_orig_input: | |
result.append(unwrapped_input_to_orig_input[id(output)]) | |
continue | |
if out_dims_specified: | |
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[index] | |
else: | |
result.append(wrap_fn(output)) | |
return pytree.tree_unflatten(result, spec) | |
# NOTE: [functorch vjp and autograd interaction] | |
# There's an edge case with the functorch vjp and autograd interaction | |
# that will eventually be fixed by mode-only functorch. | |
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, | |
# so we (the framework) need to do it manually. Regular PyTorch operators | |
# automatically do so this is consistent. | |
# | |
# class MyExp(torch.autograd.Function): | |
# @staticmethod | |
# def forward(x): | |
# return x.exp() | |
# | |
# @staticmethod | |
# def setup_context(ctx, inputs, output): | |
# y = output | |
# ctx.save_for_backward(y) | |
# | |
# @staticmethod | |
# def backward(gy): | |
# y, = ctx.saved_tensors() | |
# return MyMul.apply(gy, y) | |
# | |
# x = torch.randn([], requires_grad=True) | |
# gy = torch.randn([], requires_grad=True) | |
# _, vjp_fn = vjp(MySin.apply, x) | |
# result = vjp_fn(gy) | |
# | |
# MyMul is an autograd.Function that is not shown here. | |
# It saves a `y` for backward (since gy requires grad). | |
# | |
# in vjp_fn(gy), we get: | |
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) | |
# Because the y that is saved for backward by MyExp is a GradTensorWrapper | |
# but is now dead since we are outside the vjp context. | |
# | |
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, | |
# will automatically unwrap the GradTensorWrapper when applied. | |
# But since autograd.Function technically sits above the regular PyTorch | |
# dispatcher, it doesn't get this treatment. So we manually do | |
# the unwrapping to be consistent with regular PyTorch dispatcher operations. | |
class VmapInfo(NamedTuple): | |
batch_size: int | |
randomness: str | |
def has_overriden_vmap_rule(autograd_function): | |
return autograd_function.vmap is not torch.autograd.Function.vmap | |
def validate_vmap_returns_tuple_of_two_elements(result): | |
base_error_msg = ( | |
"Expected the vmap staticmethod to have two returns, an output " | |
"and out_dims with pytree structure compatible with the output. " | |
) | |
if not isinstance(result, tuple): | |
raise RuntimeError(base_error_msg + f"Got a {type(result)} instead") | |
if not len(result) == 2: | |
raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead") | |
def custom_function_call_vmap(interpreter, autograd_function, *operands): | |
if autograd_function.generate_vmap_rule: | |
if has_overriden_vmap_rule(autograd_function): | |
# TODO: Update link to stable once that's out | |
# https://github.com/pytorch/pytorch/issues/92029 | |
raise RuntimeError( | |
f"You tried to vmap over {autograd_function.__name__}, but " | |
f"it has both generate_vmap_rule=True and an overriden vmap " | |
f"staticmethod. Please set generate_vmap_rule=False or delete " | |
f"the overriden vmap staticmethod to avoid ambiguity. " | |
f"For more details, please see " | |
f"https://pytorch.org/docs/master/notes/extending.func.html") | |
return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands) | |
if not has_overriden_vmap_rule(autograd_function): | |
# TODO: Update link to stable once that's out | |
# https://github.com/pytorch/pytorch/issues/92029 | |
raise RuntimeError( | |
f"You tried to vmap over {autograd_function.__name__}, but " | |
f"it does not have vmap support. Please override and implement the " | |
f"vmap staticmethod or set generate_vmap_rule=True. " | |
f"For more details, please see " | |
f"https://pytorch.org/docs/master/notes/extending.func.html") | |
current_level = interpreter.level() | |
info = VmapInfo( | |
batch_size=interpreter.batch_size(), | |
randomness=interpreter.randomness(), | |
) | |
unwrapped_operands, in_dims = unwrap_batched(operands, current_level) | |
# If none of the tensors are batched at the current level, then we skip the | |
# current level. This saves the user from needing to handle this case in | |
# their vmap staticmethod (and is consistent with our C++ batching rule API) | |
if pytree.tree_all(lambda dim: dim is None, in_dims): | |
with interpreter.lower(): | |
return custom_function_call(autograd_function, *operands) | |
with interpreter.lower(): | |
result = autograd_function.vmap(info, in_dims, *unwrapped_operands) | |
validate_vmap_returns_tuple_of_two_elements(result) | |
unwrapped_output, out_dims = result | |
# See NOTE [mark_dirty object identity check] | |
def wrap_fn(output, out_dim): | |
return output if out_dim is None else _add_batch_dim(output, out_dim, current_level) | |
return wrap_outputs_maintaining_identity( | |
unwrapped_output, | |
unwrapped_operands, | |
operands, | |
wrap_fn, | |
out_dims=out_dims) | |
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands): | |
unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level()) | |
vmapped_function, get_out_dims = vmapify_autograd_function( | |
autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()) | |
with interpreter.lower(): | |
output = custom_function_call(vmapped_function, *unwrapped_operands) | |
out_dims = get_out_dims() | |
return wrap_batched(output, out_dims, interpreter.level()) | |
def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands): | |
raise RuntimeError("NYI: Functionalize rule for custom_function_call") | |
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness): | |
# The following values are saved from the forward() and setup_context() | |
# and used in backward(). | |
# Why do we save the values out here instead of on the ctx object? | |
# - out_dims: There's no way to retrieve this from forward() | |
# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting | |
# vmap(vmap( but not completely sure if it is a problem. If we | |
# assigned those fields to the ctx object, the worry is that they | |
# get overwritten. | |
init_val = "not populated" | |
out_dims = init_val | |
input_shapes: Any = init_val | |
saved_tensors_bdims: Any = init_val | |
def forward(*operands): | |
nonlocal out_dims | |
outputs, out_dims = restore_vmap( | |
autograd_function.forward, in_dims, batch_size, randomness)(*operands) | |
return outputs | |
def setup_context(ctx, inputs, outputs): | |
input_shapes_ = None | |
saved_tensors_bdims_ = None | |
def inner(inputs, outputs): | |
# wrapped_ctx.save_for_backward will: | |
# - unwrap batchedtensors into (tensor, bdim) | |
# - save_for_backward(*unwrapped_tensors) | |
# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims | |
wrapped_ctx = CtxCustomSave(ctx, current_level()) | |
autograd_function.setup_context(wrapped_ctx, inputs, outputs) | |
# input_shapes are used for reductify later to reduce expanded gradients | |
# to the correct shape. | |
# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?] | |
# for more details | |
nonlocal input_shapes_ | |
input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None | |
for inp in inputs) | |
nonlocal saved_tensors_bdims_ | |
saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims | |
# See NOTE: [Why do we need to run setup_context under a vmap?] | |
restore_vmap( | |
inner, | |
(in_dims, out_dims), | |
batch_size, | |
randomness, | |
)(inputs, outputs) | |
nonlocal input_shapes | |
input_shapes = input_shapes_ | |
nonlocal saved_tensors_bdims | |
saved_tensors_bdims = saved_tensors_bdims_ | |
def jvp(ctx, *tangents): | |
assert out_dims != init_val | |
assert saved_tensors_bdims != init_val | |
def jvp_no_context(saved_tensors, tangents): | |
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) | |
return autograd_function.jvp(wrapped_ctx, *tangents) | |
tangent_in_dims = get_tangents_in_dims(in_dims, tangents) | |
out_tangents, out_tangents_dims = restore_vmap( | |
jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)( | |
ctx.saved_tensors, tangents) | |
result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size) | |
return result | |
def backward(ctx, *grad_outputs): | |
assert out_dims != init_val | |
assert input_shapes != init_val | |
assert saved_tensors_bdims != init_val | |
def backward_no_context(inputs): | |
saved_tensors, grad_outputs = inputs | |
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) | |
return autograd_function.backward(wrapped_ctx, *grad_outputs) | |
grad_ins, grad_ins_dims = restore_vmap( | |
backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)( | |
(ctx.saved_tensors, grad_outputs)) | |
result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes) | |
return result | |
name = f'Vmapped{autograd_function.__name__}' | |
Generated = type( | |
name, | |
(torch.autograd.Function,), | |
{ | |
'forward': staticmethod(forward), | |
'backward': staticmethod(backward), | |
'jvp': staticmethod(jvp), | |
'setup_context': staticmethod(setup_context), | |
'generate_vmap_rule': True | |
} | |
) | |
def get_out_dims(): | |
assert out_dims != init_val | |
return out_dims | |
return Generated, get_out_dims | |
# tangents might be None, so we need to replace | |
# the corresponding in_dims with None. | |
def get_tangents_in_dims(input_dims, tangents): | |
flat_in_dims, spec = pytree.tree_flatten(input_dims) | |
flat_tangents = pytree.arg_tree_leaves(*tangents) | |
result = [None if tangent is None else in_dim | |
for in_dim, tangent in zip(flat_in_dims, flat_tangents)] | |
return pytree.tree_unflatten(result, spec) | |
# NOTE: [Why do we need to run setup_context under a vmap?] | |
# Consider the following autograd.Function | |
# | |
# class Sum(torch.autograd.Function): | |
# @staticmethod | |
# def forward(x): | |
# return x.sum() | |
# @staticmethod | |
# def setup_context(ctx, inputs, outputs): | |
# ctx.x_shape = inputs[0] | |
# @staticmethod | |
# def backward(ctx, gy): | |
# return gy.expand(ctx.x_shape) | |
# | |
# x = torch.randn(B, 4) | |
# in_dims = 0 | |
# vmap(Sum.apply, in_dims)(x) | |
# | |
# Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum: | |
# | |
# class VmappedSum(torch.autograd.Function): | |
# @staticmethod | |
# def forward(x): | |
# return vmap(Sum.forward, in_dims)(x) | |
# | |
# @staticmethod | |
# def setup_context(ctx, inputs, outputs): | |
# Sum.setup_context(ctx, inputs, outputs) | |
# | |
# @staticmethod | |
# def backward(ctx, gy): | |
# def backward_no_context(gy): | |
# return gy.expand(ctx.x_shape) | |
# | |
# dims = (0,) | |
# gx = vmap(backward_no_context, dims)(gy) | |
# return gx | |
# | |
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B], | |
# and we’re doing: | |
# | |
# def backward_no_context(gy): | |
# return gy.expand([B, 4]) | |
# | |
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]") | |
# | |
# This gives us the wrong result (gx has shape [B, B, 4], but it should | |
# have shape [4]). Performing vmap over setup_context means the shape | |
# saved has shape [4] and leads to a correct result shape for gx. | |
# Wraps a ctx object. Forwards all attr accesses to the underlying object | |
# except for the attrs in _pt_attrs | |
class WrappedCtx: | |
_pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx') | |
def __init__(self, ctx): | |
if not isinstance(ctx, WrappedCtx): | |
reserved_attrs = type(self)._pt_reserved_attrs | |
for name in reserved_attrs: | |
if not hasattr(ctx, name): | |
continue | |
raise RuntimeError( | |
f'PyTorch reserves the {reserved_attrs} field on ctx. ' | |
'Please name your fields on ctx something else to avoid name ' | |
'collision.') | |
self._pt_inner_ctx = ctx | |
def __getattr__(self, name): | |
return getattr(self._pt_inner_ctx, name) | |
def __setattr__(self, name, value): | |
if name in type(self)._pt_reserved_attrs: | |
self.__dict__[name] = value | |
return | |
return setattr(self._pt_inner_ctx, name, value) | |
# Wraps ctx to create a new ctx object that overrides saved_tensors. | |
class CtxWithSavedTensors(WrappedCtx): | |
_pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs) | |
def __init__(self, ctx, new_saved_tensors): | |
super().__init__(ctx) | |
self._pt_new_saved_tensors = new_saved_tensors | |
def saved_tensors(self): | |
return self._pt_new_saved_tensors | |
class CtxCustomSave(WrappedCtx): | |
_pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level', | |
*WrappedCtx._pt_reserved_attrs) | |
def __init__(self, ctx, current_level): | |
super().__init__(ctx) | |
self._pt_saved_tensors_bdims = () | |
self._pt_current_level = current_level | |
def save_for_backward(self, *tensors): | |
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) | |
self._pt_inner_ctx.save_for_backward(*unwrapped_tensors) | |
self._pt_saved_tensors_bdims = bdims | |
def save_for_forward(self, *tensors): | |
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) | |
self._pt_inner_ctx.save_for_forward(*unwrapped_tensors) | |
self._pt_saved_tensors_bdims = bdims | |
def reductify(grad_input, grad_input_bdim, input_bdim, batch_size, | |
target_shape_without_bdim_to_reduce_to=None): | |
if not isinstance(grad_input, tuple): | |
grad_input = (grad_input,) | |
if not isinstance(grad_input_bdim, tuple): | |
grad_input_bdim = (grad_input_bdim,) | |
if not isinstance(input_bdim, tuple): | |
input_bdim = (input_bdim,) | |
if target_shape_without_bdim_to_reduce_to is None: | |
target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,) | |
result = tuple( | |
reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape) | |
for gi, gi_bdim, i_bdim, maybe_ishape in | |
zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to) | |
) | |
return result | |
def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size, | |
target_shape_without_bdim_to_reduce_to=None): | |
if grad_input is None: | |
return None | |
if grad_input_bdim is None and input_bdim is None: | |
return grad_input | |
if grad_input_bdim is not None and input_bdim is None: | |
return grad_input.sum(grad_input_bdim) | |
# NOTE: [Why can't we rely on autograd to reduce expanded gradients?] | |
# For reverse-mode AD, | |
# given a grad_input and input, it is valid for the user to return a | |
# grad_input that has a broadcasted shape when compared to the input. | |
# In this situation, autograd automatically reduces the grad_input to | |
# the shape of the input. | |
# | |
# However, when input_bdim is not None, we have problems. | |
# | |
# [example 1] | |
# grad_input: Tensor[3, 4], input: Tensor[B, 4] | |
# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable | |
# from [B, 4]. | |
# | |
# [example 2] | |
# grad_input: Tensor[3, B, 4], input: Tensor[B, 4] | |
# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable | |
# from [B, 4]. | |
# | |
# This means that we need to also reduce the grad_input to the shape of the | |
# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag; | |
# if not-None then we do the reducing manually, otherwise, we do not do a reduction. | |
assert input_bdim is not None | |
if grad_input_bdim is None: | |
grad_input = grad_input.unsqueeze(input_bdim) | |
new_shape = list(grad_input.shape) | |
new_shape[input_bdim] = batch_size | |
grad_input = grad_input.expand(new_shape) | |
grad_input_bdim = input_bdim | |
if target_shape_without_bdim_to_reduce_to is not None: | |
return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)( | |
grad_input, target_shape_without_bdim_to_reduce_to) | |
if input_bdim != grad_input_bdim: | |
grad_input = grad_input.movedim(grad_input_bdim, input_bdim) | |
return grad_input | |