|
import contextlib |
|
import warnings |
|
|
|
import torch |
|
from torch import autograd |
|
from torch.nn import functional as F |
|
|
|
enabled = True |
|
weight_gradients_disabled = False |
|
|
|
|
|
@contextlib.contextmanager |
|
def no_weight_gradients(): |
|
global weight_gradients_disabled |
|
|
|
old = weight_gradients_disabled |
|
weight_gradients_disabled = True |
|
yield |
|
weight_gradients_disabled = old |
|
|
|
|
|
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): |
|
if could_use_op(input): |
|
return conv2d_gradfix( |
|
transpose=False, |
|
weight_shape=weight.shape, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=0, |
|
dilation=dilation, |
|
groups=groups, |
|
).apply(input, weight, bias) |
|
|
|
return F.conv2d( |
|
input=input, |
|
weight=weight, |
|
bias=bias, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
) |
|
|
|
|
|
def conv_transpose2d( |
|
input, |
|
weight, |
|
bias=None, |
|
stride=1, |
|
padding=0, |
|
output_padding=0, |
|
groups=1, |
|
dilation=1, |
|
): |
|
if could_use_op(input): |
|
return conv2d_gradfix( |
|
transpose=True, |
|
weight_shape=weight.shape, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
groups=groups, |
|
dilation=dilation, |
|
).apply(input, weight, bias) |
|
|
|
return F.conv_transpose2d( |
|
input=input, |
|
weight=weight, |
|
bias=bias, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
dilation=dilation, |
|
groups=groups, |
|
) |
|
|
|
|
|
def could_use_op(input): |
|
if (not enabled) or (not torch.backends.cudnn.enabled): |
|
return False |
|
|
|
if input.device.type != "cuda": |
|
return False |
|
|
|
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9", "1.10"]): |
|
return True |
|
|
|
warnings.warn( |
|
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." |
|
) |
|
|
|
return False |
|
|
|
|
|
def ensure_tuple(xs, ndim): |
|
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim |
|
|
|
return xs |
|
|
|
|
|
conv2d_gradfix_cache = dict() |
|
|
|
|
|
def conv2d_gradfix( |
|
transpose, weight_shape, stride, padding, output_padding, dilation, groups |
|
): |
|
ndim = 2 |
|
weight_shape = tuple(weight_shape) |
|
stride = ensure_tuple(stride, ndim) |
|
padding = ensure_tuple(padding, ndim) |
|
output_padding = ensure_tuple(output_padding, ndim) |
|
dilation = ensure_tuple(dilation, ndim) |
|
|
|
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) |
|
if key in conv2d_gradfix_cache: |
|
return conv2d_gradfix_cache[key] |
|
|
|
common_kwargs = dict( |
|
stride=stride, padding=padding, dilation=dilation, groups=groups |
|
) |
|
|
|
def calc_output_padding(input_shape, output_shape): |
|
if transpose: |
|
return [0, 0] |
|
|
|
return [ |
|
input_shape[i + 2] |
|
- (output_shape[i + 2] - 1) * stride[i] |
|
- (1 - 2 * padding[i]) |
|
- dilation[i] * (weight_shape[i + 2] - 1) |
|
for i in range(ndim) |
|
] |
|
|
|
class Conv2d(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input, weight, bias): |
|
if not transpose: |
|
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) |
|
|
|
else: |
|
out = F.conv_transpose2d( |
|
input=input, |
|
weight=weight, |
|
bias=bias, |
|
output_padding=output_padding, |
|
**common_kwargs, |
|
) |
|
|
|
ctx.save_for_backward(input, weight) |
|
|
|
return out |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input, weight = ctx.saved_tensors |
|
grad_input, grad_weight, grad_bias = None, None, None |
|
|
|
if ctx.needs_input_grad[0]: |
|
p = calc_output_padding( |
|
input_shape=input.shape, output_shape=grad_output.shape |
|
) |
|
grad_input = conv2d_gradfix( |
|
transpose=(not transpose), |
|
weight_shape=weight_shape, |
|
output_padding=p, |
|
**common_kwargs, |
|
).apply(grad_output, weight, None) |
|
|
|
if ctx.needs_input_grad[1] and not weight_gradients_disabled: |
|
grad_weight = Conv2dGradWeight.apply(grad_output, input) |
|
|
|
if ctx.needs_input_grad[2]: |
|
grad_bias = grad_output.sum((0, 2, 3)) |
|
|
|
return grad_input, grad_weight, grad_bias |
|
|
|
class Conv2dGradWeight(autograd.Function): |
|
@staticmethod |
|
def forward(ctx, grad_output, input): |
|
op = torch._C._jit_get_operation( |
|
"aten::cudnn_convolution_backward_weight" |
|
if not transpose |
|
else "aten::cudnn_convolution_transpose_backward_weight" |
|
) |
|
flags = [ |
|
torch.backends.cudnn.benchmark, |
|
torch.backends.cudnn.deterministic, |
|
torch.backends.cudnn.allow_tf32, |
|
] |
|
grad_weight = op( |
|
weight_shape, |
|
grad_output, |
|
input, |
|
padding, |
|
stride, |
|
dilation, |
|
groups, |
|
*flags, |
|
) |
|
ctx.save_for_backward(grad_output, input) |
|
|
|
return grad_weight |
|
|
|
@staticmethod |
|
def backward(ctx, grad_grad_weight): |
|
grad_output, input = ctx.saved_tensors |
|
grad_grad_output, grad_grad_input = None, None |
|
|
|
if ctx.needs_input_grad[0]: |
|
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) |
|
|
|
if ctx.needs_input_grad[1]: |
|
p = calc_output_padding( |
|
input_shape=input.shape, output_shape=grad_output.shape |
|
) |
|
grad_grad_input = conv2d_gradfix( |
|
transpose=(not transpose), |
|
weight_shape=weight_shape, |
|
output_padding=p, |
|
**common_kwargs, |
|
).apply(grad_output, grad_grad_weight, None) |
|
|
|
return grad_grad_output, grad_grad_input |
|
|
|
conv2d_gradfix_cache[key] = Conv2d |
|
|
|
return Conv2d |
|
|