Spaces:
Runtime error
Runtime error
source = ''' | |
#include <stdio.h> | |
#include <math.h> | |
#include <cuda.h> | |
#include <cuda_runtime.h> | |
#define CUDA_NUM_THREADS 256 | |
#include <torch/extension.h> | |
#include <torch/types.h> | |
#include <ATen/core/TensorAccessor.h> | |
#include <ATen/cuda/CUDAContext.h> | |
#include <THC/THC.h> | |
#include <THC/THCAtomics.cuh> | |
#include <THC/THCDeviceUtils.cuh> | |
template <typename scalar_t> | |
__global__ void forward_kernel( | |
const scalar_t* __restrict__ pixel_features, | |
const scalar_t* __restrict__ spixel_features, | |
const scalar_t* __restrict__ spixel_indices, | |
scalar_t* __restrict__ dist_matrix, | |
int batchsize, int channels, int num_pixels, int num_spixels, | |
int num_spixels_w, int num_spixels_h | |
){ | |
int index = blockIdx.x * blockDim.x + threadIdx.x; | |
if (index >= batchsize * num_pixels * 9) return; | |
int cp = channels * num_pixels; | |
int cs = channels * num_spixels; | |
int b = index % batchsize; | |
int spixel_offset = (index / batchsize) % 9; | |
int p = (index / (batchsize * 9)) % num_pixels; | |
int init_spix_index = spixel_indices[b * num_pixels + p]; | |
int x_index = init_spix_index % num_spixels_w; | |
int spixel_offset_x = (spixel_offset % 3 - 1); | |
int y_index = init_spix_index / num_spixels_w; | |
int spixel_offset_y = (spixel_offset / 3 - 1); | |
if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) { | |
dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; | |
} | |
else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) { | |
dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; | |
} | |
else { | |
int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; | |
scalar_t sum_squared_diff = 0; | |
for (int c=0; c<channels; c++) | |
{ | |
sum_squared_diff += pow(pixel_features[b * cp + c * num_pixels + p] - | |
spixel_features[b * cs + c * num_spixels + query_spixel_index], 2); | |
} | |
dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = sum_squared_diff; | |
} | |
} | |
torch::Tensor forward_cuda( | |
const torch::Tensor pixel_features, | |
const torch::Tensor spixel_features, | |
const torch::Tensor spixel_indices, | |
torch::Tensor dist_matrix, | |
int num_spixels_w, int num_spixels_h | |
){ | |
int batchsize = pixel_features.size(0); | |
int channels = pixel_features.size(1); | |
int num_pixels = pixel_features.size(2); | |
int num_spixels = spixel_features.size(2); | |
dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); | |
AT_DISPATCH_FLOATING_TYPES(dist_matrix.type(), "forward_kernel", ([&] { | |
forward_kernel<scalar_t><<< block, CUDA_NUM_THREADS >>>( | |
pixel_features.data<scalar_t>(), | |
spixel_features.data<scalar_t>(), | |
spixel_indices.data<scalar_t>(), | |
dist_matrix.data<scalar_t>(), | |
batchsize, channels, num_pixels, | |
num_spixels, num_spixels_w, num_spixels_h | |
); | |
})); | |
return dist_matrix; | |
} | |
template <typename scalar_t> | |
__global__ void backward_kernel( | |
const scalar_t* __restrict__ dist_matrix_grad, | |
const scalar_t* __restrict__ pixel_features, | |
const scalar_t* __restrict__ spixel_features, | |
const scalar_t* __restrict__ spixel_indices, | |
scalar_t* __restrict__ pixel_feature_grad, | |
scalar_t* __restrict__ spixel_feature_grad, | |
int batchsize, int channels, int num_pixels, int num_spixels, | |
int num_spixels_w, int num_spixels_h | |
){ | |
int index = blockIdx.x * blockDim.x + threadIdx.x; | |
if (index >= batchsize * num_pixels * 9) return; | |
int cp = channels * num_pixels; | |
int cs = channels * num_spixels; | |
int b = index % batchsize; | |
int spixel_offset = (index / batchsize) % 9; | |
int p = (index / (batchsize * 9)) % num_pixels; | |
int init_spix_index = spixel_indices[b * num_pixels + p]; | |
int x_index = init_spix_index % num_spixels_w; | |
int spixel_offset_x = (spixel_offset % 3 - 1); | |
int y_index = init_spix_index / num_spixels_w; | |
int spixel_offset_y = (spixel_offset / 3 - 1); | |
if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) return; | |
else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) return; | |
else { | |
int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; | |
scalar_t dist_matrix_grad_val = dist_matrix_grad[b * (9 * num_pixels) + spixel_offset * num_pixels + p]; | |
for (int c=0; c<channels; c++) | |
{ | |
scalar_t pix_value = pixel_features[b * cp + c * num_pixels + p]; | |
scalar_t spix_value = spixel_features[b * cs + c * num_spixels + query_spixel_index]; | |
scalar_t diff = (pix_value - spix_value) * dist_matrix_grad_val; | |
atomicAdd(&pixel_feature_grad[b * cp + c * num_pixels + p], 2 * diff); | |
atomicAdd(&spixel_feature_grad[b * cs + c * num_spixels + query_spixel_index], -2 * diff); | |
} | |
} | |
} | |
std::vector<torch::Tensor> backward_cuda( | |
const torch::Tensor dist_matrix_grad, | |
const torch::Tensor pixel_features, | |
const torch::Tensor spixel_features, | |
const torch::Tensor spixel_indices, | |
torch::Tensor pixel_features_grad, | |
torch::Tensor spixel_features_grad, | |
int num_spixels_w, int num_spixels_h | |
){ | |
int batchsize = pixel_features.size(0); | |
int channels = pixel_features.size(1); | |
int num_pixels = pixel_features.size(2); | |
int num_spixels = spixel_features.size(2); | |
dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); | |
AT_DISPATCH_FLOATING_TYPES(pixel_features_grad.type(), "backward_kernel", ([&] { | |
backward_kernel<scalar_t><<< block, CUDA_NUM_THREADS >>>( | |
dist_matrix_grad.data<scalar_t>(), | |
pixel_features.data<scalar_t>(), | |
spixel_features.data<scalar_t>(), | |
spixel_indices.data<scalar_t>(), | |
pixel_features_grad.data<scalar_t>(), | |
spixel_features_grad.data<scalar_t>(), | |
batchsize, channels, num_pixels, | |
num_spixels, num_spixels_w, num_spixels_h | |
); | |
})); | |
return {pixel_features_grad, spixel_features_grad}; | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &forward_cuda, "pair_wise_distance forward"); | |
m.def("backward", &backward_cuda, "pair_wise_distance backward"); | |
} | |
''' |