#include #include #include #include #define BLOCK_H 4 #define BLOCK_W 8 #define BLOCK_HW BLOCK_H * BLOCK_W #define CHANNEL_STRIDE 32 __forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) { return h >= 0 && h < H && w >= 0 && w < W; } template __global__ void corr_forward_kernel( const torch::PackedTensorAccessor32 fmap1, const torch::PackedTensorAccessor32 fmap2, const torch::PackedTensorAccessor32 coords, torch::PackedTensorAccessor32 corr, int r) { const int b = blockIdx.x; const int h0 = blockIdx.y * blockDim.x; const int w0 = blockIdx.z * blockDim.y; const int tid = threadIdx.x * blockDim.y + threadIdx.y; const int H1 = fmap1.size(1); const int W1 = fmap1.size(2); const int H2 = fmap2.size(1); const int W2 = fmap2.size(2); const int N = coords.size(1); const int C = fmap1.size(3); __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t x2s[BLOCK_HW]; __shared__ scalar_t y2s[BLOCK_HW]; for (int c=0; c(floor(y2s[k1]))-r+iy; int w2 = static_cast(floor(x2s[k1]))-r+ix; int c2 = tid % CHANNEL_STRIDE; auto fptr = fmap2[b][h2][w2]; if (within_bounds(h2, w2, H2, W2)) f2[c2][k1] = fptr[c+c2]; else f2[c2][k1] = 0.0; } __syncthreads(); scalar_t s = 0.0; for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) *(corr_ptr + ix_nw) += nw; if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) *(corr_ptr + ix_ne) += ne; if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) *(corr_ptr + ix_sw) += sw; if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) *(corr_ptr + ix_se) += se; } } } } } template __global__ void corr_backward_kernel( const torch::PackedTensorAccessor32 fmap1, const torch::PackedTensorAccessor32 fmap2, const torch::PackedTensorAccessor32 coords, const torch::PackedTensorAccessor32 corr_grad, torch::PackedTensorAccessor32 fmap1_grad, torch::PackedTensorAccessor32 fmap2_grad, torch::PackedTensorAccessor32 coords_grad, int r) { const int b = blockIdx.x; const int h0 = blockIdx.y * blockDim.x; const int w0 = blockIdx.z * blockDim.y; const int tid = threadIdx.x * blockDim.y + threadIdx.y; const int H1 = fmap1.size(1); const int W1 = fmap1.size(2); const int H2 = fmap2.size(1); const int W2 = fmap2.size(2); const int N = coords.size(1); const int C = fmap1.size(3); __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; __shared__ scalar_t x2s[BLOCK_HW]; __shared__ scalar_t y2s[BLOCK_HW]; for (int c=0; c(floor(y2s[k1]))-r+iy; int w2 = static_cast(floor(x2s[k1]))-r+ix; int c2 = tid % CHANNEL_STRIDE; auto fptr = fmap2[b][h2][w2]; if (within_bounds(h2, w2, H2, W2)) f2[c2][k1] = fptr[c+c2]; else f2[c2][k1] = 0.0; f2_grad[c2][k1] = 0.0; } __syncthreads(); const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; scalar_t g = 0.0; int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); int ix_ne = H1*W1*((iy-1) + rd*ix); int ix_sw = H1*W1*(iy + rd*(ix-1)); int ix_se = H1*W1*(iy + rd*ix); if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) g += *(grad_ptr + ix_nw) * dy * dx; if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) g += *(grad_ptr + ix_ne) * dy * (1-dx); if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) g += *(grad_ptr + ix_sw) * (1-dy) * dx; if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); for (int k=0; k(floor(y2s[k1]))-r+iy; int w2 = static_cast(floor(x2s[k1]))-r+ix; int c2 = tid % CHANNEL_STRIDE; scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; if (within_bounds(h2, w2, H2, W2)) atomicAdd(fptr+c+c2, f2_grad[c2][k1]); } } } } __syncthreads(); for (int k=0; k corr_cuda_forward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, int radius) { const auto B = coords.size(0); const auto N = coords.size(1); const auto H = coords.size(2); const auto W = coords.size(3); const auto rd = 2 * radius + 1; auto opts = fmap1.options(); auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); const dim3 threads(BLOCK_H, BLOCK_W); corr_forward_kernel<<>>( fmap1.packed_accessor32(), fmap2.packed_accessor32(), coords.packed_accessor32(), corr.packed_accessor32(), radius); return {corr}; } std::vector corr_cuda_backward( torch::Tensor fmap1, torch::Tensor fmap2, torch::Tensor coords, torch::Tensor corr_grad, int radius) { const auto B = coords.size(0); const auto N = coords.size(1); const auto H1 = fmap1.size(1); const auto W1 = fmap1.size(2); const auto H2 = fmap2.size(1); const auto W2 = fmap2.size(2); const auto C = fmap1.size(3); auto opts = fmap1.options(); auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); const dim3 threads(BLOCK_H, BLOCK_W); corr_backward_kernel<<>>( fmap1.packed_accessor32(), fmap2.packed_accessor32(), coords.packed_accessor32(), corr_grad.packed_accessor32(), fmap1_grad.packed_accessor32(), fmap2_grad.packed_accessor32(), coords_grad.packed_accessor32(), radius); return {fmap1_grad, fmap2_grad, coords_grad}; }