HEAT / infer.py
Egrt's picture
init
424188c
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets.outdoor_buildings import OutdoorBuildingDataset
from datasets.s3d_floorplans import S3DFloorplanDataset
from datasets.data_utils import collate_fn, get_pixel_features
from models.resnet import ResNetBackbone
from models.corner_models import HeatCorner
from models.edge_models import HeatEdge
from models.corner_to_edge import get_infer_edge_pairs
from utils.geometry_utils import corner_eval
import numpy as np
import cv2
import os
import scipy.ndimage.filters as filters
import matplotlib.pyplot as plt
from metrics.get_metric import compute_metrics, get_recall_and_precision
import skimage
import argparse
def visualize_cond_generation(positive_pixels, confs, image, save_path, gt_corners=None, prec=None, recall=None,
image_masks=None, edges=None, edge_confs=None):
image = image.copy() # get a new copy of the original image
if confs is not None:
viz_confs = confs
if edges is not None:
preds = positive_pixels.astype(int)
c_degrees = dict()
for edge_i, edge_pair in enumerate(edges):
conf = (edge_confs[edge_i] * 2) - 1
cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
for idx, c in enumerate(positive_pixels):
if edges is not None and idx not in c_degrees:
continue
if confs is None:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
else:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
# if edges is not None:
# cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
# 0.5, (255, 0, 0), 1, cv2.LINE_AA)
if gt_corners is not None:
for c in gt_corners:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
if image_masks is not None:
mask_ids = np.where(image_masks == 1)[0]
for mask_id in mask_ids:
y_idx = mask_id // 64
x_idx = (mask_id - y_idx * 64)
x_coord = x_idx * 4
y_coord = y_idx * 4
cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
# if confs is not None:
# cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
# 0.5, (255, 255, 0), 1, cv2.LINE_AA)
if prec is not None:
if isinstance(prec, tuple):
cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
else:
cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
cv2.imwrite(save_path, image)
def corner_nms(preds, confs, image_size):
data = np.zeros([image_size, image_size])
neighborhood_size = 5
threshold = 0
for i in range(len(preds)):
data[preds[i, 1], preds[i, 0]] = confs[i]
data_max = filters.maximum_filter(data, neighborhood_size)
maxima = (data == data_max)
data_min = filters.minimum_filter(data, neighborhood_size)
diff = ((data_max - data_min) > threshold)
maxima[diff == 0] = 0
results = np.where(maxima > 0)
filtered_preds = np.stack([results[1], results[0]], axis=-1)
new_confs = list()
for i, pred in enumerate(filtered_preds):
new_confs.append(data[pred[1], pred[0]])
new_confs = np.array(new_confs)
return filtered_preds, new_confs
def main(dataset, ckpt_path, image_size, viz_base, save_base, infer_times):
ckpt = torch.load(ckpt_path)
print('Load from ckpts of epoch {}'.format(ckpt['epoch']))
ckpt_args = ckpt['args']
if dataset == 'outdoor':
data_path = './data/outdoor/cities_dataset'
det_path = './data/outdoor/det_final'
test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='test', image_size=image_size, rand_aug=False,
inference=True)
elif dataset == 's3d_floorplan':
data_path = './data/s3d_floorplan'
test_dataset = S3DFloorplanDataset(data_path, phase='test', rand_aug=False, inference=True)
else:
raise ValueError('Unknown dataset type: {}'.format(dataset))
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,
collate_fn=collate_fn)
backbone = ResNetBackbone()
strides = backbone.strides
num_channels = backbone.num_channels
backbone = nn.DataParallel(backbone)
backbone = backbone.cuda()
backbone.eval()
corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
corner_model = nn.DataParallel(corner_model)
corner_model = corner_model.cuda()
corner_model.eval()
edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
edge_model = nn.DataParallel(edge_model)
edge_model = edge_model.cuda()
edge_model.eval()
backbone.load_state_dict(ckpt['backbone'])
corner_model.load_state_dict(ckpt['corner_model'])
edge_model.load_state_dict(ckpt['edge_model'])
print('Loaded saved model from {}'.format(ckpt_path))
if not os.path.exists(viz_base):
os.makedirs(viz_base)
if not os.path.exists(save_base):
os.makedirs(save_base)
all_prec = list()
all_recall = list()
corner_tp = 0.0
corner_fp = 0.0
corner_length = 0.0
edge_tp = 0.0
edge_fp = 0.0
edge_length = 0.0
region_tp = 0.0
region_fp = 0.0
region_length = 0.0
# get the positional encodings for all pixels
pixels, pixel_features = get_pixel_features(image_size=image_size)
for data_i, data in enumerate(test_dataloader):
image = data['img'].cuda()
img_path = data['img_path'][0]
annot_path = data['annot_path'][0]
annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
with torch.no_grad():
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = get_results(image, annot, backbone,
corner_model,
edge_model,
pixels, pixel_features,
ckpt_args, infer_times,
corner_thresh=0.01,
image_size=image_size)
# viz_image = cv2.imread(img_path)
positive_pixels = np.array(list(annot.keys())).round()
viz_image = data['raw_img'][0].cpu().numpy().transpose(1, 2, 0)
viz_image = (viz_image * 255).astype(np.uint8)
# visualize G.T.
gt_path = os.path.join(viz_base, '{}_gt.png'.format(data_i))
visualize_cond_generation(positive_pixels, None, viz_image, gt_path, gt_corners=None, image_masks=None)
if len(pred_corners) > 0:
prec, recall = corner_eval(positive_pixels, pred_corners)
else:
prec = recall = 0
all_prec.append(prec)
all_recall.append(recall)
if pred_confs.shape[0] == 0:
pred_confs = None
if image_size != 256:
pred_corners_viz = pred_corners * (image_size / 256)
else:
pred_corners_viz = pred_corners
recon_path = os.path.join(viz_base, '{}_pred_corner.png'.format(data_i))
visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=prec,
recall=recall)
pred_corners, pred_confs, pos_edges = postprocess_preds(pred_corners, pred_confs, pos_edges)
pred_data = {
'corners': pred_corners,
'edges': pos_edges,
}
if dataset == 's3d_floorplan':
save_filename = os.path.basename(annot_path)
save_npy_path = os.path.join(save_base, save_filename)
np.save(save_npy_path, pred_data)
else:
save_results = {
'corners': pred_corners,
'edges': pos_edges,
'image_path': img_path,
}
save_path = os.path.join(save_base, '{}_results.npy'.format(data_i))
np.save(save_path, save_results)
gt_data = convert_annot(annot)
score = compute_metrics(gt_data, pred_data)
edge_recall, edge_prec = get_recall_and_precision(score['edge_tp'], score['edge_fp'], score['edge_length'])
region_recall, region_prec = get_recall_and_precision(score['region_tp'], score['region_fp'],
score['region_length'])
er_recall = (edge_recall, region_recall)
er_prec = (edge_prec, region_prec)
if image_size != 256:
pred_corners_viz = pred_corners * (image_size / 256)
else:
pred_corners_viz = pred_corners
recon_path = os.path.join(viz_base, '{}_pred_edge.png'.format(data_i))
visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=er_prec,
recall=er_recall, edges=pos_edges, edge_confs=edge_confs)
corner_tp += score['corner_tp']
corner_fp += score['corner_fp']
corner_length += score['corner_length']
edge_tp += score['edge_tp']
edge_fp += score['edge_fp']
edge_length += score['edge_length']
region_tp += score['region_tp']
region_fp += score['region_fp']
region_length += score['region_length']
print('Finish inference for sample No.{}'.format(data_i))
avg_prec = np.array(all_prec).mean()
avg_recall = np.array(all_recall).mean()
recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
# edge
recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
# region
recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
print('Avg prec: {}, Avg recall: {}'.format(avg_prec, avg_recall))
def get_results(image, annot, backbone, corner_model, edge_model, pixels, pixel_features,
args, infer_times, corner_thresh=0.5, image_size=256):
image_feats, feat_mask, all_image_feats = backbone(image)
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
c_outputs = preds_s1
# get predicted corners
c_outputs_np = c_outputs[0].detach().cpu().numpy()
pos_indices = np.where(c_outputs_np >= corner_thresh)
pred_corners = pixels[pos_indices]
pred_confs = c_outputs_np[pos_indices]
pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
all_pos_ids = set()
all_edge_confs = dict()
for tt in range(infer_times):
if tt == 0:
gt_values = torch.zeros_like(edge_mask).long()
gt_values[:, :] = 2
# run the edge model
s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = edge_model(image_feats, feat_mask,
pixel_features,
edge_coords, edge_mask,
gt_values, corner_nums,
max_candidates,
True)
# do_inference=True)
num_total = s1_logits.shape[2]
num_selected = selected_ids.shape[1]
num_filtered = num_total - num_selected
s1_preds = s1_logits.squeeze().softmax(0)
s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
selected_ids = selected_ids.squeeze().detach().cpu().numpy()
if tt != infer_times - 1:
s2_preds_np = s2_preds_hb_np
pos_edge_ids = np.where(s2_preds_np >= 0.9)
neg_edge_ids = np.where(s2_preds_np <= 0.01)
for pos_id in pos_edge_ids[0]:
actual_id = selected_ids[pos_id]
if gt_values[0, actual_id] != 2:
continue
all_pos_ids.add(actual_id)
all_edge_confs[actual_id] = s2_preds_np[pos_id]
gt_values[0, actual_id] = 1
for neg_id in neg_edge_ids[0]:
actual_id = selected_ids[neg_id]
if gt_values[0, actual_id] != 2:
continue
gt_values[0, actual_id] = 0
num_to_pred = (gt_values == 2).sum()
if num_to_pred <= num_filtered:
break
else:
s2_preds_np = s2_preds_hb_np
pos_edge_ids = np.where(s2_preds_np >= 0.5)
for pos_id in pos_edge_ids[0]:
actual_id = selected_ids[pos_id]
if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
continue
all_pos_ids.add(actual_id)
all_edge_confs[actual_id] = s2_preds_np[pos_id]
# print('Inference time {}'.format(tt+1))
pos_edge_ids = list(all_pos_ids)
edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
edge_confs = np.array(edge_confs)
if image_size != 256:
pred_corners = pred_corners / (image_size / 256)
return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np
def postprocess_preds(corners, confs, edges):
corner_degrees = dict()
for edge_i, edge_pair in enumerate(edges):
corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
good_ids = [i for i in range(len(corners)) if i in corner_degrees]
if len(good_ids) == len(corners):
return corners, confs, edges
else:
good_corners = corners[good_ids]
good_confs = confs[good_ids]
id_mapping = {value: idx for idx, value in enumerate(good_ids)}
new_edges = list()
for edge_pair in edges:
new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
new_edges.append(new_pair)
new_edges = np.array(new_edges)
return good_corners, good_confs, new_edges
def process_image(img):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = skimage.img_as_float(img)
img = img.transpose((2, 0, 1))
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
img = torch.Tensor(img).cuda()
img = img.unsqueeze(0)
return img
def plot_heatmap(results, filename):
# generate 2 2d grids for the x & y bounds
# import pdb; pdb.set_trace()
y, x = np.meshgrid(np.linspace(0, 255, 256), np.linspace(0, 255, 256))
z = results[::-1, :]
# x and y are bounds, so z should be the value *inside* those bounds.
# Therefore, remove the last value from the z array.
z = z[:-1, :-1]
fig, ax = plt.subplots()
c = ax.pcolormesh(y, x, z, cmap='RdBu', vmin=0, vmax=1)
# set the limits of the plot to the limits of the data
ax.axis([x.min(), x.max(), y.min(), y.max()])
fig.colorbar(c, ax=ax)
fig.savefig(filename)
plt.close()
def convert_annot(annot):
corners = np.array(list(annot.keys()))
corners_mapping = {tuple(c): idx for idx, c in enumerate(corners)}
edges = set()
for corner, connections in annot.items():
idx_c = corners_mapping[tuple(corner)]
for other_c in connections:
idx_other_c = corners_mapping[tuple(other_c)]
if (idx_c, idx_other_c) not in edges and (idx_other_c, idx_c) not in edges:
edges.add((idx_c, idx_other_c))
edges = np.array(list(edges))
gt_data = {
'corners': corners,
'edges': edges
}
return gt_data
def get_args_parser():
parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
parser.add_argument('--dataset', default='outdoor',
help='the dataset for experiments, outdoor/s3d_floorplan')
parser.add_argument('--checkpoint_path', default='',
help='path to the checkpoints of the model')
parser.add_argument('--image_size', default=256, type=int)
parser.add_argument('--viz_base', default='./results/viz',
help='path to save the intermediate visualizations')
parser.add_argument('--save_base', default='./results/npy',
help='path to save the prediction results in npy files')
parser.add_argument('--infer_times', default=3, type=int)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser('HEAT inference', parents=[get_args_parser()])
args = parser.parse_args()
main(args.dataset, args.checkpoint_path, args.image_size, args.viz_base, args.save_base,
infer_times=args.infer_times)