|
|
|
|
|
|
|
|
|
#include <torch/extension.h> |
|
using namespace torch::indexing; |
|
#include <vector> |
|
|
|
#define MIN(x, y) ((x) < (y) ? (x) : (y)) |
|
#define MAX(x, y) ((x) < (y) ? (y) : (x)) |
|
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") |
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) |
|
|
|
inline Slice sl(bool x) { |
|
if (x) |
|
return Slice(0, -1); |
|
else |
|
return Slice(1, None); |
|
} |
|
|
|
torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower, |
|
const at::optional<at::Tensor> weights, torch::Tensor upper ); |
|
|
|
std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower, |
|
const at::optional<at::Tensor> weights = at::nullopt ) { |
|
TORCH_CHECK(level >= 1, "level must be >= 1"); |
|
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions"); |
|
const auto LH1 = lower.size(0); |
|
const auto LW1 = lower.size(1); |
|
const auto LH2 = lower.size(2); |
|
const auto LW2 = lower.size(3); |
|
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]"); |
|
const auto UH1 = (level == 1) ? LH1+1 : LH1; |
|
const auto UW1 = (level == 1) ? LW1+1 : LW1; |
|
|
|
TORCH_CHECK(lower.is_cuda()) |
|
auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options()); |
|
torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper ); |
|
return {upper, new_weights}; |
|
} |
|
|
|
|
|
torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower, |
|
const at::optional<at::Tensor> weights, torch::Tensor upper ); |
|
|
|
std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower, |
|
const at::optional<at::Tensor> weights = at::nullopt ) { |
|
TORCH_CHECK(level >= 1, "level must be >= 1"); |
|
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions"); |
|
const auto LH1 = lower.size(0); |
|
const auto LW1 = lower.size(1); |
|
const auto LH2 = lower.size(2); |
|
const auto LW2 = lower.size(3); |
|
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]"); |
|
const auto UH1 = (level == 1) ? LH1+1 : LH1; |
|
const auto UW1 = (level == 1) ? LW1+1 : LW1; |
|
|
|
TORCH_CHECK(lower.is_cuda()) |
|
auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options()); |
|
torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper ); |
|
return {upper, new_weights}; |
|
} |
|
|
|
|
|
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders ); |
|
|
|
void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) { |
|
TORCH_CHECK(level >= 1, "level must be >= 1"); |
|
TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" ); |
|
|
|
TORCH_CHECK(upper.is_cuda() && lower.is_cuda()) |
|
backward_agg_unpool_cuda(level, upper, lower, exclude_borders); |
|
} |
|
|
|
|
|
void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride, |
|
torch::Tensor maxima, torch::Tensor indices ); |
|
|
|
std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) { |
|
TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW"); |
|
TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size ); |
|
TORCH_CHECK( 1 <= stride, "bad stride %d", stride ); |
|
const int IB = tensor.size(0); |
|
const int IH = tensor.size(2); |
|
const int IW = tensor.size(3); |
|
|
|
|
|
const int OH = 1 + (IH - kernel_size) / stride; |
|
const int OW = 1 + (IW - kernel_size) / stride; |
|
|
|
torch::Tensor maxima = torch::empty({IB, OH, OW}, tensor.options()); |
|
torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64)); |
|
|
|
if (tensor.is_cuda()) |
|
max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices ); |
|
else |
|
TORCH_CHECK(false, "CPU max_pool3d not implemented yet"); |
|
return {maxima, indices}; |
|
} |
|
|
|
static inline float ptdot( const float* m, float x, float y ) { |
|
return x*m[0] + y*m[1] + m[2]; |
|
} |
|
|
|
static inline float pow2(float v) { |
|
return v*v; |
|
} |
|
|
|
void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot, |
|
float dmax, torch::Tensor all_corres, const int all_step ) { |
|
const int H = corres.size(0); |
|
const int W = corres.size(1); |
|
const float tol = 2*2; |
|
dmax *= dmax; |
|
|
|
TORCH_CHECK( _inv_rot.is_contiguous() ); |
|
const float* inv_rot = _inv_rot.data_ptr<float>(); |
|
|
|
auto corres_a = corres.accessor<float,3>(); |
|
auto all_corres_a = all_corres.accessor<float,3>(); |
|
|
|
|
|
for (int j=0; j<all_corres.size(0); j++) |
|
for (int i=0; i<all_corres.size(1); i++) { |
|
|
|
auto all_cor = all_corres_a[j][i]; |
|
|
|
|
|
float x = i*all_step + all_step/2; |
|
float y = j*all_step + all_step/2; |
|
|
|
|
|
|
|
float xr = ptdot( inv_rot + 0, x, y ); |
|
float yr = ptdot( inv_rot + 3, x, y ); |
|
|
|
|
|
|
|
int xb = (int)(0.5+ xr/4); |
|
int yb = (int)(0.5+ yr/4); |
|
|
|
|
|
float best = dmax; |
|
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++) |
|
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) { |
|
|
|
|
|
auto cor = corres_a[v][u]; |
|
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y); |
|
if( d < best ) best = d; |
|
} |
|
|
|
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++) |
|
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) { |
|
|
|
|
|
auto cor = corres_a[v][u]; |
|
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y); |
|
if( d <= tol*best ) { |
|
|
|
|
|
if( cor[4] > all_cor[4] ) |
|
for (int k = 0; k < all_corres.size(2); k++) |
|
all_cor[k] = cor[k]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot, |
|
float dmax, torch::Tensor all_corres, const int all_step ); |
|
|
|
void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot, |
|
torch::Tensor all_corres, const int all_step ) { |
|
TORCH_CHECK( corres.dim() == 3 && corres.size(2) == 6, "corres.shape should be (H,W,6)" ); |
|
TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" ); |
|
|
|
float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>(); |
|
torch::Tensor inv_rot = torch::inverse(rot).contiguous(); |
|
|
|
if (all_corres.is_cuda()) |
|
merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step ); |
|
else |
|
merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step ); |
|
} |
|
|
|
|
|
void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets, |
|
const float radius, const float alpha); |
|
|
|
void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets, |
|
const float radius, const float alpha) { |
|
|
|
|
|
TORCH_CHECK( corr.dim() == 4 ); |
|
TORCH_CHECK( targets.dim() == 3 ); |
|
TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2, |
|
"correlations and targets should have the same shape[:2]" ); |
|
|
|
if (corr.is_cuda()) |
|
mask_correlations_radial_cuda( corr, targets, radius, alpha ); |
|
else |
|
TORCH_CHECK(false, "TODO"); |
|
} |
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)"); |
|
m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)"); |
|
m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)"); |
|
m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)"); |
|
m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" ); |
|
m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" ); |
|
} |
|
|