Spaces:
Running
Running
File size: 4,487 Bytes
ffbe0b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <iostream>
// declarations
torch::Tensor correlation_cpp_forward(
torch::Tensor input1,
torch::Tensor input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
std::vector<torch::Tensor> correlation_cpp_backward(
torch::Tensor grad_output,
torch::Tensor input1,
torch::Tensor input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
#ifdef USE_CUDA
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_SAME_DEVICE(x, y) TORCH_CHECK(x.device() == y.device(), #x " is not on same device as " #y)
torch::Tensor correlation_cuda_forward(
torch::Tensor input1,
torch::Tensor input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
std::vector<torch::Tensor> correlation_cuda_backward(
torch::Tensor grad_output,
torch::Tensor input1,
torch::Tensor input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW);
// C++ interface
torch::Tensor correlation_sample_forward(
torch::Tensor input1,
torch::Tensor input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW) {
if (input1.device().is_cuda()){
CHECK_INPUT(input1);
CHECK_INPUT(input2);
// set device of input1 as default CUDA device
// https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
CHECK_SAME_DEVICE(input1, input2);
return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
}else{
return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
}
}
std::vector<torch::Tensor> correlation_sample_backward(
torch::Tensor input1,
torch::Tensor input2,
torch::Tensor grad_output,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW) {
if(grad_output.device().is_cuda()){
CHECK_INPUT(input1);
CHECK_INPUT(input2);
// set device of input1 as default CUDA device
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
CHECK_SAME_DEVICE(input1, input2);
CHECK_SAME_DEVICE(input1, grad_output);
return correlation_cuda_backward(input1, input2, grad_output,
kH, kW, patchH, patchW,
padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
}else{
return correlation_cpp_backward(
input1, input2, grad_output,
kH, kW, patchH, patchW,
padH, padW,
dilationH, dilationW,
dilation_patchH, dilation_patchW,
dH, dW);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward");
m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward");
}
#else
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward");
m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward");
}
#endif
|