PUMP / core /cuda_deepm /func.cpp
Philippe Weinzaepfel
huggingface demo
3ef85e9
// Copyright 2022-present NAVER Corp.
// CC BY-NC-SA 4.0
// Available only for non-commercial use
#include <torch/extension.h>
using namespace torch::indexing; // Slice
#include <vector>
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
inline Slice sl(bool x) {
if (x)
return Slice(0, -1);
else
return Slice(1, None);
}
torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights, torch::Tensor upper );
std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights = at::nullopt ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
const auto LH1 = lower.size(0);
const auto LW1 = lower.size(1);
const auto LH2 = lower.size(2);
const auto LW2 = lower.size(3);
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
const auto UH1 = (level == 1) ? LH1+1 : LH1;
const auto UW1 = (level == 1) ? LW1+1 : LW1;
TORCH_CHECK(lower.is_cuda())
auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options());
torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper );
return {upper, new_weights};
}
torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights, torch::Tensor upper );
std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights = at::nullopt ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
const auto LH1 = lower.size(0);
const auto LW1 = lower.size(1);
const auto LH2 = lower.size(2);
const auto LW2 = lower.size(3);
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
const auto UH1 = (level == 1) ? LH1+1 : LH1;
const auto UW1 = (level == 1) ? LW1+1 : LW1;
TORCH_CHECK(lower.is_cuda())
auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options());
torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper );
return {upper, new_weights};
}
// forward declaration
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders );
void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" );
TORCH_CHECK(upper.is_cuda() && lower.is_cuda())
backward_agg_unpool_cuda(level, upper, lower, exclude_borders);
}
void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
torch::Tensor maxima, torch::Tensor indices );
std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) {
TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size );
TORCH_CHECK( 1 <= stride, "bad stride %d", stride );
const int IB = tensor.size(0);
const int IH = tensor.size(2); // input height
const int IW = tensor.size(3); // input width
// output size
const int OH = 1 + (IH - kernel_size) / stride;
const int OW = 1 + (IW - kernel_size) / stride;
torch::Tensor maxima = torch::empty({IB, OH, OW}, tensor.options());
torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64));
if (tensor.is_cuda())
max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices );
else
TORCH_CHECK(false, "CPU max_pool3d not implemented yet");
return {maxima, indices};
}
static inline float ptdot( const float* m, float x, float y ) {
return x*m[0] + y*m[1] + m[2];
}
static inline float pow2(float v) {
return v*v;
}
void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot,
float dmax, torch::Tensor all_corres, const int all_step ) {
const int H = corres.size(0);
const int W = corres.size(1);
const float tol = 2*2; // squared
dmax *= dmax; // squared
TORCH_CHECK( _inv_rot.is_contiguous() );
const float* inv_rot = _inv_rot.data_ptr<float>();
auto corres_a = corres.accessor<float,3>();
auto all_corres_a = all_corres.accessor<float,3>();
// for each bin of the final histograms, we get the nearest-neighbour bin in corres0 and corres1
for (int j=0; j<all_corres.size(0); j++)
for (int i=0; i<all_corres.size(1); i++) {
// printf("accessing all_corres[%d,%d]", j, i);
auto all_cor = all_corres_a[j][i];
// center of the bin in the reference frame
float x = i*all_step + all_step/2;
float y = j*all_step + all_step/2;
// printf(" -> (%g,%g) in ref img", x, y);
// center of the bin on the rescaled+rotated image
float xr = ptdot( inv_rot + 0, x, y );
float yr = ptdot( inv_rot + 3, x, y );
// printf(" -> (%g,%g) in rescaled", xr, yr);
// iterate on the nearby bins
int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
int yb = (int)(0.5+ yr/4);
// printf(" -> (%d,%d) in bins\n", xb, yb);
float best = dmax;
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
// assert( v >= 0 && v < corres_a.size(0) );
// assert( u >= 0 && u < corres_a.size(1) );
auto cor = corres_a[v][u];
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
if( d < best ) best = d;
}
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
// assert( v >= 0 && v < corres_a.size(0) );
// assert( u >= 0 && u < corres_a.size(1) );
auto cor = corres_a[v][u];
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
if( d <= tol*best ) { // spatially close
// merge correspondence if score is better than actual
// printf("update all_corres[%d,%d]\n", v,u);
if( cor[4] > all_cor[4] )
for (int k = 0; k < all_corres.size(2); k++)
all_cor[k] = cor[k];
}
}
}
}
void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot,
float dmax, torch::Tensor all_corres, const int all_step );
void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot,
torch::Tensor all_corres, const int all_step ) {
TORCH_CHECK( corres.dim() == 3 && corres.size(2) == 6, "corres.shape should be (H,W,6)" );
TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" );
float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>();
torch::Tensor inv_rot = torch::inverse(rot).contiguous();
if (all_corres.is_cuda())
merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step );
else
merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step );
}
void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
const float radius, const float alpha);
void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets,
const float radius, const float alpha) {
// radius: protected area in pixels around each target center
// alpha: in [0,1]. If alpha = 0: no effect. If alpha = 1: full effect.
TORCH_CHECK( corr.dim() == 4 );
TORCH_CHECK( targets.dim() == 3 );
TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2,
"correlations and targets should have the same shape[:2]" );
if (corr.is_cuda())
mask_correlations_radial_cuda( corr, targets, radius, alpha );
else
TORCH_CHECK(false, "TODO");
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)");
m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)");
m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)");
m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)");
m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" );
m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" );
}