File size: 8,414 Bytes
424188c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import torch
import numpy as np
import scipy.ndimage.filters as filters
import cv2
import itertools

NEIGHBOUR_SIZE = 5
MATCH_THRESH = 5
LOCAL_MAX_THRESH = 0.01
viz_count = 0

# pre-compute all combinations to generate edge candidates faster
all_combibations = dict()
for length in range(2, 351):
    ids = np.arange(length)
    combs = np.array(list(itertools.combinations(ids, 2)))
    all_combibations[length] = combs


def prepare_edge_data(c_outputs, annots, images, max_corner_num):
    bs = c_outputs.shape[0]
    # prepares parameters for each sample of the batch
    all_results = list()

    for b_i in range(bs):
        annot = annots[b_i]
        output = c_outputs[b_i]
        results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}, max_corner_num)
        all_results.append(results)

    processed_corners = [item['corners'] for item in all_results]
    edge_coords = [item['edges'] for item in all_results]
    edge_labels = [item['labels'] for item in all_results]

    edge_info = {
        'edge_coords': edge_coords,
        'edge_labels': edge_labels,
        'processed_corners': processed_corners
    }

    edge_data = collate_edge_info(edge_info)
    return edge_data


def process_annot(annot, do_round=True):
    corners = np.array(list(annot.keys()))
    ind = np.lexsort(corners.T)  # sort the g.t. corners to fix the order for the matching later
    corners = corners[ind]  # sorted by y, then x
    corner_mapping = {tuple(k): v for v, k in enumerate(corners)}

    edges = list()
    for c, connections in annot.items():
        for other_c in connections:
            edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
            edges.append(edge_pair)
    corner_degrees = [len(annot[tuple(c)]) for c in corners]
    if do_round:
        corners = corners.round()
    return corners, edges, corner_degrees


def process_each_sample(data, max_corner_num):
    annot = data['annot']
    output = data['output']

    preds = output.detach().cpu().numpy()

    data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
    maxima = (preds == data_max)
    data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
    diff = ((data_max - data_min) > 0)
    maxima[diff == 0] = 0
    local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
    pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]]  # to (x, y format)

    # produce edge labels labels from pred corners here

    processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot, max_corner_num)
    # global viz_count
    # viz_img = data['viz_img']
    #output_path = './viz_training/{}_example_gt.png'.format(viz_count)
    #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
    #viz_count += 1

    results = {
        'corners': processed_corners,
        'edges': edges,
        'labels': labels,
    }
    return results


def get_edge_label_mix_gt(pred_corners, annot, max_corner_num):
    ind = np.lexsort(pred_corners.T)  # sort the pred corners to fix the order for matching
    pred_corners = pred_corners[ind]  # sorted by y, then x
    gt_corners, edge_pairs, corner_degrees = process_annot(annot)

    output_to_gt = dict()
    gt_to_output = dict()
    diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
    diff = diff.T

    if len(pred_corners) > 0:
        for target_i, target in enumerate(gt_corners):
            dist = diff[target_i]
            if len(output_to_gt) > 0:
                dist[list(output_to_gt.keys())] = 1000  # ignore already matched pred corners
            min_dist = dist.min()
            min_idx = dist.argmin()
            if min_dist < MATCH_THRESH and min_idx not in output_to_gt:  # a positive match
                output_to_gt[min_idx] = (target_i, min_dist)
                gt_to_output[target_i] = min_idx

    all_corners = gt_corners.copy()

    # replace matched g.t. corners with pred corners
    for gt_i in range(len(gt_corners)):
       if gt_i in gt_to_output:
            all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]

    nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
    nm_pred_ids = np.random.permutation(nm_pred_ids)
    if len(nm_pred_ids) > 0:
        nm_pred_corners = pred_corners[nm_pred_ids]
        #if len(nm_pred_ids) + len(all_corners) <= 150:
        if len(nm_pred_ids) + len(all_corners) <= max_corner_num:
            all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
        else:
            #all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
            all_corners = np.concatenate([all_corners, nm_pred_corners[:(max_corner_num - len(gt_corners)), :]], axis=0)

    processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)

    return processed_corners, edges, labels


def _get_edges(corners, edge_pairs):
    ind = np.lexsort(corners.T)
    corners = corners[ind]  # sorted by y, then x
    corners = corners.round()
    id_mapping = {old: new for new, old in enumerate(ind)}

    all_ids = all_combibations[len(corners)]
    edges = corners[all_ids]
    labels = np.zeros(edges.shape[0])

    N = len(corners)
    edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
    edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
    pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
    labels[pos_ids] = 1

    edge_ids = np.array(all_ids)
    return corners, edges, edge_ids, labels


def collate_edge_info(data):
    batched_data = {}
    lengths_info = {}
    for field in data.keys():
        batch_values = data[field]
        all_lens = [len(value) for value in batch_values]
        max_len = max(all_lens)
        pad_value = 0
        batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
        batch_values = np.stack(batch_values, axis=0)

        if field in ['edge_coords', 'edge_labels', 'gt_values']:
            batch_values = torch.Tensor(batch_values).long()
        if field in ['processed_corners', 'edge_coords']:
            lengths_info[field] = all_lens
        batched_data[field] = batch_values

    # Add length and mask into the data, the mask if for Transformers' input format, True means padding
    for field, lengths in lengths_info.items():
        lengths_str = field + '_lengths'
        batched_data[lengths_str] = torch.Tensor(lengths).long()
        mask = torch.arange(max(lengths))
        mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
        mask = mask >= batched_data[lengths_str].unsqueeze(-1)
        mask_str = field + '_mask'
        batched_data[mask_str] = mask

    return batched_data


def pad_sequence(seq, length, pad_value=0):
    if len(seq) == length:
        return seq
    else:
        pad_len = length - len(seq)
        if len(seq.shape) == 1:
            if pad_value == 0:
                paddings = np.zeros([pad_len, ])
            else:
                paddings = np.ones([pad_len, ]) * pad_value
        else:
            if pad_value == 0:
                paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
            else:
                paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
        padded_seq = np.concatenate([seq, paddings], axis=0)
        return padded_seq


def get_infer_edge_pairs(corners, confs):
    ind = np.lexsort(corners.T)
    corners = corners[ind]  # sorted by y, then x
    confs = confs[ind]

    edge_ids = all_combibations[len(corners)]
    edge_coords = corners[edge_ids]

    edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
    mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
    edge_ids = torch.tensor(np.array(edge_ids))
    return corners, confs, edge_coords, mask, edge_ids


def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
    image = image.transpose([1, 2, 0])
    image = (image * 255).astype(np.uint8)
    image = np.ascontiguousarray(image)

    for edge, label in zip(edges, edge_labels):
        if label == 1:
            cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)

    for c in corners:
        cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)

    cv2.imwrite(save_path, image)