#include #include #include #include #include #include __global__ void decompress_residuals_kernel( const uint8_t* binary_residuals, const torch::PackedTensorAccessor32 bucket_weights, const torch::PackedTensorAccessor32 reversed_bit_map, const torch::PackedTensorAccessor32 bucket_weight_combinations, const torch::PackedTensorAccessor32 codes, const torch::PackedTensorAccessor32 centroids, const int n, const int dim, const int nbits, const int packed_size, at::Half* output) { const int packed_dim = (int)(dim * nbits / packed_size); const int i = blockIdx.x; const int j = threadIdx.x; if (i >= n) return; if (j >= dim * nbits / packed_size) return; const int code = codes[i]; uint8_t x = binary_residuals[i * packed_dim + j]; x = reversed_bit_map[x]; int output_idx = (int)(j * packed_size / nbits); for (int k = 0; k < packed_size / nbits; k++) { assert(output_idx < dim); const int bucket_weight_idx = bucket_weight_combinations[x][k]; output[i * dim + output_idx] = bucket_weights[bucket_weight_idx]; output[i * dim + output_idx] += centroids[code][output_idx]; output_idx++; } } torch::Tensor decompress_residuals_cuda( const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, const torch::Tensor reversed_bit_map, const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, const torch::Tensor centroids, const int dim, const int nbits) { auto options = torch::TensorOptions() .dtype(torch::kFloat16) .device(torch::kCUDA, 0) .requires_grad(false); torch::Tensor output = torch::zeros({(int)binary_residuals.size(0), (int)dim}, options); // TODO: Set this automatically? const int packed_size = 8; const int threads = dim / (packed_size / nbits); const int blocks = (binary_residuals.size(0) * binary_residuals.size(1)) / threads; decompress_residuals_kernel<<>>( binary_residuals.data(), bucket_weights .packed_accessor32(), reversed_bit_map .packed_accessor32(), bucket_weight_combinations .packed_accessor32(), codes.packed_accessor32(), centroids.packed_accessor32(), binary_residuals.size(0), dim, nbits, packed_size, output.data()); return output; }