source = ''' #include #include #include #include #define CUDA_NUM_THREADS 256 #include #include #include #include #include #include #include template __global__ void forward_kernel( const scalar_t* __restrict__ pixel_features, const scalar_t* __restrict__ spixel_features, const scalar_t* __restrict__ spixel_indices, scalar_t* __restrict__ dist_matrix, int batchsize, int channels, int num_pixels, int num_spixels, int num_spixels_w, int num_spixels_h ){ int index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= batchsize * num_pixels * 9) return; int cp = channels * num_pixels; int cs = channels * num_spixels; int b = index % batchsize; int spixel_offset = (index / batchsize) % 9; int p = (index / (batchsize * 9)) % num_pixels; int init_spix_index = spixel_indices[b * num_pixels + p]; int x_index = init_spix_index % num_spixels_w; int spixel_offset_x = (spixel_offset % 3 - 1); int y_index = init_spix_index / num_spixels_w; int spixel_offset_y = (spixel_offset / 3 - 1); if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) { dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; } else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) { dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; } else { int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; scalar_t sum_squared_diff = 0; for (int c=0; c<<< block, CUDA_NUM_THREADS >>>( pixel_features.data(), spixel_features.data(), spixel_indices.data(), dist_matrix.data(), batchsize, channels, num_pixels, num_spixels, num_spixels_w, num_spixels_h ); })); return dist_matrix; } template __global__ void backward_kernel( const scalar_t* __restrict__ dist_matrix_grad, const scalar_t* __restrict__ pixel_features, const scalar_t* __restrict__ spixel_features, const scalar_t* __restrict__ spixel_indices, scalar_t* __restrict__ pixel_feature_grad, scalar_t* __restrict__ spixel_feature_grad, int batchsize, int channels, int num_pixels, int num_spixels, int num_spixels_w, int num_spixels_h ){ int index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= batchsize * num_pixels * 9) return; int cp = channels * num_pixels; int cs = channels * num_spixels; int b = index % batchsize; int spixel_offset = (index / batchsize) % 9; int p = (index / (batchsize * 9)) % num_pixels; int init_spix_index = spixel_indices[b * num_pixels + p]; int x_index = init_spix_index % num_spixels_w; int spixel_offset_x = (spixel_offset % 3 - 1); int y_index = init_spix_index / num_spixels_w; int spixel_offset_y = (spixel_offset / 3 - 1); if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) return; else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) return; else { int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; scalar_t dist_matrix_grad_val = dist_matrix_grad[b * (9 * num_pixels) + spixel_offset * num_pixels + p]; for (int c=0; c backward_cuda( const torch::Tensor dist_matrix_grad, const torch::Tensor pixel_features, const torch::Tensor spixel_features, const torch::Tensor spixel_indices, torch::Tensor pixel_features_grad, torch::Tensor spixel_features_grad, int num_spixels_w, int num_spixels_h ){ int batchsize = pixel_features.size(0); int channels = pixel_features.size(1); int num_pixels = pixel_features.size(2); int num_spixels = spixel_features.size(2); dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); AT_DISPATCH_FLOATING_TYPES(pixel_features_grad.type(), "backward_kernel", ([&] { backward_kernel<<< block, CUDA_NUM_THREADS >>>( dist_matrix_grad.data(), pixel_features.data(), spixel_features.data(), spixel_indices.data(), pixel_features_grad.data(), spixel_features_grad.data(), batchsize, channels, num_pixels, num_spixels, num_spixels_w, num_spixels_h ); })); return {pixel_features_grad, spixel_features_grad}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &forward_cuda, "pair_wise_distance forward"); m.def("backward", &backward_cuda, "pair_wise_distance backward"); } '''