#include #include #include #include using namespace torch::indexing; #define THREADS 256 #define BLOCKS(n) (n + THREADS - 1) / THREADS __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) { return h >= 0 && h < H && w >= 0 && w < W; } template __global__ void patchify_forward_kernel(int R, const torch::PackedTensorAccessor32 net, const torch::PackedTensorAccessor32 coords, torch::PackedTensorAccessor32 patches) { // diameter const int D = 2*R + 2; const int B = coords.size(0); const int M = coords.size(1); const int C = net.size(1); const int H = net.size(2); const int W = net.size(3); int n = blockIdx.x * blockDim.x + threadIdx.x; if (n < B * M * D * D) { const int ii = n % D; n /= D; const int jj = n % D; n /= D; const int m = n % M; n /= M; const float x = coords[n][m][0]; const float y = coords[n][m][1]; const int i = static_cast(floor(y)) + (ii - R); const int j = static_cast(floor(x)) + (jj - R); if (within_bounds(i, j, H, W)) { for (int k=0; k __global__ void patchify_backward_kernel(int R, const torch::PackedTensorAccessor32 patch_gradient, const torch::PackedTensorAccessor32 coords, torch::PackedTensorAccessor32 gradient) { // diameter const int D = 2*R + 2; const int B = coords.size(0); const int M = coords.size(1); const int C = gradient.size(1); const int H = gradient.size(2); const int W = gradient.size(3); int n = blockIdx.x * blockDim.x + threadIdx.x; if (n < B * M * D * D) { const int ii = n % D; n /= D; const int jj = n % D; n /= D; const int m = n % M; n /= M; const float x = coords[n][m][0]; const float y = coords[n][m][1]; const int i = static_cast(floor(y)) + (ii - R); const int j = static_cast(floor(x)) + (jj - R); if (within_bounds(i, j, H, W)) { for (int k=0; k __global__ void corr_forward_kernel(int R, const torch::PackedTensorAccessor32 fmap1, const torch::PackedTensorAccessor32 fmap2, const torch::PackedTensorAccessor32 coords, const torch::PackedTensorAccessor32 us, const torch::PackedTensorAccessor32 vs, torch::PackedTensorAccessor32 corr) { // diameter const int D = 2*R + 2; const int B = coords.size(0); const int M = coords.size(1); const int H = coords.size(3); const int W = coords.size(4); const int C = fmap1.size(2); const int H2 = fmap2.size(3); const int W2 = fmap2.size(4); int n = blockIdx.x * blockDim.x + threadIdx.x; if (n < B * M * H * W * D * D) { const int jj = n % D; n /= D; const int ii = n % D; n /= D; const int j0 = n % W; n /= W; const int i0 = n % H; n /= H; const int m = n % M; n /= M; const int ix = us[m]; const int jx = vs[m]; const float x = coords[n][m][0][i0][j0]; const float y = coords[n][m][1][i0][j0]; const int i1 = static_cast(floor(y)) + (ii - R); const int j1 = static_cast(floor(x)) + (jj - R); scalar_t s = 0; if (within_bounds(i1, j1, H2, W2)) { #pragma unroll 8 for (int i=0; i __global__ void corr_backward_kernel(int R, const torch::PackedTensorAccessor32 fmap1, const torch::PackedTensorAccessor32 fmap2, const torch::PackedTensorAccessor32 coords, const torch::PackedTensorAccessor32 us, const torch::PackedTensorAccessor32 vs, const torch::PackedTensorAccessor32 corr_grad, torch::PackedTensorAccessor32 fmap1_grad, torch::PackedTensorAccessor32 fmap2_grad) { // diameter const int D = 2*R + 2; const int B = coords.size(0); const int M = coords.size(1); const int H = coords.size(3); const int W = coords.size(4); const int C = fmap1.size(2); const int H2 = fmap2.size(3); const int W2 = fmap2.size(4); int n = blockIdx.x * blockDim.x + threadIdx.x; if (n < B * M * H * W * D * D) { const int jj = n % D; n /= D; const int ii = n % D; n /= D; const int j0 = n % W; n /= W; const int i0 = n % H; n /= H; const int m = n % M; n /= M; const int ix = us[m]; const int jx = vs[m]; const float x = coords[n][m][0][i0][j0]; const float y = coords[n][m][1][i0][j0]; const int i1 = static_cast(floor(y)) + (ii - R); const int j1 = static_cast(floor(x)) + (jj - R); const scalar_t g = (scalar_t) corr_grad[n][m][ii][jj][i0][j0]; if (within_bounds(i1, j1, H2, W2)) { #pragma unroll 32 for (int i=0; i corr_cuda_forward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, torch::Tensor ii, torch::Tensor jj, int radius) { const int B = coords.size(0); const int M = coords.size(1); const int H = coords.size(3); const int W = coords.size(4); const int D = 2 * radius + 2; auto opts = fmap1.options(); auto corr = torch::empty({B, M, D, D, H, W}, opts); AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_forward_kernel", ([&] { corr_forward_kernel<<>>(radius, fmap1.packed_accessor32(), fmap2.packed_accessor32(), coords.packed_accessor32(), ii.packed_accessor32(), jj.packed_accessor32(), corr.packed_accessor32()); })); torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None}); torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None}); torch::Tensor dx = x - x.floor(); dx = dx.to(fmap1.dtype()); torch::Tensor dy = y - y.floor(); dy = dy.to(fmap2.dtype()); torch::Tensor out; out = (1 - dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}); out += (dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}); out += (1 - dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}); out += (dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}); return { out.permute({0,1,3,2,4,5}) }; } std::vector corr_cuda_backward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, torch::Tensor ii, torch::Tensor jj, torch::Tensor grad, int radius) { const int B = coords.size(0); const int M = coords.size(1); const int H = coords.size(3); const int W = coords.size(4); const int D = 2 * radius + 2; grad = grad.permute({0,1,3,2,4,5}).contiguous(); torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None}); torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None}); torch::Tensor dx = x - x.floor(); torch::Tensor dy = y - y.floor(); auto opts = torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA); torch::Tensor g1 = torch::zeros({B, M, D, D, H, W}, grad.options()); torch::Tensor g2 = torch::zeros({B, M, D, D, H, W}, grad.options()); torch::Tensor g3 = torch::zeros({B, M, D, D, H, W}, grad.options()); torch::Tensor g4 = torch::zeros({B, M, D, D, H, W}, grad.options()); g1.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}, (1 - dx) * (1 - dy) * grad); g2.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}, (dx) * (1 - dy) * grad); g3.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}, (1 - dx) * (dy) * grad); g4.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}, (dx) * (dy) * grad); torch::Tensor corr_grad = g1 + g2 + g3 + g4; auto fmap1_grad = torch::zeros_like(fmap1); auto fmap2_grad = torch::zeros_like(fmap2); AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_backward_kernel", ([&] { corr_backward_kernel<<>>(radius, fmap1.packed_accessor32(), fmap2.packed_accessor32(), coords.packed_accessor32(), ii.packed_accessor32(), jj.packed_accessor32(), corr_grad.packed_accessor32(), fmap1_grad.packed_accessor32(), fmap2_grad.packed_accessor32()); })); return {fmap1_grad, fmap2_grad}; } std::vector patchify_cuda_forward( torch::Tensor net, torch::Tensor coords, int radius) { const int B = coords.size(0); const int M = coords.size(1); const int C = net.size(1); const int D = 2 * radius + 2; auto opts = net.options(); auto patches = torch::zeros({B, M, C, D, D}, opts); AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_forward_kernel", ([&] { patchify_forward_kernel<<>>(radius, net.packed_accessor32(), coords.packed_accessor32(), patches.packed_accessor32()); })); return { patches }; } std::vector patchify_cuda_backward( torch::Tensor net, torch::Tensor coords, torch::Tensor gradient, int radius) { const int B = coords.size(0); const int M = coords.size(1); const int C = net.size(1); const int H = net.size(2); const int W = net.size(3); const int D = 2 * radius + 2; torch::Tensor net_gradient = torch::zeros_like(net); AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_backward_kernel", ([&] { patchify_backward_kernel<<>>(radius, gradient.packed_accessor32(), coords.packed_accessor32(), net_gradient.packed_accessor32()); })); return { net_gradient }; }