|
""" |
|
Line segment detection from raw images. |
|
""" |
|
import time |
|
import numpy as np |
|
import torch |
|
from torch.nn.functional import softmax |
|
|
|
from .model_util import get_model |
|
from .loss import get_loss_and_weights |
|
from .line_detection import LineSegmentDetectionModule |
|
from ..train import convert_junc_predictions |
|
from ..misc.train_utils import adapt_checkpoint |
|
|
|
|
|
def line_map_to_segments(junctions, line_map): |
|
""" Convert a line map to a Nx2x2 list of segments. """ |
|
line_map_tmp = line_map.copy() |
|
|
|
output_segments = np.zeros([0, 2, 2]) |
|
for idx in range(junctions.shape[0]): |
|
|
|
if line_map_tmp[idx, :].sum() == 0: |
|
continue |
|
|
|
else: |
|
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: |
|
p1 = junctions[idx, :] |
|
p2 = junctions[idx2, :] |
|
single_seg = np.concatenate([p1[None, ...], p2[None, ...]], |
|
axis=0) |
|
output_segments = np.concatenate( |
|
(output_segments, single_seg[None, ...]), axis=0) |
|
|
|
|
|
line_map_tmp[idx, idx2] = 0 |
|
line_map_tmp[idx2, idx] = 0 |
|
|
|
return output_segments |
|
|
|
|
|
class LineDetector(object): |
|
def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, |
|
junc_detect_thresh=None): |
|
""" SOLD² line detector taking raw images as input. |
|
Parameters: |
|
model_cfg: config for CNN model |
|
ckpt_path: path to the weights |
|
line_detector_cfg: config file for the line detection module |
|
""" |
|
|
|
_, loss_weights = get_loss_and_weights(model_cfg, device) |
|
self.device = device |
|
|
|
|
|
self.model = get_model(model_cfg, loss_weights) |
|
checkpoint = torch.load(ckpt_path, map_location=self.device) |
|
checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) |
|
self.model.load_state_dict(checkpoint) |
|
self.model = self.model.to(self.device) |
|
self.model = self.model.eval() |
|
|
|
self.grid_size = model_cfg["grid_size"] |
|
|
|
if junc_detect_thresh is not None: |
|
self.junc_detect_thresh = junc_detect_thresh |
|
else: |
|
self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65) |
|
self.max_num_junctions = model_cfg.get("max_num_junctions", 300) |
|
|
|
|
|
self.line_detector_cfg = line_detector_cfg |
|
self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) |
|
|
|
def __call__(self, input_image, valid_mask=None, |
|
return_heatmap=False, profile=False): |
|
|
|
if ((not len(input_image.shape) == 4) |
|
or (not isinstance(input_image, torch.Tensor))): |
|
raise ValueError( |
|
"[Error] the input image should be a 4D torch tensor.") |
|
|
|
|
|
input_image = input_image.to(self.device) |
|
|
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
net_outputs = self.model(input_image) |
|
|
|
junc_np = convert_junc_predictions( |
|
net_outputs["junctions"], self.grid_size, |
|
self.junc_detect_thresh, self.max_num_junctions) |
|
if valid_mask is None: |
|
junctions = np.where(junc_np["junc_pred_nms"].squeeze()) |
|
else: |
|
junctions = np.where(junc_np["junc_pred_nms"].squeeze() |
|
* valid_mask) |
|
junctions = np.concatenate( |
|
[junctions[0][..., None], junctions[1][..., None]], axis=-1) |
|
|
|
if net_outputs["heatmap"].shape[1] == 2: |
|
|
|
heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] |
|
else: |
|
heatmap = torch.sigmoid(net_outputs["heatmap"]) |
|
heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0] |
|
|
|
|
|
line_map, junctions, heatmap = self.line_detector.detect( |
|
junctions, heatmap, device=self.device) |
|
heatmap = heatmap.cpu().numpy() |
|
if isinstance(line_map, torch.Tensor): |
|
line_map = line_map.cpu().numpy() |
|
if isinstance(junctions, torch.Tensor): |
|
junctions = junctions.cpu().numpy() |
|
line_segments = line_map_to_segments(junctions, line_map) |
|
end_time = time.time() |
|
|
|
outputs = {"line_segments": line_segments} |
|
|
|
if return_heatmap: |
|
outputs["heatmap"] = heatmap |
|
if profile: |
|
outputs["time"] = end_time - start_time |
|
|
|
return outputs |
|
|