Spaces:
Paused
Paused
import contextlib | |
import warnings | |
import torch | |
from torch import autograd | |
from torch.nn import functional as F | |
enabled = True | |
weight_gradients_disabled = False | |
def no_weight_gradients(): | |
global weight_gradients_disabled | |
old = weight_gradients_disabled | |
weight_gradients_disabled = True | |
yield | |
weight_gradients_disabled = old | |
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): | |
if could_use_op(input): | |
return conv2d_gradfix( | |
transpose=False, | |
weight_shape=weight.shape, | |
stride=stride, | |
padding=padding, | |
output_padding=0, | |
dilation=dilation, | |
groups=groups, | |
).apply(input, weight, bias) | |
return F.conv2d( | |
input=input, | |
weight=weight, | |
bias=bias, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
) | |
def conv_transpose2d( | |
input, | |
weight, | |
bias=None, | |
stride=1, | |
padding=0, | |
output_padding=0, | |
groups=1, | |
dilation=1, | |
): | |
if could_use_op(input): | |
return conv2d_gradfix( | |
transpose=True, | |
weight_shape=weight.shape, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
groups=groups, | |
dilation=dilation, | |
).apply(input, weight, bias) | |
return F.conv_transpose2d( | |
input=input, | |
weight=weight, | |
bias=bias, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
dilation=dilation, | |
groups=groups, | |
) | |
def could_use_op(input): | |
if (not enabled) or (not torch.backends.cudnn.enabled): | |
return False | |
if input.device.type != "cuda": | |
return False | |
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): | |
return True | |
warnings.warn( | |
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." | |
) | |
return False | |
def ensure_tuple(xs, ndim): | |
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim | |
return xs | |
conv2d_gradfix_cache = dict() | |
def conv2d_gradfix( | |
transpose, weight_shape, stride, padding, output_padding, dilation, groups | |
): | |
ndim = 2 | |
weight_shape = tuple(weight_shape) | |
stride = ensure_tuple(stride, ndim) | |
padding = ensure_tuple(padding, ndim) | |
output_padding = ensure_tuple(output_padding, ndim) | |
dilation = ensure_tuple(dilation, ndim) | |
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) | |
if key in conv2d_gradfix_cache: | |
return conv2d_gradfix_cache[key] | |
common_kwargs = dict( | |
stride=stride, padding=padding, dilation=dilation, groups=groups | |
) | |
def calc_output_padding(input_shape, output_shape): | |
if transpose: | |
return [0, 0] | |
return [ | |
input_shape[i + 2] | |
- (output_shape[i + 2] - 1) * stride[i] | |
- (1 - 2 * padding[i]) | |
- dilation[i] * (weight_shape[i + 2] - 1) | |
for i in range(ndim) | |
] | |
class Conv2d(autograd.Function): | |
def forward(ctx, input, weight, bias): | |
if not transpose: | |
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) | |
else: | |
out = F.conv_transpose2d( | |
input=input, | |
weight=weight, | |
bias=bias, | |
output_padding=output_padding, | |
**common_kwargs, | |
) | |
ctx.save_for_backward(input, weight) | |
return out | |
def backward(ctx, grad_output): | |
input, weight = ctx.saved_tensors | |
grad_input, grad_weight, grad_bias = None, None, None | |
if ctx.needs_input_grad[0]: | |
p = calc_output_padding( | |
input_shape=input.shape, output_shape=grad_output.shape | |
) | |
grad_input = conv2d_gradfix( | |
transpose=(not transpose), | |
weight_shape=weight_shape, | |
output_padding=p, | |
**common_kwargs, | |
).apply(grad_output, weight, None) | |
if ctx.needs_input_grad[1] and not weight_gradients_disabled: | |
grad_weight = Conv2dGradWeight.apply(grad_output, input) | |
if ctx.needs_input_grad[2]: | |
grad_bias = grad_output.sum((0, 2, 3)) | |
return grad_input, grad_weight, grad_bias | |
class Conv2dGradWeight(autograd.Function): | |
def forward(ctx, grad_output, input): | |
op = torch._C._jit_get_operation( | |
"aten::cudnn_convolution_backward_weight" | |
if not transpose | |
else "aten::cudnn_convolution_transpose_backward_weight" | |
) | |
flags = [ | |
torch.backends.cudnn.benchmark, | |
torch.backends.cudnn.deterministic, | |
torch.backends.cudnn.allow_tf32, | |
] | |
grad_weight = op( | |
weight_shape, | |
grad_output, | |
input, | |
padding, | |
stride, | |
dilation, | |
groups, | |
*flags, | |
) | |
ctx.save_for_backward(grad_output, input) | |
return grad_weight | |
def backward(ctx, grad_grad_weight): | |
grad_output, input = ctx.saved_tensors | |
grad_grad_output, grad_grad_input = None, None | |
if ctx.needs_input_grad[0]: | |
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) | |
if ctx.needs_input_grad[1]: | |
p = calc_output_padding( | |
input_shape=input.shape, output_shape=grad_output.shape | |
) | |
grad_grad_input = conv2d_gradfix( | |
transpose=(not transpose), | |
weight_shape=weight_shape, | |
output_padding=p, | |
**common_kwargs, | |
).apply(grad_output, grad_grad_weight, None) | |
return grad_grad_output, grad_grad_input | |
conv2d_gradfix_cache[key] = Conv2d | |
return Conv2d | |