#include using namespace torch; #include #define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) template static void correlate_patch( TensorAccessor input1, TensorAccessor input2, scalar_t *dst, int kH, int kW, int dilationH, int dilationW, int u, int v, int shiftU, int shiftV){ const int C = input1.size(0); const int iH = input1.size(1); const int iW = input1.size(2); for (int c=0; c static void correlate_patch_grad( TensorAccessor input1, TensorAccessor gradInput1, TensorAccessor input2, TensorAccessor gradInput2, scalar_t gradOutput, int kH, int kW, int dilationH, int dilationW, int u, int v, int shiftU, int shiftV){ const int C = input1.size(0); const int iH = input1.size(1); const int iW = input1.size(2); for (int c=0; c(); auto input2_acc = input2.accessor(); auto output_acc = output.accessor(); for (h = 0; h < oH; ++h) { for (w = 0; w < oW; ++w) { correlate_patch(input1_acc[n], input2_acc[n], &output_acc[n][ph][pw][h][w], kH, kW, dilationH, dilationW, -padH + h * dH, -padW + w * dW, (ph - patchRadH) * dilation_patchH, (pw - patchRadW) * dilation_patchW); } } })); } } } return output; } std::vector correlation_cpp_backward( torch::Tensor input1, torch::Tensor input2, torch::Tensor gradOutput, int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH, int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) { const int batch_size = input1.size(0); const int patchRadH = (patchH - 1) / 2; const int patchRadW = (patchW - 1) / 2; const int oH = gradOutput.size(3); const int oW = gradOutput.size(4); auto gradInput1 = torch::zeros_like(input1); auto gradInput2 = torch::zeros_like(input2); int n, ph, pw, h, w; #pragma omp parallel for private(n, ph, pw, h, w) for (n = 0; n < batch_size; ++n) { AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_backward_cpp", ([&] { auto input1_acc = input1.accessor(); auto gradInput1_acc = gradInput1.accessor(); auto input2_acc = input2.accessor(); auto gradInput2_acc = gradInput2.accessor(); auto gradOutput_acc = gradOutput.accessor(); for(ph = 0; ph < patchH; ++ph){ for(pw = 0; pw < patchW; ++pw){ for (h = 0; h < oH; ++h) { for (w = 0; w < oW; ++w) { correlate_patch_grad(input1_acc[n], gradInput1_acc[n], input2_acc[n], gradInput2_acc[n], gradOutput_acc[n][ph][pw][h][w], kH, kW, dilationH, dilationW, -padH + h * dH, -padW + w * dW, (ph - patchRadH) * dilation_patchH, (pw - patchRadW) * dilation_patchW); } } } } })); } return {gradInput1, gradInput2}; }