HEAT / models /corner_to_edge.py
Egrt's picture
init
424188c
import torch
import numpy as np
import scipy.ndimage.filters as filters
import cv2
import itertools
NEIGHBOUR_SIZE = 5
MATCH_THRESH = 5
LOCAL_MAX_THRESH = 0.01
viz_count = 0
# pre-compute all combinations to generate edge candidates faster
all_combibations = dict()
for length in range(2, 351):
ids = np.arange(length)
combs = np.array(list(itertools.combinations(ids, 2)))
all_combibations[length] = combs
def prepare_edge_data(c_outputs, annots, images, max_corner_num):
bs = c_outputs.shape[0]
# prepares parameters for each sample of the batch
all_results = list()
for b_i in range(bs):
annot = annots[b_i]
output = c_outputs[b_i]
results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}, max_corner_num)
all_results.append(results)
processed_corners = [item['corners'] for item in all_results]
edge_coords = [item['edges'] for item in all_results]
edge_labels = [item['labels'] for item in all_results]
edge_info = {
'edge_coords': edge_coords,
'edge_labels': edge_labels,
'processed_corners': processed_corners
}
edge_data = collate_edge_info(edge_info)
return edge_data
def process_annot(annot, do_round=True):
corners = np.array(list(annot.keys()))
ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later
corners = corners[ind] # sorted by y, then x
corner_mapping = {tuple(k): v for v, k in enumerate(corners)}
edges = list()
for c, connections in annot.items():
for other_c in connections:
edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
edges.append(edge_pair)
corner_degrees = [len(annot[tuple(c)]) for c in corners]
if do_round:
corners = corners.round()
return corners, edges, corner_degrees
def process_each_sample(data, max_corner_num):
annot = data['annot']
output = data['output']
preds = output.detach().cpu().numpy()
data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
maxima = (preds == data_max)
data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
diff = ((data_max - data_min) > 0)
maxima[diff == 0] = 0
local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format)
# produce edge labels labels from pred corners here
processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot, max_corner_num)
# global viz_count
# viz_img = data['viz_img']
#output_path = './viz_training/{}_example_gt.png'.format(viz_count)
#_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
#viz_count += 1
results = {
'corners': processed_corners,
'edges': edges,
'labels': labels,
}
return results
def get_edge_label_mix_gt(pred_corners, annot, max_corner_num):
ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching
pred_corners = pred_corners[ind] # sorted by y, then x
gt_corners, edge_pairs, corner_degrees = process_annot(annot)
output_to_gt = dict()
gt_to_output = dict()
diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
diff = diff.T
if len(pred_corners) > 0:
for target_i, target in enumerate(gt_corners):
dist = diff[target_i]
if len(output_to_gt) > 0:
dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners
min_dist = dist.min()
min_idx = dist.argmin()
if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match
output_to_gt[min_idx] = (target_i, min_dist)
gt_to_output[target_i] = min_idx
all_corners = gt_corners.copy()
# replace matched g.t. corners with pred corners
for gt_i in range(len(gt_corners)):
if gt_i in gt_to_output:
all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]
nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
nm_pred_ids = np.random.permutation(nm_pred_ids)
if len(nm_pred_ids) > 0:
nm_pred_corners = pred_corners[nm_pred_ids]
#if len(nm_pred_ids) + len(all_corners) <= 150:
if len(nm_pred_ids) + len(all_corners) <= max_corner_num:
all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
else:
#all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
all_corners = np.concatenate([all_corners, nm_pred_corners[:(max_corner_num - len(gt_corners)), :]], axis=0)
processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)
return processed_corners, edges, labels
def _get_edges(corners, edge_pairs):
ind = np.lexsort(corners.T)
corners = corners[ind] # sorted by y, then x
corners = corners.round()
id_mapping = {old: new for new, old in enumerate(ind)}
all_ids = all_combibations[len(corners)]
edges = corners[all_ids]
labels = np.zeros(edges.shape[0])
N = len(corners)
edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
labels[pos_ids] = 1
edge_ids = np.array(all_ids)
return corners, edges, edge_ids, labels
def collate_edge_info(data):
batched_data = {}
lengths_info = {}
for field in data.keys():
batch_values = data[field]
all_lens = [len(value) for value in batch_values]
max_len = max(all_lens)
pad_value = 0
batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
batch_values = np.stack(batch_values, axis=0)
if field in ['edge_coords', 'edge_labels', 'gt_values']:
batch_values = torch.Tensor(batch_values).long()
if field in ['processed_corners', 'edge_coords']:
lengths_info[field] = all_lens
batched_data[field] = batch_values
# Add length and mask into the data, the mask if for Transformers' input format, True means padding
for field, lengths in lengths_info.items():
lengths_str = field + '_lengths'
batched_data[lengths_str] = torch.Tensor(lengths).long()
mask = torch.arange(max(lengths))
mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
mask = mask >= batched_data[lengths_str].unsqueeze(-1)
mask_str = field + '_mask'
batched_data[mask_str] = mask
return batched_data
def pad_sequence(seq, length, pad_value=0):
if len(seq) == length:
return seq
else:
pad_len = length - len(seq)
if len(seq.shape) == 1:
if pad_value == 0:
paddings = np.zeros([pad_len, ])
else:
paddings = np.ones([pad_len, ]) * pad_value
else:
if pad_value == 0:
paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
else:
paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
padded_seq = np.concatenate([seq, paddings], axis=0)
return padded_seq
def get_infer_edge_pairs(corners, confs):
ind = np.lexsort(corners.T)
corners = corners[ind] # sorted by y, then x
confs = confs[ind]
edge_ids = all_combibations[len(corners)]
edge_coords = corners[edge_ids]
edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
edge_ids = torch.tensor(np.array(edge_ids))
return corners, confs, edge_coords, mask, edge_ids
def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
image = image.transpose([1, 2, 0])
image = (image * 255).astype(np.uint8)
image = np.ascontiguousarray(image)
for edge, label in zip(edges, edge_labels):
if label == 1:
cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)
for c in corners:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
cv2.imwrite(save_path, image)