File size: 2,062 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# flake8: noqa
from torch.nn.modules.module import Module
from torch.autograd import Function, Variable
from torch.cuda.amp import autocast
import resample2d_cuda


class Resample2dFunction(Function):

    @staticmethod
    # def forward(ctx, input1, input2, kernel_size=1, bilinear=True):
    def forward(ctx, input1, input2, kernel_size=1):
        assert input1.is_contiguous()
        assert input2.is_contiguous()

        ctx.save_for_backward(input1, input2)
        ctx.kernel_size = kernel_size
        ctx.bilinear = True

        _, d, _, _ = input1.size()
        b, _, h, w = input2.size()
        output = input1.new(b, d, h, w).zero_()

        resample2d_cuda.forward(input1, input2, output, kernel_size)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.contiguous()
        assert grad_output.is_contiguous()

        input1, input2 = ctx.saved_tensors

        grad_input1 = Variable(input1.new(input1.size()).zero_())
        grad_input2 = Variable(input1.new(input2.size()).zero_())

        # resample2d_cuda.backward(input1, input2, grad_output.data,
        #                          grad_input1.data, grad_input2.data,
        #                          ctx.kernel_size, ctx.bilinear)
        resample2d_cuda.backward(input1, input2, grad_output.data,
                                 grad_input1.data, grad_input2.data,
                                 ctx.kernel_size)

        return grad_input1, grad_input2, None, None


class Resample2d(Module):

    def __init__(self, kernel_size=1, bilinear=True):
        super(Resample2d, self).__init__()
        self.kernel_size = kernel_size
        self.bilinear = bilinear

    @autocast(False)
    def forward(self, input1, input2):
        input1, input2 = input1.float(), input2.float()
        input1_c = input1.contiguous()
        # return Resample2dFunction.apply(
        #     input1_c, input2, self.kernel_size, self.bilinear)
        return Resample2dFunction.apply(
            input1_c, input2, self.kernel_size)