import contextlib import warnings import torch from torch import autograd from torch.nn import functional as F enabled = True weight_gradients_disabled = False @contextlib.contextmanager def no_weight_gradients(): global weight_gradients_disabled old = weight_gradients_disabled weight_gradients_disabled = True yield weight_gradients_disabled = old def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if could_use_op(input): return conv2d_gradfix( transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups, ).apply(input, weight, bias) return F.conv2d( input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, ) def conv_transpose2d( input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, ): if could_use_op(input): return conv2d_gradfix( transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation, ).apply(input, weight, bias) return F.conv_transpose2d( input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, ) def could_use_op(input): return False if (not enabled) or (not torch.backends.cudnn.enabled): return False if input.device.type != "cuda": return False if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): return True warnings.warn( f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." ) return False def ensure_tuple(xs, ndim): xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim return xs conv2d_gradfix_cache = dict() def conv2d_gradfix( transpose, weight_shape, stride, padding, output_padding, dilation, groups ): ndim = 2 weight_shape = tuple(weight_shape) stride = ensure_tuple(stride, ndim) padding = ensure_tuple(padding, ndim) output_padding = ensure_tuple(output_padding, ndim) dilation = ensure_tuple(dilation, ndim) key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) if key in conv2d_gradfix_cache: return conv2d_gradfix_cache[key] common_kwargs = dict( stride=stride, padding=padding, dilation=dilation, groups=groups ) def calc_output_padding(input_shape, output_shape): if transpose: return [0, 0] return [ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) ] class Conv2d(autograd.Function): @staticmethod def forward(ctx, input, weight, bias): if not transpose: out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) else: out = F.conv_transpose2d( input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs, ) ctx.save_for_backward(input, weight) return out @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors grad_input, grad_weight, grad_bias = None, None, None if ctx.needs_input_grad[0]: p = calc_output_padding( input_shape=input.shape, output_shape=grad_output.shape ) grad_input = conv2d_gradfix( transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs, ).apply(grad_output, weight, None) if ctx.needs_input_grad[1] and not weight_gradients_disabled: grad_weight = Conv2dGradWeight.apply(grad_output, input) if ctx.needs_input_grad[2]: grad_bias = grad_output.sum((0, 2, 3)) return grad_input, grad_weight, grad_bias class Conv2dGradWeight(autograd.Function): @staticmethod def forward(ctx, grad_output, input): op = torch._C._jit_get_operation( "aten::cudnn_convolution_backward_weight" if not transpose else "aten::cudnn_convolution_transpose_backward_weight" ) flags = [ torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32, ] grad_weight = op( weight_shape, grad_output, input, padding, stride, dilation, groups, *flags, ) ctx.save_for_backward(grad_output, input) return grad_weight @staticmethod def backward(ctx, grad_grad_weight): grad_output, input = ctx.saved_tensors grad_grad_output, grad_grad_input = None, None if ctx.needs_input_grad[0]: grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) if ctx.needs_input_grad[1]: p = calc_output_padding( input_shape=input.shape, output_shape=grad_output.shape ) grad_grad_input = conv2d_gradfix( transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs, ).apply(grad_output, grad_grad_weight, None) return grad_grad_output, grad_grad_input conv2d_gradfix_cache[key] = Conv2d return Conv2d