""" Implements the full pipeline from raw images to line matches. """ import time import cv2 import numpy as np import torch import torch.nn.functional as F from torch.nn.functional import softmax from .model_util import get_model from .loss import get_loss_and_weights from .metrics import super_nms from .line_detection import LineSegmentDetectionModule from .line_matching import WunschLineMatcher from ..train import convert_junc_predictions from ..misc.train_utils import adapt_checkpoint from .line_detector import line_map_to_segments class LineMatcher(object): """Full line matcher including line detection and matching with the Needleman-Wunsch algorithm.""" def __init__( self, model_cfg, ckpt_path, device, line_detector_cfg, line_matcher_cfg, multiscale=False, scales=[1.0, 2.0], ): # Get loss weights if dynamic weighting _, loss_weights = get_loss_and_weights(model_cfg, device) self.device = device # Initialize the cnn backbone self.model = get_model(model_cfg, loss_weights) checkpoint = torch.load(ckpt_path, map_location=self.device) checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) self.model.load_state_dict(checkpoint) self.model = self.model.to(self.device) self.model = self.model.eval() self.grid_size = model_cfg["grid_size"] self.junc_detect_thresh = model_cfg["detection_thresh"] self.max_num_junctions = model_cfg.get("max_num_junctions", 300) # Initialize the line detector self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) self.multiscale = multiscale self.scales = scales # Initialize the line matcher self.line_matcher = WunschLineMatcher(**line_matcher_cfg) # Print some debug messages for key, val in line_detector_cfg.items(): print(f"[Debug] {key}: {val}") # print("[Debug] detect_thresh: %f" % (line_detector_cfg["detect_thresh"])) # print("[Debug] num_samples: %d" % (line_detector_cfg["num_samples"])) # Perform line detection and descriptor inference on a single image def line_detection( self, input_image, valid_mask=None, desc_only=False, profile=False ): # Restrict input_image to 4D torch tensor if (not len(input_image.shape) == 4) or ( not isinstance(input_image, torch.Tensor) ): raise ValueError("[Error] the input image should be a 4D torch tensor") # Move the input to corresponding device input_image = input_image.to(self.device) # Forward of the CNN backbone start_time = time.time() with torch.no_grad(): net_outputs = self.model(input_image) outputs = {"descriptor": net_outputs["descriptors"]} if not desc_only: junc_np = convert_junc_predictions( net_outputs["junctions"], self.grid_size, self.junc_detect_thresh, self.max_num_junctions, ) if valid_mask is None: junctions = np.where(junc_np["junc_pred_nms"].squeeze()) else: junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask) junctions = np.concatenate( [junctions[0][..., None], junctions[1][..., None]], axis=-1 ) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here heatmap = ( softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] .cpu() .numpy() .transpose(0, 2, 3, 1) ) else: heatmap = ( torch.sigmoid(net_outputs["heatmap"]) .cpu() .numpy() .transpose(0, 2, 3, 1) ) heatmap = heatmap[0, :, :, 0] # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( junctions, heatmap, device=self.device ) if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() if isinstance(junctions, torch.Tensor): junctions = junctions.cpu().numpy() outputs["heatmap"] = heatmap.cpu().numpy() outputs["junctions"] = junctions # If it's a line map with multiple detect_thresh and inlier_thresh if len(line_map.shape) > 2: num_detect_thresh = line_map.shape[0] num_inlier_thresh = line_map.shape[1] line_segments = [] for detect_idx in range(num_detect_thresh): line_segments_inlier = [] for inlier_idx in range(num_inlier_thresh): line_map_tmp = line_map[detect_idx, inlier_idx, :, :] line_segments_tmp = line_map_to_segments( junctions, line_map_tmp ) line_segments_inlier.append(line_segments_tmp) line_segments.append(line_segments_inlier) else: line_segments = line_map_to_segments(junctions, line_map) outputs["line_segments"] = line_segments end_time = time.time() if profile: outputs["time"] = end_time - start_time return outputs # Perform line detection and descriptor inference at multiple scales def multiscale_line_detection( self, input_image, valid_mask=None, desc_only=False, profile=False, scales=[1.0, 2.0], aggregation="mean", ): # Restrict input_image to 4D torch tensor if (not len(input_image.shape) == 4) or ( not isinstance(input_image, torch.Tensor) ): raise ValueError("[Error] the input image should be a 4D torch tensor") # Move the input to corresponding device input_image = input_image.to(self.device) img_size = input_image.shape[2:4] desc_size = tuple(np.array(img_size) // 4) # Run the inference at multiple image scales start_time = time.time() junctions, heatmaps, descriptors = [], [], [] for s in scales: # Resize the image resized_img = F.interpolate(input_image, scale_factor=s, mode="bilinear") # Forward of the CNN backbone with torch.no_grad(): net_outputs = self.model(resized_img) descriptors.append( F.interpolate( net_outputs["descriptors"], size=desc_size, mode="bilinear" ) ) if not desc_only: junc_prob = convert_junc_predictions( net_outputs["junctions"], self.grid_size )["junc_pred"] junctions.append( cv2.resize( junc_prob.squeeze(), (img_size[1], img_size[0]), interpolation=cv2.INTER_LINEAR, ) ) if net_outputs["heatmap"].shape[1] == 2: # Convert to single channel directly from here heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] else: heatmap = torch.sigmoid(net_outputs["heatmap"]) heatmaps.append(F.interpolate(heatmap, size=img_size, mode="bilinear")) # Aggregate the results if aggregation == "mean": # Aggregation through the mean activation descriptors = torch.stack(descriptors, dim=0).mean(0) else: # Aggregation through the max activation descriptors = torch.stack(descriptors, dim=0).max(0)[0] outputs = {"descriptor": descriptors} if not desc_only: if aggregation == "mean": junctions = np.stack(junctions, axis=0).mean(0)[None] heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :] heatmap = heatmap.cpu().numpy() else: junctions = np.stack(junctions, axis=0).max(0)[None] heatmap = torch.stack(heatmaps, dim=0).max(0)[0][0, 0, :, :] heatmap = heatmap.cpu().numpy() # Extract junctions junc_pred_nms = super_nms( junctions[..., None], self.grid_size, self.junc_detect_thresh, self.max_num_junctions, ) if valid_mask is None: junctions = np.where(junc_pred_nms.squeeze()) else: junctions = np.where(junc_pred_nms.squeeze() * valid_mask) junctions = np.concatenate( [junctions[0][..., None], junctions[1][..., None]], axis=-1 ) # Run the line detector. line_map, junctions, heatmap = self.line_detector.detect( junctions, heatmap, device=self.device ) if isinstance(line_map, torch.Tensor): line_map = line_map.cpu().numpy() if isinstance(junctions, torch.Tensor): junctions = junctions.cpu().numpy() outputs["heatmap"] = heatmap.cpu().numpy() outputs["junctions"] = junctions # If it's a line map with multiple detect_thresh and inlier_thresh if len(line_map.shape) > 2: num_detect_thresh = line_map.shape[0] num_inlier_thresh = line_map.shape[1] line_segments = [] for detect_idx in range(num_detect_thresh): line_segments_inlier = [] for inlier_idx in range(num_inlier_thresh): line_map_tmp = line_map[detect_idx, inlier_idx, :, :] line_segments_tmp = line_map_to_segments( junctions, line_map_tmp ) line_segments_inlier.append(line_segments_tmp) line_segments.append(line_segments_inlier) else: line_segments = line_map_to_segments(junctions, line_map) outputs["line_segments"] = line_segments end_time = time.time() if profile: outputs["time"] = end_time - start_time return outputs def __call__(self, images, valid_masks=[None, None], profile=False): # Line detection and descriptor inference on both images if self.multiscale: forward_outputs = [ self.multiscale_line_detection( images[0], valid_masks[0], profile=profile, scales=self.scales ), self.multiscale_line_detection( images[1], valid_masks[1], profile=profile, scales=self.scales ), ] else: forward_outputs = [ self.line_detection(images[0], valid_masks[0], profile=profile), self.line_detection(images[1], valid_masks[1], profile=profile), ] line_seg1 = forward_outputs[0]["line_segments"] line_seg2 = forward_outputs[1]["line_segments"] desc1 = forward_outputs[0]["descriptor"] desc2 = forward_outputs[1]["descriptor"] # Match the lines in both images start_time = time.time() matches = self.line_matcher.forward(line_seg1, line_seg2, desc1, desc2) end_time = time.time() outputs = {"line_segments": [line_seg1, line_seg2], "matches": matches} if profile: outputs["line_detection_time"] = ( forward_outputs[0]["time"] + forward_outputs[1]["time"] ) outputs["line_matching_time"] = end_time - start_time return outputs