Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
4.81 kB
"""
Line segment detection from raw images.
"""
import time
import numpy as np
import torch
from torch.nn.functional import softmax
from .model_util import get_model
from .loss import get_loss_and_weights
from .line_detection import LineSegmentDetectionModule
from ..train import convert_junc_predictions
from ..misc.train_utils import adapt_checkpoint
def line_map_to_segments(junctions, line_map):
"""Convert a line map to a Nx2x2 list of segments."""
line_map_tmp = line_map.copy()
output_segments = np.zeros([0, 2, 2])
for idx in range(junctions.shape[0]):
# if no connectivity, just skip it
if line_map_tmp[idx, :].sum() == 0:
continue
# Record the line segment
else:
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
p1 = junctions[idx, :] # HW format
p2 = junctions[idx2, :]
single_seg = np.concatenate([p1[None, ...], p2[None, ...]], axis=0)
output_segments = np.concatenate(
(output_segments, single_seg[None, ...]), axis=0
)
# Update line_map
line_map_tmp[idx, idx2] = 0
line_map_tmp[idx2, idx] = 0
return output_segments
class LineDetector(object):
def __init__(
self, model_cfg, ckpt_path, device, line_detector_cfg, junc_detect_thresh=None
):
"""SOLD² line detector taking raw images as input.
Parameters:
model_cfg: config for CNN model
ckpt_path: path to the weights
line_detector_cfg: config file for the line detection module
"""
# Get loss weights if dynamic weighting
_, loss_weights = get_loss_and_weights(model_cfg, device)
self.device = device
# Initialize the cnn backbone
self.model = get_model(model_cfg, loss_weights)
checkpoint = torch.load(ckpt_path, map_location=self.device)
checkpoint = adapt_checkpoint(checkpoint["model_state_dict"])
self.model.load_state_dict(checkpoint)
self.model = self.model.to(self.device)
self.model = self.model.eval()
self.grid_size = model_cfg["grid_size"]
if junc_detect_thresh is not None:
self.junc_detect_thresh = junc_detect_thresh
else:
self.junc_detect_thresh = model_cfg.get("detection_thresh", 1 / 65)
self.max_num_junctions = model_cfg.get("max_num_junctions", 300)
# Initialize the line detector
self.line_detector_cfg = line_detector_cfg
self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
def __call__(
self, input_image, valid_mask=None, return_heatmap=False, profile=False
):
# Now we restrict input_image to 4D torch tensor
if (not len(input_image.shape) == 4) or (
not isinstance(input_image, torch.Tensor)
):
raise ValueError("[Error] the input image should be a 4D torch tensor.")
# Move the input to corresponding device
input_image = input_image.to(self.device)
# Forward of the CNN backbone
start_time = time.time()
with torch.no_grad():
net_outputs = self.model(input_image)
junc_np = convert_junc_predictions(
net_outputs["junctions"],
self.grid_size,
self.junc_detect_thresh,
self.max_num_junctions,
)
if valid_mask is None:
junctions = np.where(junc_np["junc_pred_nms"].squeeze())
else:
junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask)
junctions = np.concatenate(
[junctions[0][..., None], junctions[1][..., None]], axis=-1
)
if net_outputs["heatmap"].shape[1] == 2:
# Convert to single channel directly from here
heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
else:
heatmap = torch.sigmoid(net_outputs["heatmap"])
heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0]
# Run the line detector.
line_map, junctions, heatmap = self.line_detector.detect(
junctions, heatmap, device=self.device
)
heatmap = heatmap.cpu().numpy()
if isinstance(line_map, torch.Tensor):
line_map = line_map.cpu().numpy()
if isinstance(junctions, torch.Tensor):
junctions = junctions.cpu().numpy()
line_segments = line_map_to_segments(junctions, line_map)
end_time = time.time()
outputs = {"line_segments": line_segments}
if return_heatmap:
outputs["heatmap"] = heatmap
if profile:
outputs["time"] = end_time - start_time
return outputs