File size: 9,530 Bytes
3ef85e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
// Copyright 2022-present NAVER Corp.
// CC BY-NC-SA 4.0
// Available only for non-commercial use
#include <torch/extension.h>
using namespace torch::indexing; // Slice
#include <vector>
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
inline Slice sl(bool x) {
if (x)
return Slice(0, -1);
else
return Slice(1, None);
}
torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights, torch::Tensor upper );
std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights = at::nullopt ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
const auto LH1 = lower.size(0);
const auto LW1 = lower.size(1);
const auto LH2 = lower.size(2);
const auto LW2 = lower.size(3);
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
const auto UH1 = (level == 1) ? LH1+1 : LH1;
const auto UW1 = (level == 1) ? LW1+1 : LW1;
TORCH_CHECK(lower.is_cuda())
auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options());
torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper );
return {upper, new_weights};
}
torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights, torch::Tensor upper );
std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower,
const at::optional<at::Tensor> weights = at::nullopt ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
const auto LH1 = lower.size(0);
const auto LW1 = lower.size(1);
const auto LH2 = lower.size(2);
const auto LW2 = lower.size(3);
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
const auto UH1 = (level == 1) ? LH1+1 : LH1;
const auto UW1 = (level == 1) ? LW1+1 : LW1;
TORCH_CHECK(lower.is_cuda())
auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options());
torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper );
return {upper, new_weights};
}
// forward declaration
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders );
void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) {
TORCH_CHECK(level >= 1, "level must be >= 1");
TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" );
TORCH_CHECK(upper.is_cuda() && lower.is_cuda())
backward_agg_unpool_cuda(level, upper, lower, exclude_borders);
}
void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
torch::Tensor maxima, torch::Tensor indices );
std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) {
TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size );
TORCH_CHECK( 1 <= stride, "bad stride %d", stride );
const int IB = tensor.size(0);
const int IH = tensor.size(2); // input height
const int IW = tensor.size(3); // input width
// output size
const int OH = 1 + (IH - kernel_size) / stride;
const int OW = 1 + (IW - kernel_size) / stride;
torch::Tensor maxima = torch::empty({IB, OH, OW}, tensor.options());
torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64));
if (tensor.is_cuda())
max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices );
else
TORCH_CHECK(false, "CPU max_pool3d not implemented yet");
return {maxima, indices};
}
static inline float ptdot( const float* m, float x, float y ) {
return x*m[0] + y*m[1] + m[2];
}
static inline float pow2(float v) {
return v*v;
}
void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot,
float dmax, torch::Tensor all_corres, const int all_step ) {
const int H = corres.size(0);
const int W = corres.size(1);
const float tol = 2*2; // squared
dmax *= dmax; // squared
TORCH_CHECK( _inv_rot.is_contiguous() );
const float* inv_rot = _inv_rot.data_ptr<float>();
auto corres_a = corres.accessor<float,3>();
auto all_corres_a = all_corres.accessor<float,3>();
// for each bin of the final histograms, we get the nearest-neighbour bin in corres0 and corres1
for (int j=0; j<all_corres.size(0); j++)
for (int i=0; i<all_corres.size(1); i++) {
// printf("accessing all_corres[%d,%d]", j, i);
auto all_cor = all_corres_a[j][i];
// center of the bin in the reference frame
float x = i*all_step + all_step/2;
float y = j*all_step + all_step/2;
// printf(" -> (%g,%g) in ref img", x, y);
// center of the bin on the rescaled+rotated image
float xr = ptdot( inv_rot + 0, x, y );
float yr = ptdot( inv_rot + 3, x, y );
// printf(" -> (%g,%g) in rescaled", xr, yr);
// iterate on the nearby bins
int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
int yb = (int)(0.5+ yr/4);
// printf(" -> (%d,%d) in bins\n", xb, yb);
float best = dmax;
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
// assert( v >= 0 && v < corres_a.size(0) );
// assert( u >= 0 && u < corres_a.size(1) );
auto cor = corres_a[v][u];
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
if( d < best ) best = d;
}
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
// assert( v >= 0 && v < corres_a.size(0) );
// assert( u >= 0 && u < corres_a.size(1) );
auto cor = corres_a[v][u];
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
if( d <= tol*best ) { // spatially close
// merge correspondence if score is better than actual
// printf("update all_corres[%d,%d]\n", v,u);
if( cor[4] > all_cor[4] )
for (int k = 0; k < all_corres.size(2); k++)
all_cor[k] = cor[k];
}
}
}
}
void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot,
float dmax, torch::Tensor all_corres, const int all_step );
void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot,
torch::Tensor all_corres, const int all_step ) {
TORCH_CHECK( corres.dim() == 3 && corres.size(2) == 6, "corres.shape should be (H,W,6)" );
TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" );
float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>();
torch::Tensor inv_rot = torch::inverse(rot).contiguous();
if (all_corres.is_cuda())
merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step );
else
merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step );
}
void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
const float radius, const float alpha);
void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets,
const float radius, const float alpha) {
// radius: protected area in pixels around each target center
// alpha: in [0,1]. If alpha = 0: no effect. If alpha = 1: full effect.
TORCH_CHECK( corr.dim() == 4 );
TORCH_CHECK( targets.dim() == 3 );
TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2,
"correlations and targets should have the same shape[:2]" );
if (corr.is_cuda())
mask_correlations_radial_cuda( corr, targets, radius, alpha );
else
TORCH_CHECK(false, "TODO");
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)");
m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)");
m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)");
m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)");
m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" );
m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" );
}
|