File size: 4,487 Bytes
ffbe0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <iostream>

// declarations

torch::Tensor correlation_cpp_forward(
    torch::Tensor input1,
    torch::Tensor input2,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW);

std::vector<torch::Tensor> correlation_cpp_backward(
    torch::Tensor grad_output,
    torch::Tensor input1,
    torch::Tensor input2,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW);

#ifdef USE_CUDA

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_SAME_DEVICE(x, y) TORCH_CHECK(x.device() == y.device(), #x " is not on same device as " #y)

torch::Tensor correlation_cuda_forward(
    torch::Tensor input1,
    torch::Tensor input2,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW);

std::vector<torch::Tensor> correlation_cuda_backward(
    torch::Tensor grad_output,
    torch::Tensor input1,
    torch::Tensor input2,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW);

// C++ interface

torch::Tensor correlation_sample_forward(
    torch::Tensor input1,
    torch::Tensor input2,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW) {
  if (input1.device().is_cuda()){
    CHECK_INPUT(input1);
    CHECK_INPUT(input2);

    // set device of input1 as default CUDA device
    // https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html
    const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
    CHECK_SAME_DEVICE(input1, input2);

    return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW,
                             padH, padW, dilationH, dilationW,
                             dilation_patchH, dilation_patchW,
                             dH, dW);
  }else{
    return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW,
                             padH, padW, dilationH, dilationW,
                             dilation_patchH, dilation_patchW,
                             dH, dW);
  }
}

std::vector<torch::Tensor> correlation_sample_backward(
    torch::Tensor input1,
    torch::Tensor input2,
    torch::Tensor grad_output,
    int kH, int kW,
    int patchH, int patchW,
    int padH, int padW,
    int dilationH, int dilationW,
    int dilation_patchH, int dilation_patchW,
    int dH, int dW) {

  if(grad_output.device().is_cuda()){
    CHECK_INPUT(input1);
    CHECK_INPUT(input2);

    // set device of input1 as default CUDA device
    const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
    CHECK_SAME_DEVICE(input1, input2);
    CHECK_SAME_DEVICE(input1, grad_output);

    return correlation_cuda_backward(input1, input2, grad_output,
                              kH, kW, patchH, patchW,
                              padH, padW,
                              dilationH, dilationW,
                              dilation_patchH, dilation_patchW,
                              dH, dW);
  }else{
    return correlation_cpp_backward(
                              input1, input2, grad_output,
                              kH, kW, patchH, patchW,
                              padH, padW,
                              dilationH, dilationW,
                              dilation_patchH, dilation_patchW,
                              dH, dW);
  }
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward");
  m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward");
}

#else

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward");
  m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward");
}

#endif