File size: 4,870 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Line segment detection from raw images.
"""
import time
import numpy as np
import torch
from torch.nn.functional import softmax

from .model_util import get_model
from .loss import get_loss_and_weights
from .line_detection import LineSegmentDetectionModule
from ..train import convert_junc_predictions
from ..misc.train_utils import adapt_checkpoint


def line_map_to_segments(junctions, line_map):
    """ Convert a line map to a Nx2x2 list of segments. """ 
    line_map_tmp = line_map.copy()

    output_segments = np.zeros([0, 2, 2])
    for idx in range(junctions.shape[0]):
        # if no connectivity, just skip it
        if line_map_tmp[idx, :].sum() == 0:
            continue
        # Record the line segment
        else:
            for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
                p1 = junctions[idx, :]  # HW format
                p2 = junctions[idx2, :]
                single_seg = np.concatenate([p1[None, ...], p2[None, ...]],
                                            axis=0)
                output_segments = np.concatenate(
                    (output_segments, single_seg[None, ...]), axis=0)
                
                # Update line_map
                line_map_tmp[idx, idx2] = 0
                line_map_tmp[idx2, idx] = 0
    
    return output_segments


class LineDetector(object):
    def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg,
                 junc_detect_thresh=None):
        """ SOLD² line detector taking raw images as input.
        Parameters:
            model_cfg: config for CNN model
            ckpt_path: path to the weights
            line_detector_cfg: config file for the line detection module
        """
        # Get loss weights if dynamic weighting
        _, loss_weights = get_loss_and_weights(model_cfg, device)
        self.device = device
        
        # Initialize the cnn backbone
        self.model = get_model(model_cfg, loss_weights)
        checkpoint = torch.load(ckpt_path, map_location=self.device)
        checkpoint = adapt_checkpoint(checkpoint["model_state_dict"])
        self.model.load_state_dict(checkpoint)
        self.model = self.model.to(self.device)
        self.model = self.model.eval()

        self.grid_size = model_cfg["grid_size"]

        if junc_detect_thresh is not None:
            self.junc_detect_thresh = junc_detect_thresh
        else:
            self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65)
        self.max_num_junctions = model_cfg.get("max_num_junctions", 300)

        # Initialize the line detector
        self.line_detector_cfg = line_detector_cfg
        self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
    
    def __call__(self, input_image, valid_mask=None,
                 return_heatmap=False, profile=False):
        # Now we restrict input_image to 4D torch tensor
        if ((not len(input_image.shape) == 4)
            or (not isinstance(input_image, torch.Tensor))):
            raise ValueError(
        "[Error] the input image should be a 4D torch tensor.")

        # Move the input to corresponding device
        input_image = input_image.to(self.device)

        # Forward of the CNN backbone
        start_time = time.time()
        with torch.no_grad():
            net_outputs = self.model(input_image)

        junc_np = convert_junc_predictions(
            net_outputs["junctions"], self.grid_size,
            self.junc_detect_thresh, self.max_num_junctions)
        if valid_mask is None:
            junctions = np.where(junc_np["junc_pred_nms"].squeeze())
        else:
            junctions = np.where(junc_np["junc_pred_nms"].squeeze()
                                 * valid_mask)
        junctions = np.concatenate(
            [junctions[0][..., None], junctions[1][..., None]], axis=-1)

        if net_outputs["heatmap"].shape[1] == 2:
            # Convert to single channel directly from here
            heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
        else:
            heatmap = torch.sigmoid(net_outputs["heatmap"])
        heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0]

        # Run the line detector.
        line_map, junctions, heatmap = self.line_detector.detect(
            junctions, heatmap, device=self.device)
        heatmap = heatmap.cpu().numpy()
        if isinstance(line_map, torch.Tensor):
            line_map = line_map.cpu().numpy()
        if isinstance(junctions, torch.Tensor):
            junctions = junctions.cpu().numpy()
        line_segments = line_map_to_segments(junctions, line_map)
        end_time = time.time()

        outputs = {"line_segments": line_segments}

        if return_heatmap:
            outputs["heatmap"] = heatmap
        if profile:
            outputs["time"] = end_time - start_time
        
        return outputs