Spaces:
Running
Running
// declarations | |
torch::Tensor correlation_cpp_forward( | |
torch::Tensor input1, | |
torch::Tensor input2, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW); | |
std::vector<torch::Tensor> correlation_cpp_backward( | |
torch::Tensor grad_output, | |
torch::Tensor input1, | |
torch::Tensor input2, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW); | |
torch::Tensor correlation_cuda_forward( | |
torch::Tensor input1, | |
torch::Tensor input2, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW); | |
std::vector<torch::Tensor> correlation_cuda_backward( | |
torch::Tensor grad_output, | |
torch::Tensor input1, | |
torch::Tensor input2, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW); | |
// C++ interface | |
torch::Tensor correlation_sample_forward( | |
torch::Tensor input1, | |
torch::Tensor input2, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW) { | |
if (input1.device().is_cuda()){ | |
CHECK_INPUT(input1); | |
CHECK_INPUT(input2); | |
// set device of input1 as default CUDA device | |
// https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html | |
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1)); | |
CHECK_SAME_DEVICE(input1, input2); | |
return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW, | |
padH, padW, dilationH, dilationW, | |
dilation_patchH, dilation_patchW, | |
dH, dW); | |
}else{ | |
return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW, | |
padH, padW, dilationH, dilationW, | |
dilation_patchH, dilation_patchW, | |
dH, dW); | |
} | |
} | |
std::vector<torch::Tensor> correlation_sample_backward( | |
torch::Tensor input1, | |
torch::Tensor input2, | |
torch::Tensor grad_output, | |
int kH, int kW, | |
int patchH, int patchW, | |
int padH, int padW, | |
int dilationH, int dilationW, | |
int dilation_patchH, int dilation_patchW, | |
int dH, int dW) { | |
if(grad_output.device().is_cuda()){ | |
CHECK_INPUT(input1); | |
CHECK_INPUT(input2); | |
// set device of input1 as default CUDA device | |
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1)); | |
CHECK_SAME_DEVICE(input1, input2); | |
CHECK_SAME_DEVICE(input1, grad_output); | |
return correlation_cuda_backward(input1, input2, grad_output, | |
kH, kW, patchH, patchW, | |
padH, padW, | |
dilationH, dilationW, | |
dilation_patchH, dilation_patchW, | |
dH, dW); | |
}else{ | |
return correlation_cpp_backward( | |
input1, input2, grad_output, | |
kH, kW, patchH, patchW, | |
padH, padW, | |
dilationH, dilationW, | |
dilation_patchH, dilation_patchW, | |
dH, dW); | |
} | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward"); | |
m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward"); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward"); | |
m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward"); | |
} | |