File size: 9,530 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
// Copyright 2022-present NAVER Corp.
// CC BY-NC-SA 4.0
// Available only for non-commercial use

#include <torch/extension.h>
using namespace torch::indexing; // Slice
#include <vector>

#define MIN(x, y)           ((x) < (y) ? (x) : (y))
#define MAX(x, y)           ((x) < (y) ? (y) : (x))
#define CHECK_CUDA(x)       TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x)      CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline Slice sl(bool x) {
    if (x)
        return Slice(0, -1);
    else
        return Slice(1, None);
}

torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower, 
                                const at::optional<at::Tensor> weights, torch::Tensor upper );

std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower, 
                                        const at::optional<at::Tensor> weights = at::nullopt ) {
    TORCH_CHECK(level >= 1, "level must be >= 1");
    TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
    const auto LH1 = lower.size(0);
    const auto LW1 = lower.size(1);
    const auto LH2 = lower.size(2);
    const auto LW2 = lower.size(3);
    if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
    const auto UH1 = (level == 1) ? LH1+1 : LH1;
    const auto UW1 = (level == 1) ? LW1+1 : LW1;

    TORCH_CHECK(lower.is_cuda())
    auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options());
    torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper );
    return {upper, new_weights};
}


torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
                                     const at::optional<at::Tensor> weights, torch::Tensor upper );

std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower, 
                                        const at::optional<at::Tensor> weights = at::nullopt ) {
    TORCH_CHECK(level >= 1, "level must be >= 1");
    TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
    const auto LH1 = lower.size(0);
    const auto LW1 = lower.size(1);
    const auto LH2 = lower.size(2);
    const auto LW2 = lower.size(3);
    if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
    const auto UH1 = (level == 1) ? LH1+1 : LH1;
    const auto UW1 = (level == 1) ? LW1+1 : LW1;

    TORCH_CHECK(lower.is_cuda())
    auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options());
    torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper );
    return {upper, new_weights};
}

// forward declaration
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders );

void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) {
    TORCH_CHECK(level >= 1, "level must be >= 1");
    TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" );

    TORCH_CHECK(upper.is_cuda() && lower.is_cuda())
    backward_agg_unpool_cuda(level, upper, lower, exclude_borders);
}


void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
                            torch::Tensor maxima, torch::Tensor indices );

std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) {
    TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
    TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size );
    TORCH_CHECK( 1 <= stride, "bad stride %d", stride );
    const int IB = tensor.size(0);
    const int IH = tensor.size(2); // input height
    const int IW = tensor.size(3); // input width

    // output size
    const int OH = 1 + (IH - kernel_size) / stride;
    const int OW = 1 + (IW - kernel_size) / stride;
    
    torch::Tensor maxima  = torch::empty({IB, OH, OW}, tensor.options());
    torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64));

    if (tensor.is_cuda())
        max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices );
    else
        TORCH_CHECK(false, "CPU max_pool3d not implemented yet");
    return {maxima, indices};
}

static inline float ptdot( const float* m, float x, float y ) {
  return x*m[0] + y*m[1] + m[2];
}

static inline float pow2(float v) {
    return v*v;
}

void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot, 
                       float dmax, torch::Tensor all_corres, const int all_step ) {
    const int H = corres.size(0);
    const int W = corres.size(1);
    const float tol = 2*2; // squared
    dmax *= dmax; // squared

    TORCH_CHECK( _inv_rot.is_contiguous() );
    const float* inv_rot = _inv_rot.data_ptr<float>();

    auto corres_a = corres.accessor<float,3>();
    auto all_corres_a = all_corres.accessor<float,3>();

    // for each bin of the final histograms, we get the nearest-neighbour bin in corres0 and corres1
    for (int j=0; j<all_corres.size(0); j++) 
      for (int i=0; i<all_corres.size(1); i++) {
        // printf("accessing all_corres[%d,%d]", j, i);
        auto all_cor = all_corres_a[j][i];
        
        // center of the bin in the reference frame
        float x = i*all_step + all_step/2;
        float y = j*all_step + all_step/2;
        // printf(" -> (%g,%g) in ref img", x, y);

        // center of the bin on the rescaled+rotated image
        float xr = ptdot( inv_rot + 0, x, y ); 
        float yr = ptdot( inv_rot + 3, x, y );
        // printf(" -> (%g,%g) in rescaled", xr, yr);

        // iterate on the nearby bins
        int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
        int yb = (int)(0.5+ yr/4);
        // printf(" -> (%d,%d) in bins\n", xb, yb);

        float best = dmax;
        for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
          for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
            // assert( v >= 0 && v < corres_a.size(0) );
            // assert( u >= 0 && u < corres_a.size(1) );
            auto cor = corres_a[v][u];
            float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
            if( d < best )  best = d;
        }

        for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
          for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
            // assert( v >= 0 && v < corres_a.size(0) );
            // assert( u >= 0 && u < corres_a.size(1) );
            auto cor = corres_a[v][u];
            float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
            if( d <= tol*best ) { // spatially close
                // merge correspondence if score is better than actual
                // printf("update all_corres[%d,%d]\n", v,u);
                if( cor[4] > all_cor[4] )
                  for (int k = 0; k < all_corres.size(2); k++) 
                    all_cor[k] = cor[k];
              }
        }
    }
}

void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot, 
                        float dmax, torch::Tensor all_corres, const int all_step );

void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot, 
                   torch::Tensor all_corres, const int all_step ) {
    TORCH_CHECK(     corres.dim() == 3 &&     corres.size(2) == 6,     "corres.shape should be (H,W,6)" );
    TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" );

    float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>();
    torch::Tensor inv_rot = torch::inverse(rot).contiguous();

    if (all_corres.is_cuda()) 
        merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step );
    else
        merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step );
}


void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets, 
                                    const float radius, const float alpha);

void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets, 
                                    const float radius, const float alpha) {
    // radius: protected area in pixels around each target center
    // alpha: in [0,1]. If alpha = 0: no effect. If alpha = 1: full effect.
    TORCH_CHECK( corr.dim() == 4 );
    TORCH_CHECK( targets.dim() == 3 );
    TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2, 
        "correlations and targets should have the same shape[:2]" );

    if (corr.is_cuda()) 
        mask_correlations_radial_cuda( corr, targets, radius, alpha );
    else
        TORCH_CHECK(false, "TODO");
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)");
  m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)");
  m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)");
  m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)");
  m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" );
  m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" );
}