File size: 4,813 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
Line segment detection from raw images.
"""
import time
import numpy as np
import torch
from torch.nn.functional import softmax
from .model_util import get_model
from .loss import get_loss_and_weights
from .line_detection import LineSegmentDetectionModule
from ..train import convert_junc_predictions
from ..misc.train_utils import adapt_checkpoint
def line_map_to_segments(junctions, line_map):
"""Convert a line map to a Nx2x2 list of segments."""
line_map_tmp = line_map.copy()
output_segments = np.zeros([0, 2, 2])
for idx in range(junctions.shape[0]):
# if no connectivity, just skip it
if line_map_tmp[idx, :].sum() == 0:
continue
# Record the line segment
else:
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
p1 = junctions[idx, :] # HW format
p2 = junctions[idx2, :]
single_seg = np.concatenate([p1[None, ...], p2[None, ...]], axis=0)
output_segments = np.concatenate(
(output_segments, single_seg[None, ...]), axis=0
)
# Update line_map
line_map_tmp[idx, idx2] = 0
line_map_tmp[idx2, idx] = 0
return output_segments
class LineDetector(object):
def __init__(
self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None
):
"""SOLD² line detector taking raw images as input.
Parameters:
model_cfg: config for CNN model
ckpt_path: path to the weights
line_detector_cfg: config file for the line detection module
"""
# Get loss weights if dynamic weighting
_, loss_weights = get_loss_and_weights(model_cfg, device)
self.device = device
# Initialize the cnn backbone
self.model = get_model(model_cfg, loss_weights)
checkpoint = torch.load(ckpt_path, map_location=self.device)
checkpoint = adapt_checkpoint(checkpoint["model_state_dict"])
self.model.load_state_dict(checkpoint)
self.model = self.model.to(self.device)
self.model = self.model.eval()
self.grid_size = model_cfg["grid_size"]
if junc_detect_thresh is not None:
self.junc_detect_thresh = junc_detect_thresh
else:
self.junc_detect_thresh = model_cfg.get("detection_thresh", 1 / 65)
self.max_num_junctions = model_cfg.get("max_num_junctions", 300)
# Initialize the line detector
self.line_detector_cfg = line_detector_cfg
self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
def __call__(
self, input_image, valid_mask=None, return_heatmap=False, profile=False
):
# Now we restrict input_image to 4D torch tensor
if (not len(input_image.shape) == 4) or (
not isinstance(input_image, torch.Tensor)
):
raise ValueError("[Error] the input image should be a 4D torch tensor.")
# Move the input to corresponding device
input_image = input_image.to(self.device)
# Forward of the CNN backbone
start_time = time.time()
with torch.no_grad():
net_outputs = self.model(input_image)
junc_np = convert_junc_predictions(
net_outputs["junctions"],
self.grid_size,
self.junc_detect_thresh,
self.max_num_junctions,
)
if valid_mask is None:
junctions = np.where(junc_np["junc_pred_nms"].squeeze())
else:
junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask)
junctions = np.concatenate(
[junctions[0][..., None], junctions[1][..., None]], axis=-1
)
if net_outputs["heatmap"].shape[1] == 2:
# Convert to single channel directly from here
heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
else:
heatmap = torch.sigmoid(net_outputs["heatmap"])
heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0]
# Run the line detector.
line_map, junctions, heatmap = self.line_detector.detect(
junctions, heatmap, device=self.device
)
heatmap = heatmap.cpu().numpy()
if isinstance(line_map, torch.Tensor):
line_map = line_map.cpu().numpy()
if isinstance(junctions, torch.Tensor):
junctions = junctions.cpu().numpy()
line_segments = line_map_to_segments(junctions, line_map)
end_time = time.time()
outputs = {"line_segments": line_segments}
if return_heatmap:
outputs["heatmap"] = heatmap
if profile:
outputs["time"] = end_time - start_time
return outputs
|