File size: 23,967 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
"""
Implementation of the line segment detection module.
"""
import math
import numpy as np
import torch


class LineSegmentDetectionModule(object):
    """ Module extracting line segments from junctions and line heatmaps. """
    def __init__(
        self, detect_thresh, num_samples=64, sampling_method="local_max",
        inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2,
        max_local_patch_radius=3, lambda_radius=2.,
        use_candidate_suppression=False, nms_dist_tolerance=3., 
        use_heatmap_refinement=False, heatmap_refine_cfg=None,
        use_junction_refinement=False, junction_refine_cfg=None):
        """
        Parameters:
            detect_thresh: The probability threshold for mean activation (0. ~ 1.)
            num_samples: Number of sampling locations along the line segments.
            sampling_method: Sampling method on locations ("bilinear" or "local_max").
            inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
            heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
            heatmap_high_thresh: The higher threshold for NMS in junction recovery.
            max_local_patch_radius: The max patch to be considered in local maximum search.
            lambda_radius: The lambda factor in linear local maximum search formulation
            use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
            nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
            use_heatmap_refinement: Use heatmap refinement method or not.
            heatmap_refine_cfg: The configs for heatmap refinement methods.
            use_junction_refinement: Use junction refinement method or not.
            junction_refine_cfg: The configs for junction refinement methods.
        """
        # Line detection parameters
        self.detect_thresh = detect_thresh

        # Line sampling parameters
        self.num_samples = num_samples
        self.sampling_method = sampling_method
        self.inlier_thresh = inlier_thresh
        self.local_patch_radius = max_local_patch_radius
        self.lambda_radius = lambda_radius
        
        # Detecting junctions on the boundary parameters
        self.low_thresh = heatmap_low_thresh
        self.high_thresh = heatmap_high_thresh

        # Pre-compute the linspace sampler
        self.sampler = np.linspace(0, 1, self.num_samples)
        self.torch_sampler = torch.linspace(0, 1, self.num_samples)

        # Long line segment suppression configuration
        self.use_candidate_suppression = use_candidate_suppression
        self.nms_dist_tolerance = nms_dist_tolerance

        # Heatmap refinement configuration
        self.use_heatmap_refinement = use_heatmap_refinement
        self.heatmap_refine_cfg = heatmap_refine_cfg
        if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
            raise ValueError("[Error] Missing heatmap refinement config.")

        # Junction refinement configuration
        self.use_junction_refinement = use_junction_refinement
        self.junction_refine_cfg = junction_refine_cfg
        if self.use_junction_refinement and self.junction_refine_cfg is None:
            raise ValueError("[Error] Missing junction refinement config.")
        
    def convert_inputs(self, inputs, device):
        """ Convert inputs to desired torch tensor. """
        if isinstance(inputs, np.ndarray):
            outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
        elif isinstance(inputs, torch.Tensor):
            outputs = inputs.to(torch.float32).to(device)
        else:
            raise ValueError(
        "[Error] Inputs must either be torch tensor or numpy ndarray.")
        
        return outputs
        
    def detect(self, junctions, heatmap, device=torch.device("cpu")):
        """ Main function performing line segment detection. """
        # Convert inputs to torch tensor
        junctions = self.convert_inputs(junctions, device=device)
        heatmap = self.convert_inputs(heatmap, device=device)
        
        # Perform the heatmap refinement
        if self.use_heatmap_refinement:
            if self.heatmap_refine_cfg["mode"] == "global":
                heatmap = self.refine_heatmap(
                    heatmap, 
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"]
                )
            elif self.heatmap_refine_cfg["mode"] == "local":
                heatmap = self.refine_heatmap_local(
                    heatmap, 
                    self.heatmap_refine_cfg["num_blocks"],
                    self.heatmap_refine_cfg["overlap_ratio"],
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"]
                )
        
        # Initialize empty line map
        num_junctions = junctions.shape[0]
        line_map_pred = torch.zeros([num_junctions, num_junctions],
                                    device=device, dtype=torch.int32)
        
        # Stop if there are not enough junctions
        if num_junctions < 2:
            return line_map_pred, junctions, heatmap

        # Generate the candidate map
        candidate_map = torch.triu(torch.ones(
            [num_junctions, num_junctions], device=device, dtype=torch.int32),
                                   diagonal=1)
        
        # Fetch the image boundary
        if len(heatmap.shape) > 2:
            H, W, _ = heatmap.shape
        else:
            H, W = heatmap.shape

        # Optionally perform candidate filtering
        if self.use_candidate_suppression:
            candidate_map = self.candidate_suppression(junctions,
                                                       candidate_map)

        # Fetch the candidates
        candidate_index_map = torch.where(candidate_map)
        candidate_index_map = torch.cat([candidate_index_map[0][..., None],
                                         candidate_index_map[1][..., None]],
                                        dim=-1)
        
        # Get the corresponding start and end junctions
        candidate_junc_start = junctions[candidate_index_map[:, 0], :]
        candidate_junc_end = junctions[candidate_index_map[:, 1], :]

        # Get the sampling locations (N x 64)
        sampler = self.torch_sampler.to(device)[None, ...]
        cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \
                         candidate_junc_end[:, 0:1] * (1 - sampler)
        cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \
                         candidate_junc_end[:, 1:2] * (1 - sampler)
        
        # Clip to image boundary
        cand_h = torch.clamp(cand_samples_h, min=0, max=H-1)
        cand_w = torch.clamp(cand_samples_w, min=0, max=W-1)
        
        # Local maximum search
        if self.sampling_method == "local_max":
            # Compute normalized segment lengths
            segments_length = torch.sqrt(torch.sum(
                (candidate_junc_start.to(torch.float32) -
                 candidate_junc_end.to(torch.float32)) ** 2, dim=-1))
            normalized_seg_length = (segments_length
                                     / (((H ** 2) + (W ** 2)) ** 0.5))
            
            # Perform local max search
            num_cand = cand_h.shape[0]
            group_size = 10000
            if num_cand > group_size:
                num_iter = math.ceil(num_cand / group_size)
                sampled_feat_lst = []
                for iter_idx in range(num_iter):
                    if not iter_idx == num_iter-1:
                        cand_h_ = cand_h[iter_idx * group_size:
                                         (iter_idx+1) * group_size, :]
                        cand_w_ = cand_w[iter_idx * group_size:
                                         (iter_idx+1) * group_size, :]
                        normalized_seg_length_ = normalized_seg_length[
                            iter_idx * group_size: (iter_idx+1) * group_size]
                    else:
                        cand_h_ = cand_h[iter_idx * group_size:, :]
                        cand_w_ = cand_w[iter_idx * group_size:, :]
                        normalized_seg_length_ = normalized_seg_length[
                            iter_idx * group_size:]
                    sampled_feat_ = self.detect_local_max(
                        heatmap, cand_h_, cand_w_, H, W,
                        normalized_seg_length_, device)
                    sampled_feat_lst.append(sampled_feat_)
                sampled_feat = torch.cat(sampled_feat_lst, dim=0)
            else:
                sampled_feat = self.detect_local_max(
                    heatmap, cand_h, cand_w, H, W, 
                    normalized_seg_length, device)
        # Bilinear sampling
        elif self.sampling_method == "bilinear":
            # Perform bilinear sampling
            sampled_feat = self.detect_bilinear(
                heatmap, cand_h, cand_w, H, W, device)
        else:
            raise ValueError("[Error] Unknown sampling method.")
     
        # [Simple threshold detection]
        # detection_results is a mask over all candidates
        detection_results = (torch.mean(sampled_feat, dim=-1)
                             > self.detect_thresh)
        
        # [Inlier threshold detection]
        if self.inlier_thresh > 0.:
            inlier_ratio = torch.sum(
                sampled_feat > self.detect_thresh,
                dim=-1).to(torch.float32) / self.num_samples
            detection_results_inlier = inlier_ratio >= self.inlier_thresh
            detection_results = detection_results * detection_results_inlier

        # Convert detection results back to line_map_pred
        detected_junc_indexes = candidate_index_map[detection_results, :]
        line_map_pred[detected_junc_indexes[:, 0],
                      detected_junc_indexes[:, 1]] = 1
        line_map_pred[detected_junc_indexes[:, 1],
                      detected_junc_indexes[:, 0]] = 1
        
        # Perform junction refinement
        if self.use_junction_refinement and len(detected_junc_indexes) > 0:
            junctions, line_map_pred = self.refine_junction_perturb(
                junctions, line_map_pred, heatmap, H, W, device)

        return line_map_pred, junctions, heatmap
    
    def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
        """ Global heatmap refinement method. """
        # Grab the top 10% values
        heatmap_values = heatmap[heatmap > valid_thresh]
        sorted_values = torch.sort(heatmap_values, descending=True)[0]
        top10_len = math.ceil(sorted_values.shape[0] * ratio)
        max20 = torch.mean(sorted_values[:top10_len])
        heatmap = torch.clamp(heatmap / max20, min=0., max=1.)
        return heatmap
    
    def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5,
                             ratio=0.2, valid_thresh=2e-3):
        """ Local heatmap refinement method. """
        # Get the shape of the heatmap
        H, W = heatmap.shape
        increase_ratio = 1 - overlap_ratio
        h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
        w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))

        count_map = torch.zeros(heatmap.shape, dtype=torch.int,
                                device=heatmap.device)
        heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float,
                                     device=heatmap.device)
        # Iterate through each block
        for h_idx in range(num_blocks):
            for w_idx in range(num_blocks):
                # Fetch the heatmap
                h_start = round(h_idx * h_block * increase_ratio)
                w_start = round(w_idx * w_block * increase_ratio)
                h_end = h_start + h_block if h_idx < num_blocks - 1 else H
                w_end = w_start + w_block if w_idx < num_blocks - 1 else W

                subheatmap = heatmap[h_start:h_end, w_start:w_end]
                if subheatmap.max() > valid_thresh:
                    subheatmap = self.refine_heatmap(
                        subheatmap, ratio, valid_thresh=valid_thresh)
                
                # Aggregate it to the final heatmap
                heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
                count_map[h_start:h_end, w_start:w_end] += 1
        heatmap_output = torch.clamp(heatmap_output / count_map,
                                     max=1., min=0.)

        return heatmap_output

    def candidate_suppression(self, junctions, candidate_map):
        """ Suppress overlapping long lines in the candidate segments. """
        # Define the distance tolerance
        dist_tolerance = self.nms_dist_tolerance

        # Compute distance between junction pairs
        # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
        line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1)
                                  - junctions[None, ...]) ** 2, dim=-1) ** 0.5

        # Fetch all the "detected lines"
        seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
        start_point_idxs = seg_indexes[0]
        end_point_idxs = seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        # Fetch corresponding entries
        line_dists = line_dist_map[start_point_idxs, end_point_idxs]

        # Check whether they are on the line
        dir_vecs = ((end_points - start_points)
                    / torch.norm(end_points - start_points,
                                 dim=-1)[..., None])
        # Get the orthogonal distance
        cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
        cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
        # Check whether they are projected directly onto the segment
        proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
                / line_dists[..., None, None])
        # proj is num_segs x num_junction x 1
        proj_mask = (proj >=0) * (proj <= 1)
        cand_angles = torch.acos(
            torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
            / cand_vecs_norm[..., None])
        cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
        junc_dist_mask = cand_dists <= dist_tolerance
        junc_mask = junc_dist_mask * proj_mask

        # Minus starting points
        num_segs = start_point_idxs.shape[0]
        junc_counts = torch.sum(junc_mask, dim=[1, 2])
        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
                                         start_point_idxs].to(torch.int)
        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
                                         end_point_idxs].to(torch.int)
        
        # Get the invalid candidate mask
        final_mask = junc_counts > 0
        candidate_map[start_point_idxs[final_mask],
                      end_point_idxs[final_mask]] = 0
            
        return candidate_map
    
    def refine_junction_perturb(self, junctions, line_map_pred,
                                heatmap, H, W, device):
        """ Refine the line endpoints in a similar way as in LSD. """
        # Get the config
        junction_refine_cfg = self.junction_refine_cfg

        # Fetch refinement parameters
        num_perturbs = junction_refine_cfg["num_perturbs"]
        perturb_interval = junction_refine_cfg["perturb_interval"]
        side_perturbs = (num_perturbs - 1) // 2
        # Fetch the 2D perturb mat
        perturb_vec = torch.arange(
            start=-perturb_interval*side_perturbs,
            end=perturb_interval*(side_perturbs+1),
            step=perturb_interval, device=device)
        w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
            perturb_vec, perturb_vec, perturb_vec, perturb_vec)
        perturb_tensor = torch.cat([
            w1_grid[..., None], h1_grid[..., None], 
            w2_grid[..., None], h2_grid[..., None]], dim=-1)
        perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)

        # Fetch the junctions and line_map
        junctions = junctions.clone()
        line_map = line_map_pred

        # Fetch all the detected lines
        detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1))
        start_point_idxs = detected_seg_indexes[0]
        end_point_idxs = detected_seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        line_segments = torch.cat([start_points.unsqueeze(dim=1),
                                   end_points.unsqueeze(dim=1)], dim=1)

        line_segment_candidates = (line_segments.unsqueeze(dim=1)
                                   + perturb_tensor_flat[None, ...])
        # Clip the boundaries
        line_segment_candidates[..., 0] = torch.clamp(
            line_segment_candidates[..., 0], min=0, max=H - 1)
        line_segment_candidates[..., 1] = torch.clamp(
            line_segment_candidates[..., 1], min=0, max=W - 1)

        # Iterate through all the segments
        refined_segment_lst = []
        num_segments = line_segments.shape[0]
        for idx in range(num_segments):
            segment = line_segment_candidates[idx, ...]
            # Get the corresponding start and end junctions
            candidate_junc_start = segment[:, 0, :]
            candidate_junc_end = segment[:, 1, :]

            # Get the sampling locations (N x 64)
            sampler = self.torch_sampler.to(device)[None, ...]
            cand_samples_h = (candidate_junc_start[:, 0:1] * sampler +
                              candidate_junc_end[:, 0:1] * (1 - sampler))
            cand_samples_w = (candidate_junc_start[:, 1:2] * sampler +
                              candidate_junc_end[:, 1:2] * (1 - sampler))
            
            # Clip to image boundary
            cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
            cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)

            # Perform bilinear sampling
            segment_feat = self.detect_bilinear(
                heatmap, cand_h, cand_w, H, W, device)
            segment_results = torch.mean(segment_feat, dim=-1)
            max_idx = torch.argmax(segment_results)
            refined_segment_lst.append(segment[max_idx, ...][None, ...])
        
        # Concatenate back to segments
        refined_segments = torch.cat(refined_segment_lst, dim=0)

        # Convert back to junctions and line_map
        junctions_new = torch.cat(
            [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0)
        junctions_new = torch.unique(junctions_new, dim=0)
        line_map_new = self.segments_to_line_map(junctions_new,
                                                 refined_segments)

        return junctions_new, line_map_new
    
    def segments_to_line_map(self, junctions, segments):
        """ Convert the list of segments to line map. """
        # Create empty line map
        device = junctions.device
        num_junctions = junctions.shape[0]
        line_map = torch.zeros([num_junctions, num_junctions], device=device)

        # Iterate through every segment
        for idx in range(segments.shape[0]):
            # Get the junctions from a single segement
            seg = segments[idx, ...]
            junction1 = seg[0, :]
            junction2 = seg[1, :]

            # Get index
            idx_junction1 = torch.where(
                (junctions == junction1).sum(axis=1) == 2)[0]
            idx_junction2 = torch.where(
                (junctions == junction2).sum(axis=1) == 2)[0]

            # label the corresponding entries
            line_map[idx_junction1, idx_junction2] = 1
            line_map[idx_junction2, idx_junction1] = 1

        return line_map

    def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
        """ Detection by bilinear sampling. """
        # Get the floor and ceiling locations
        cand_h_floor = torch.floor(cand_h).to(torch.long)
        cand_h_ceil = torch.ceil(cand_h).to(torch.long)
        cand_w_floor = torch.floor(cand_w).to(torch.long)
        cand_w_ceil = torch.ceil(cand_w).to(torch.long)

        # Perform the bilinear sampling
        cand_samples_feat = (
            heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h)
            * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil]
            * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) +
            heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor)
            * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil]
            * (cand_h - cand_h_floor) * (cand_w - cand_w_floor))
        
        return cand_samples_feat
    
    def detect_local_max(self, heatmap, cand_h, cand_w, H, W,
                         normalized_seg_length, device):
        """ Detection by local maximum search. """
        # Compute the distance threshold
        dist_thresh = (0.5 * (2 ** 0.5)
                       + self.lambda_radius * normalized_seg_length)
        # Make it N x 64
        dist_thresh = torch.repeat_interleave(dist_thresh[..., None],
                                              self.num_samples, dim=-1)
        
        # Compute the candidate points
        cand_points = torch.cat([cand_h[..., None], cand_w[..., None]],
                                dim=-1)
        cand_points_round = torch.round(cand_points) # N x 64 x 2
        
        # Construct local patches 9x9 = 81
        patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), 
                                  int(2 * self.local_patch_radius + 1)],
                                 device=device)
        patch_center = torch.tensor(
            [[self.local_patch_radius, self.local_patch_radius]], 
            device=device, dtype=torch.float32)
        H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
        patch_points = torch.cat([H_patch_points[..., None],
                                  W_patch_points[..., None]], dim=-1)
        # Fetch the circle region
        patch_center_dist = torch.sqrt(torch.sum(
            (patch_points - patch_center) ** 2, dim=-1))
        patch_points = (patch_points[patch_center_dist
                        <= self.local_patch_radius, :])
        # Shift [0, 0] to the center
        patch_points = patch_points - self.local_patch_radius
        
        # Construct local patch mask
        patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2)
                                + patch_points[None, None, ...])
        patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2)
                                          - patch_points_shifted) ** 2,
                                          dim=-1))
        patch_dist_mask = patch_dist < dist_thresh[..., None]
        
        # Get all points => num_points_center x num_patch_points x 2
        points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0,
                               max=H - 1).to(torch.long)
        points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0,
                               max=W - 1).to(torch.long)
        points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)
        
        # Sample the feature (N x 64 x 81)
        sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
        # Filtering using the valid mask
        sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32)
        if len(sampled_feat) == 0:
            sampled_feat_lmax = torch.empty(0, 64)
        else:
            sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)
        
        return sampled_feat_lmax