Spaces:
Running
Running
""" | |
Implementation of the line segment detection module. | |
""" | |
import math | |
import numpy as np | |
import torch | |
class LineSegmentDetectionModule(object): | |
"""Module extracting line segments from junctions and line heatmaps.""" | |
def __init__( | |
self, | |
detect_thresh, | |
num_samples=64, | |
sampling_method="local_max", | |
inlier_thresh=0.0, | |
heatmap_low_thresh=0.15, | |
heatmap_high_thresh=0.2, | |
max_local_patch_radius=3, | |
lambda_radius=2.0, | |
use_candidate_suppression=False, | |
nms_dist_tolerance=3.0, | |
use_heatmap_refinement=False, | |
heatmap_refine_cfg=None, | |
use_junction_refinement=False, | |
junction_refine_cfg=None, | |
): | |
""" | |
Parameters: | |
detect_thresh: The probability threshold for mean activation (0. ~ 1.) | |
num_samples: Number of sampling locations along the line segments. | |
sampling_method: Sampling method on locations ("bilinear" or "local_max"). | |
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. | |
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. | |
heatmap_high_thresh: The higher threshold for NMS in junction recovery. | |
max_local_patch_radius: The max patch to be considered in local maximum search. | |
lambda_radius: The lambda factor in linear local maximum search formulation | |
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. | |
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. | |
use_heatmap_refinement: Use heatmap refinement method or not. | |
heatmap_refine_cfg: The configs for heatmap refinement methods. | |
use_junction_refinement: Use junction refinement method or not. | |
junction_refine_cfg: The configs for junction refinement methods. | |
""" | |
# Line detection parameters | |
self.detect_thresh = detect_thresh | |
# Line sampling parameters | |
self.num_samples = num_samples | |
self.sampling_method = sampling_method | |
self.inlier_thresh = inlier_thresh | |
self.local_patch_radius = max_local_patch_radius | |
self.lambda_radius = lambda_radius | |
# Detecting junctions on the boundary parameters | |
self.low_thresh = heatmap_low_thresh | |
self.high_thresh = heatmap_high_thresh | |
# Pre-compute the linspace sampler | |
self.sampler = np.linspace(0, 1, self.num_samples) | |
self.torch_sampler = torch.linspace(0, 1, self.num_samples) | |
# Long line segment suppression configuration | |
self.use_candidate_suppression = use_candidate_suppression | |
self.nms_dist_tolerance = nms_dist_tolerance | |
# Heatmap refinement configuration | |
self.use_heatmap_refinement = use_heatmap_refinement | |
self.heatmap_refine_cfg = heatmap_refine_cfg | |
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: | |
raise ValueError("[Error] Missing heatmap refinement config.") | |
# Junction refinement configuration | |
self.use_junction_refinement = use_junction_refinement | |
self.junction_refine_cfg = junction_refine_cfg | |
if self.use_junction_refinement and self.junction_refine_cfg is None: | |
raise ValueError("[Error] Missing junction refinement config.") | |
def convert_inputs(self, inputs, device): | |
"""Convert inputs to desired torch tensor.""" | |
if isinstance(inputs, np.ndarray): | |
outputs = torch.tensor(inputs, dtype=torch.float32, device=device) | |
elif isinstance(inputs, torch.Tensor): | |
outputs = inputs.to(torch.float32).to(device) | |
else: | |
raise ValueError( | |
"[Error] Inputs must either be torch tensor or numpy ndarray." | |
) | |
return outputs | |
def detect(self, junctions, heatmap, device=torch.device("cpu")): | |
"""Main function performing line segment detection.""" | |
# Convert inputs to torch tensor | |
junctions = self.convert_inputs(junctions, device=device) | |
heatmap = self.convert_inputs(heatmap, device=device) | |
# Perform the heatmap refinement | |
if self.use_heatmap_refinement: | |
if self.heatmap_refine_cfg["mode"] == "global": | |
heatmap = self.refine_heatmap( | |
heatmap, | |
self.heatmap_refine_cfg["ratio"], | |
self.heatmap_refine_cfg["valid_thresh"], | |
) | |
elif self.heatmap_refine_cfg["mode"] == "local": | |
heatmap = self.refine_heatmap_local( | |
heatmap, | |
self.heatmap_refine_cfg["num_blocks"], | |
self.heatmap_refine_cfg["overlap_ratio"], | |
self.heatmap_refine_cfg["ratio"], | |
self.heatmap_refine_cfg["valid_thresh"], | |
) | |
# Initialize empty line map | |
num_junctions = junctions.shape[0] | |
line_map_pred = torch.zeros( | |
[num_junctions, num_junctions], device=device, dtype=torch.int32 | |
) | |
# Stop if there are not enough junctions | |
if num_junctions < 2: | |
return line_map_pred, junctions, heatmap | |
# Generate the candidate map | |
candidate_map = torch.triu( | |
torch.ones( | |
[num_junctions, num_junctions], device=device, dtype=torch.int32 | |
), | |
diagonal=1, | |
) | |
# Fetch the image boundary | |
if len(heatmap.shape) > 2: | |
H, W, _ = heatmap.shape | |
else: | |
H, W = heatmap.shape | |
# Optionally perform candidate filtering | |
if self.use_candidate_suppression: | |
candidate_map = self.candidate_suppression(junctions, candidate_map) | |
# Fetch the candidates | |
candidate_index_map = torch.where(candidate_map) | |
candidate_index_map = torch.cat( | |
[candidate_index_map[0][..., None], candidate_index_map[1][..., None]], | |
dim=-1, | |
) | |
# Get the corresponding start and end junctions | |
candidate_junc_start = junctions[candidate_index_map[:, 0], :] | |
candidate_junc_end = junctions[candidate_index_map[:, 1], :] | |
# Get the sampling locations (N x 64) | |
sampler = self.torch_sampler.to(device)[None, ...] | |
cand_samples_h = candidate_junc_start[:, 0:1] * sampler + candidate_junc_end[ | |
:, 0:1 | |
] * (1 - sampler) | |
cand_samples_w = candidate_junc_start[:, 1:2] * sampler + candidate_junc_end[ | |
:, 1:2 | |
] * (1 - sampler) | |
# Clip to image boundary | |
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) | |
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) | |
# Local maximum search | |
if self.sampling_method == "local_max": | |
# Compute normalized segment lengths | |
segments_length = torch.sqrt( | |
torch.sum( | |
( | |
candidate_junc_start.to(torch.float32) | |
- candidate_junc_end.to(torch.float32) | |
) | |
** 2, | |
dim=-1, | |
) | |
) | |
normalized_seg_length = segments_length / (((H**2) + (W**2)) ** 0.5) | |
# Perform local max search | |
num_cand = cand_h.shape[0] | |
group_size = 10000 | |
if num_cand > group_size: | |
num_iter = math.ceil(num_cand / group_size) | |
sampled_feat_lst = [] | |
for iter_idx in range(num_iter): | |
if not iter_idx == num_iter - 1: | |
cand_h_ = cand_h[ | |
iter_idx * group_size : (iter_idx + 1) * group_size, : | |
] | |
cand_w_ = cand_w[ | |
iter_idx * group_size : (iter_idx + 1) * group_size, : | |
] | |
normalized_seg_length_ = normalized_seg_length[ | |
iter_idx * group_size : (iter_idx + 1) * group_size | |
] | |
else: | |
cand_h_ = cand_h[iter_idx * group_size :, :] | |
cand_w_ = cand_w[iter_idx * group_size :, :] | |
normalized_seg_length_ = normalized_seg_length[ | |
iter_idx * group_size : | |
] | |
sampled_feat_ = self.detect_local_max( | |
heatmap, cand_h_, cand_w_, H, W, normalized_seg_length_, device | |
) | |
sampled_feat_lst.append(sampled_feat_) | |
sampled_feat = torch.cat(sampled_feat_lst, dim=0) | |
else: | |
sampled_feat = self.detect_local_max( | |
heatmap, cand_h, cand_w, H, W, normalized_seg_length, device | |
) | |
# Bilinear sampling | |
elif self.sampling_method == "bilinear": | |
# Perform bilinear sampling | |
sampled_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) | |
else: | |
raise ValueError("[Error] Unknown sampling method.") | |
# [Simple threshold detection] | |
# detection_results is a mask over all candidates | |
detection_results = torch.mean(sampled_feat, dim=-1) > self.detect_thresh | |
# [Inlier threshold detection] | |
if self.inlier_thresh > 0.0: | |
inlier_ratio = ( | |
torch.sum(sampled_feat > self.detect_thresh, dim=-1).to(torch.float32) | |
/ self.num_samples | |
) | |
detection_results_inlier = inlier_ratio >= self.inlier_thresh | |
detection_results = detection_results * detection_results_inlier | |
# Convert detection results back to line_map_pred | |
detected_junc_indexes = candidate_index_map[detection_results, :] | |
line_map_pred[detected_junc_indexes[:, 0], detected_junc_indexes[:, 1]] = 1 | |
line_map_pred[detected_junc_indexes[:, 1], detected_junc_indexes[:, 0]] = 1 | |
# Perform junction refinement | |
if self.use_junction_refinement and len(detected_junc_indexes) > 0: | |
junctions, line_map_pred = self.refine_junction_perturb( | |
junctions, line_map_pred, heatmap, H, W, device | |
) | |
return line_map_pred, junctions, heatmap | |
def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): | |
"""Global heatmap refinement method.""" | |
# Grab the top 10% values | |
heatmap_values = heatmap[heatmap > valid_thresh] | |
sorted_values = torch.sort(heatmap_values, descending=True)[0] | |
top10_len = math.ceil(sorted_values.shape[0] * ratio) | |
max20 = torch.mean(sorted_values[:top10_len]) | |
heatmap = torch.clamp(heatmap / max20, min=0.0, max=1.0) | |
return heatmap | |
def refine_heatmap_local( | |
self, heatmap, num_blocks=5, overlap_ratio=0.5, ratio=0.2, valid_thresh=2e-3 | |
): | |
"""Local heatmap refinement method.""" | |
# Get the shape of the heatmap | |
H, W = heatmap.shape | |
increase_ratio = 1 - overlap_ratio | |
h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) | |
w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) | |
count_map = torch.zeros(heatmap.shape, dtype=torch.int, device=heatmap.device) | |
heatmap_output = torch.zeros( | |
heatmap.shape, dtype=torch.float, device=heatmap.device | |
) | |
# Iterate through each block | |
for h_idx in range(num_blocks): | |
for w_idx in range(num_blocks): | |
# Fetch the heatmap | |
h_start = round(h_idx * h_block * increase_ratio) | |
w_start = round(w_idx * w_block * increase_ratio) | |
h_end = h_start + h_block if h_idx < num_blocks - 1 else H | |
w_end = w_start + w_block if w_idx < num_blocks - 1 else W | |
subheatmap = heatmap[h_start:h_end, w_start:w_end] | |
if subheatmap.max() > valid_thresh: | |
subheatmap = self.refine_heatmap( | |
subheatmap, ratio, valid_thresh=valid_thresh | |
) | |
# Aggregate it to the final heatmap | |
heatmap_output[h_start:h_end, w_start:w_end] += subheatmap | |
count_map[h_start:h_end, w_start:w_end] += 1 | |
heatmap_output = torch.clamp(heatmap_output / count_map, max=1.0, min=0.0) | |
return heatmap_output | |
def candidate_suppression(self, junctions, candidate_map): | |
"""Suppress overlapping long lines in the candidate segments.""" | |
# Define the distance tolerance | |
dist_tolerance = self.nms_dist_tolerance | |
# Compute distance between junction pairs | |
# (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map | |
line_dist_map = ( | |
torch.sum( | |
(torch.unsqueeze(junctions, dim=1) - junctions[None, ...]) ** 2, dim=-1 | |
) | |
** 0.5 | |
) | |
# Fetch all the "detected lines" | |
seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) | |
start_point_idxs = seg_indexes[0] | |
end_point_idxs = seg_indexes[1] | |
start_points = junctions[start_point_idxs, :] | |
end_points = junctions[end_point_idxs, :] | |
# Fetch corresponding entries | |
line_dists = line_dist_map[start_point_idxs, end_point_idxs] | |
# Check whether they are on the line | |
dir_vecs = (end_points - start_points) / torch.norm( | |
end_points - start_points, dim=-1 | |
)[..., None] | |
# Get the orthogonal distance | |
cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) | |
cand_vecs_norm = torch.norm(cand_vecs, dim=-1) | |
# Check whether they are projected directly onto the segment | |
proj = ( | |
torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) | |
/ line_dists[..., None, None] | |
) | |
# proj is num_segs x num_junction x 1 | |
proj_mask = (proj >= 0) * (proj <= 1) | |
cand_angles = torch.acos( | |
torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) | |
/ cand_vecs_norm[..., None] | |
) | |
cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) | |
junc_dist_mask = cand_dists <= dist_tolerance | |
junc_mask = junc_dist_mask * proj_mask | |
# Minus starting points | |
num_segs = start_point_idxs.shape[0] | |
junc_counts = torch.sum(junc_mask, dim=[1, 2]) | |
junc_counts -= junc_mask[..., 0][ | |
torch.arange(0, num_segs), start_point_idxs | |
].to(torch.int) | |
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), end_point_idxs].to( | |
torch.int | |
) | |
# Get the invalid candidate mask | |
final_mask = junc_counts > 0 | |
candidate_map[start_point_idxs[final_mask], end_point_idxs[final_mask]] = 0 | |
return candidate_map | |
def refine_junction_perturb(self, junctions, line_map_pred, heatmap, H, W, device): | |
"""Refine the line endpoints in a similar way as in LSD.""" | |
# Get the config | |
junction_refine_cfg = self.junction_refine_cfg | |
# Fetch refinement parameters | |
num_perturbs = junction_refine_cfg["num_perturbs"] | |
perturb_interval = junction_refine_cfg["perturb_interval"] | |
side_perturbs = (num_perturbs - 1) // 2 | |
# Fetch the 2D perturb mat | |
perturb_vec = torch.arange( | |
start=-perturb_interval * side_perturbs, | |
end=perturb_interval * (side_perturbs + 1), | |
step=perturb_interval, | |
device=device, | |
) | |
w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( | |
perturb_vec, perturb_vec, perturb_vec, perturb_vec | |
) | |
perturb_tensor = torch.cat( | |
[ | |
w1_grid[..., None], | |
h1_grid[..., None], | |
w2_grid[..., None], | |
h2_grid[..., None], | |
], | |
dim=-1, | |
) | |
perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) | |
# Fetch the junctions and line_map | |
junctions = junctions.clone() | |
line_map = line_map_pred | |
# Fetch all the detected lines | |
detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1)) | |
start_point_idxs = detected_seg_indexes[0] | |
end_point_idxs = detected_seg_indexes[1] | |
start_points = junctions[start_point_idxs, :] | |
end_points = junctions[end_point_idxs, :] | |
line_segments = torch.cat( | |
[start_points.unsqueeze(dim=1), end_points.unsqueeze(dim=1)], dim=1 | |
) | |
line_segment_candidates = ( | |
line_segments.unsqueeze(dim=1) + perturb_tensor_flat[None, ...] | |
) | |
# Clip the boundaries | |
line_segment_candidates[..., 0] = torch.clamp( | |
line_segment_candidates[..., 0], min=0, max=H - 1 | |
) | |
line_segment_candidates[..., 1] = torch.clamp( | |
line_segment_candidates[..., 1], min=0, max=W - 1 | |
) | |
# Iterate through all the segments | |
refined_segment_lst = [] | |
num_segments = line_segments.shape[0] | |
for idx in range(num_segments): | |
segment = line_segment_candidates[idx, ...] | |
# Get the corresponding start and end junctions | |
candidate_junc_start = segment[:, 0, :] | |
candidate_junc_end = segment[:, 1, :] | |
# Get the sampling locations (N x 64) | |
sampler = self.torch_sampler.to(device)[None, ...] | |
cand_samples_h = candidate_junc_start[ | |
:, 0:1 | |
] * sampler + candidate_junc_end[:, 0:1] * (1 - sampler) | |
cand_samples_w = candidate_junc_start[ | |
:, 1:2 | |
] * sampler + candidate_junc_end[:, 1:2] * (1 - sampler) | |
# Clip to image boundary | |
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) | |
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) | |
# Perform bilinear sampling | |
segment_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) | |
segment_results = torch.mean(segment_feat, dim=-1) | |
max_idx = torch.argmax(segment_results) | |
refined_segment_lst.append(segment[max_idx, ...][None, ...]) | |
# Concatenate back to segments | |
refined_segments = torch.cat(refined_segment_lst, dim=0) | |
# Convert back to junctions and line_map | |
junctions_new = torch.cat( | |
[refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0 | |
) | |
junctions_new = torch.unique(junctions_new, dim=0) | |
line_map_new = self.segments_to_line_map(junctions_new, refined_segments) | |
return junctions_new, line_map_new | |
def segments_to_line_map(self, junctions, segments): | |
"""Convert the list of segments to line map.""" | |
# Create empty line map | |
device = junctions.device | |
num_junctions = junctions.shape[0] | |
line_map = torch.zeros([num_junctions, num_junctions], device=device) | |
# Iterate through every segment | |
for idx in range(segments.shape[0]): | |
# Get the junctions from a single segement | |
seg = segments[idx, ...] | |
junction1 = seg[0, :] | |
junction2 = seg[1, :] | |
# Get index | |
idx_junction1 = torch.where((junctions == junction1).sum(axis=1) == 2)[0] | |
idx_junction2 = torch.where((junctions == junction2).sum(axis=1) == 2)[0] | |
# label the corresponding entries | |
line_map[idx_junction1, idx_junction2] = 1 | |
line_map[idx_junction2, idx_junction1] = 1 | |
return line_map | |
def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): | |
"""Detection by bilinear sampling.""" | |
# Get the floor and ceiling locations | |
cand_h_floor = torch.floor(cand_h).to(torch.long) | |
cand_h_ceil = torch.ceil(cand_h).to(torch.long) | |
cand_w_floor = torch.floor(cand_w).to(torch.long) | |
cand_w_ceil = torch.ceil(cand_w).to(torch.long) | |
# Perform the bilinear sampling | |
cand_samples_feat = ( | |
heatmap[cand_h_floor, cand_w_floor] | |
* (cand_h_ceil - cand_h) | |
* (cand_w_ceil - cand_w) | |
+ heatmap[cand_h_floor, cand_w_ceil] | |
* (cand_h_ceil - cand_h) | |
* (cand_w - cand_w_floor) | |
+ heatmap[cand_h_ceil, cand_w_floor] | |
* (cand_h - cand_h_floor) | |
* (cand_w_ceil - cand_w) | |
+ heatmap[cand_h_ceil, cand_w_ceil] | |
* (cand_h - cand_h_floor) | |
* (cand_w - cand_w_floor) | |
) | |
return cand_samples_feat | |
def detect_local_max( | |
self, heatmap, cand_h, cand_w, H, W, normalized_seg_length, device | |
): | |
"""Detection by local maximum search.""" | |
# Compute the distance threshold | |
dist_thresh = 0.5 * (2**0.5) + self.lambda_radius * normalized_seg_length | |
# Make it N x 64 | |
dist_thresh = torch.repeat_interleave( | |
dist_thresh[..., None], self.num_samples, dim=-1 | |
) | |
# Compute the candidate points | |
cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], dim=-1) | |
cand_points_round = torch.round(cand_points) # N x 64 x 2 | |
# Construct local patches 9x9 = 81 | |
patch_mask = torch.zeros( | |
[ | |
int(2 * self.local_patch_radius + 1), | |
int(2 * self.local_patch_radius + 1), | |
], | |
device=device, | |
) | |
patch_center = torch.tensor( | |
[[self.local_patch_radius, self.local_patch_radius]], | |
device=device, | |
dtype=torch.float32, | |
) | |
H_patch_points, W_patch_points = torch.where(patch_mask >= 0) | |
patch_points = torch.cat( | |
[H_patch_points[..., None], W_patch_points[..., None]], dim=-1 | |
) | |
# Fetch the circle region | |
patch_center_dist = torch.sqrt( | |
torch.sum((patch_points - patch_center) ** 2, dim=-1) | |
) | |
patch_points = patch_points[patch_center_dist <= self.local_patch_radius, :] | |
# Shift [0, 0] to the center | |
patch_points = patch_points - self.local_patch_radius | |
# Construct local patch mask | |
patch_points_shifted = ( | |
torch.unsqueeze(cand_points_round, dim=2) + patch_points[None, None, ...] | |
) | |
patch_dist = torch.sqrt( | |
torch.sum( | |
(torch.unsqueeze(cand_points, dim=2) - patch_points_shifted) ** 2, | |
dim=-1, | |
) | |
) | |
patch_dist_mask = patch_dist < dist_thresh[..., None] | |
# Get all points => num_points_center x num_patch_points x 2 | |
points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, max=H - 1).to( | |
torch.long | |
) | |
points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, max=W - 1).to( | |
torch.long | |
) | |
points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) | |
# Sample the feature (N x 64 x 81) | |
sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] | |
# Filtering using the valid mask | |
sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32) | |
if len(sampled_feat) == 0: | |
sampled_feat_lmax = torch.empty(0, 64) | |
else: | |
sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) | |
return sampled_feat_lmax | |