import torch from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair from maskrcnn_benchmark import _C class DeformConvFunction(Function): @staticmethod def forward( ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, im2col_step=64 ): if input is not None and input.dim() != 4: raise ValueError( "Expected 4D tensor as input, got {}D tensor instead.".format( input.dim())) ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.groups = groups ctx.deformable_groups = deformable_groups ctx.im2col_step = im2col_step ctx.save_for_backward(input, offset, weight) output = input.new_empty( DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones if not input.is_cuda: raise NotImplementedError else: cur_im2col_step = min(ctx.im2col_step, input.shape[0]) assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' _C.deform_conv_forward( input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, ctx.deformable_groups, cur_im2col_step ) return output @staticmethod @once_differentiable def backward(ctx, grad_output): input, offset, weight = ctx.saved_tensors grad_input = grad_offset = grad_weight = None if not grad_output.is_cuda: raise NotImplementedError else: cur_im2col_step = min(ctx.im2col_step, input.shape[0]) assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) _C.deform_conv_backward_input( input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, ctx.deformable_groups, cur_im2col_step ) if ctx.needs_input_grad[2]: grad_weight = torch.zeros_like(weight) _C.deform_conv_backward_parameters( input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, cur_im2col_step ) return (grad_input, grad_offset, grad_weight, None, None, None, None, None) @staticmethod def _output_size(input, weight, padding, dilation, stride): channels = weight.size(0) output_size = (input.size(0), channels) for d in range(input.dim() - 2): in_size = input.size(d + 2) pad = padding[d] kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 stride_ = stride[d] output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) if not all(map(lambda s: s > 0, output_size)): raise ValueError( "convolution input is too small (output would be {})".format( 'x'.join(map(str, output_size)))) return output_size class ModulatedDeformConvFunction(Function): @staticmethod def forward( ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1 ): ctx.stride = stride ctx.padding = padding ctx.dilation = dilation ctx.groups = groups ctx.deformable_groups = deformable_groups ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(1) # fake tensor if not input.is_cuda: raise NotImplementedError if weight.requires_grad or mask.requires_grad or offset.requires_grad \ or input.requires_grad: ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) ctx._bufs = [input.new_empty(0), input.new_empty(0)] _C.modulated_deform_conv_forward( input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias ) return output @staticmethod @once_differentiable def backward(ctx, grad_output): if not grad_output.is_cuda: raise NotImplementedError input, offset, mask, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) grad_mask = torch.zeros_like(mask) grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros_like(bias) _C.modulated_deform_conv_backward( input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias ) if not ctx.with_bias: grad_bias = None return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) @staticmethod def _infer_shape(ctx, input, weight): n = input.size(0) channels_out = weight.size(0) height, width = input.shape[2:4] kernel_h, kernel_w = weight.shape[2:4] height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 return n, channels_out, height_out, width_out deform_conv = DeformConvFunction.apply modulated_deform_conv = ModulatedDeformConvFunction.apply