|
import contextlib
|
|
import warnings
|
|
|
|
import torch
|
|
from torch import autograd
|
|
from torch.nn import functional as F
|
|
|
|
enabled = True
|
|
weight_gradients_disabled = False
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def no_weight_gradients():
|
|
global weight_gradients_disabled
|
|
|
|
old = weight_gradients_disabled
|
|
weight_gradients_disabled = True
|
|
yield
|
|
weight_gradients_disabled = old
|
|
|
|
|
|
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|
if could_use_op(input):
|
|
return conv2d_gradfix(
|
|
transpose=False,
|
|
weight_shape=weight.shape,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=0,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
).apply(input, weight, bias)
|
|
|
|
return F.conv2d(
|
|
input=input,
|
|
weight=weight,
|
|
bias=bias,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
|
|
|
|
def conv_transpose2d(
|
|
input,
|
|
weight,
|
|
bias=None,
|
|
stride=1,
|
|
padding=0,
|
|
output_padding=0,
|
|
groups=1,
|
|
dilation=1,
|
|
):
|
|
if could_use_op(input):
|
|
return conv2d_gradfix(
|
|
transpose=True,
|
|
weight_shape=weight.shape,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
groups=groups,
|
|
dilation=dilation,
|
|
).apply(input, weight, bias)
|
|
|
|
return F.conv_transpose2d(
|
|
input=input,
|
|
weight=weight,
|
|
bias=bias,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
)
|
|
|
|
|
|
def could_use_op(input):
|
|
if (not enabled) or (not torch.backends.cudnn.enabled):
|
|
return False
|
|
|
|
if input.device.type != "cuda":
|
|
return False
|
|
|
|
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
|
return True
|
|
|
|
warnings.warn(
|
|
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
|
)
|
|
|
|
return False
|
|
|
|
|
|
def ensure_tuple(xs, ndim):
|
|
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
|
|
|
return xs
|
|
|
|
|
|
conv2d_gradfix_cache = dict()
|
|
|
|
|
|
def conv2d_gradfix(
|
|
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
|
):
|
|
ndim = 2
|
|
weight_shape = tuple(weight_shape)
|
|
stride = ensure_tuple(stride, ndim)
|
|
padding = ensure_tuple(padding, ndim)
|
|
output_padding = ensure_tuple(output_padding, ndim)
|
|
dilation = ensure_tuple(dilation, ndim)
|
|
|
|
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
|
if key in conv2d_gradfix_cache:
|
|
return conv2d_gradfix_cache[key]
|
|
|
|
common_kwargs = dict(
|
|
stride=stride, padding=padding, dilation=dilation, groups=groups
|
|
)
|
|
|
|
def calc_output_padding(input_shape, output_shape):
|
|
if transpose:
|
|
return [0, 0]
|
|
|
|
return [
|
|
input_shape[i + 2]
|
|
- (output_shape[i + 2] - 1) * stride[i]
|
|
- (1 - 2 * padding[i])
|
|
- dilation[i] * (weight_shape[i + 2] - 1)
|
|
for i in range(ndim)
|
|
]
|
|
|
|
class Conv2d(autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias):
|
|
if not transpose:
|
|
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
|
|
|
else:
|
|
out = F.conv_transpose2d(
|
|
input=input,
|
|
weight=weight,
|
|
bias=bias,
|
|
output_padding=output_padding,
|
|
**common_kwargs,
|
|
)
|
|
|
|
ctx.save_for_backward(input, weight)
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight = ctx.saved_tensors
|
|
grad_input, grad_weight, grad_bias = None, None, None
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
p = calc_output_padding(
|
|
input_shape=input.shape, output_shape=grad_output.shape
|
|
)
|
|
grad_input = conv2d_gradfix(
|
|
transpose=(not transpose),
|
|
weight_shape=weight_shape,
|
|
output_padding=p,
|
|
**common_kwargs,
|
|
).apply(grad_output, weight, None)
|
|
|
|
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
|
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
|
|
|
if ctx.needs_input_grad[2]:
|
|
grad_bias = grad_output.sum((0, 2, 3))
|
|
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
class Conv2dGradWeight(autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, grad_output, input):
|
|
op = torch._C._jit_get_operation(
|
|
"aten::cudnn_convolution_backward_weight"
|
|
if not transpose
|
|
else "aten::cudnn_convolution_transpose_backward_weight"
|
|
)
|
|
flags = [
|
|
torch.backends.cudnn.benchmark,
|
|
torch.backends.cudnn.deterministic,
|
|
torch.backends.cudnn.allow_tf32,
|
|
]
|
|
grad_weight = op(
|
|
weight_shape,
|
|
grad_output,
|
|
input,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
*flags,
|
|
)
|
|
ctx.save_for_backward(grad_output, input)
|
|
|
|
return grad_weight
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_grad_weight):
|
|
grad_output, input = ctx.saved_tensors
|
|
grad_grad_output, grad_grad_input = None, None
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
|
|
|
if ctx.needs_input_grad[1]:
|
|
p = calc_output_padding(
|
|
input_shape=input.shape, output_shape=grad_output.shape
|
|
)
|
|
grad_grad_input = conv2d_gradfix(
|
|
transpose=(not transpose),
|
|
weight_shape=weight_shape,
|
|
output_padding=p,
|
|
**common_kwargs,
|
|
).apply(grad_output, grad_grad_weight, None)
|
|
|
|
return grad_grad_output, grad_grad_input
|
|
|
|
conv2d_gradfix_cache[key] = Conv2d
|
|
|
|
return Conv2d
|
|
|