""" Line segment detection from raw images. """ import time import numpy as np import torch from torch.nn.functional import softmax from .model_util import get_model from .loss import get_loss_and_weights from .line_detection import LineSegmentDetectionModule from ..train import convert_junc_predictions from ..misc.train_utils import adapt_checkpoint def line_map_to_segments(junctions, line_map): """ Convert a line map to a Nx2x2 list of segments. """ line_map_tmp = line_map.copy() output_segments = np.zeros([0, 2, 2]) for idx in range(junctions.shape[0]): # if no connectivity, just skip it if line_map_tmp[idx, :].sum() == 0: continue # Record the line segment else: for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: p1 = junctions[idx, :] # HW format p2 = junctions[idx2, :] single_seg = np.concatenate([p1[None, ...], p2[None, ...]], axis=0) output_segments = np.concatenate( (output_segments, single_seg[None, ...]), axis=0) # Update line_map line_map_tmp[idx, idx2] = 0 line_map_tmp[idx2, idx] = 0 return output_segments class LineDetector(object): def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None): """ SOLDĀ² line detector taking raw images as input. Parameters: model_cfg: config for CNN model ckpt_path: path to the weights line_detector_cfg: config file for the line detection module """ # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) self.model.load_state_dict(checkpoint) self.model = self.model.to(self.device) self.model = self.model.eval() self.grid_size = model_cfg["grid_size"] if junc_detect_thresh is not None: self.junc_detect_thresh = junc_detect_thresh else: self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65) self.max_num_junctions = model_cfg.get("max_num_junctions", 300) # Initialize the line detector self.line_detector_cfg = line_detector_cfg self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) def __call__(self, input_image, valid_mask=None, return_heatmap=False, profile=False): # Now we restrict input_image to 4D torch tensor if ((not len(input_image.shape) == 4) or (not isinstance(input_image, torch.Tensor))): raise ValueError( "[Error] the input image should be a 4D torch tensor.") # Move the input to corresponding device input_image = input_image.to(self.device) # Forward of the CNN backbone start_time = time.time() with torch.no_grad(): net_outputs = self.model(input_image) junc_np = convert_junc_predictions( net_outputs["junctions"], self.grid_size, self.junc_detect_thresh, self.max_num_junctions) if valid_mask is None: junctions = np.where(junc_np["junc_pred_nms"].squeeze()) else: junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask) junctions = np.concatenate( [junctions[0][..., None], junctions[1][..., None]], axis=-1) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] else: heatmap = torch.sigmoid(net_outputs["heatmap"]) heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0] # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( junctions, heatmap, device=self.device) heatmap = heatmap.cpu().numpy() if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() if isinstance(junctions, torch.Tensor): junctions = junctions.cpu().numpy() line_segments = line_map_to_segments(junctions, line_map) end_time = time.time() outputs = {"line_segments": line_segments} if return_heatmap: outputs["heatmap"] = heatmap if profile: outputs["time"] = end_time - start_time return outputs