Spaces:
Sleeping
Sleeping
import bisect | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib, os, cv2 | |
import matplotlib.cm as cm | |
from PIL import Image | |
import torch.nn.functional as F | |
import torch | |
import seaborn as sns | |
def _compute_conf_thresh(data): | |
dataset_name = data["dataset_name"][0].lower() | |
if dataset_name == "scannet": | |
thr = 5e-4 | |
elif dataset_name == "megadepth": | |
thr = 1e-4 | |
else: | |
raise ValueError(f"Unknown dataset: {dataset_name}") | |
return thr | |
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5): | |
"""Plot a set of images horizontally. | |
Args: | |
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). | |
titles: a list of strings, as titles for each image. | |
cmaps: colormaps for monochrome images. | |
""" | |
n = len(imgs) | |
if not isinstance(cmaps, (list, tuple)): | |
cmaps = [cmaps] * n | |
# figsize = (size*n, size*3/4) if size is not None else None | |
figsize = (size * n, size * 6 / 5) if size is not None else None | |
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) | |
if n == 1: | |
ax = [ax] | |
for i in range(n): | |
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) | |
ax[i].get_yaxis().set_ticks([]) | |
ax[i].get_xaxis().set_ticks([]) | |
ax[i].set_axis_off() | |
for spine in ax[i].spines.values(): # remove frame | |
spine.set_visible(False) | |
if titles: | |
ax[i].set_title(titles[i]) | |
fig.tight_layout(pad=pad) | |
return fig | |
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)): | |
"""Plot line matches for existing images with multiple colors. | |
Args: | |
lines: list of ndarrays of size (N, 2, 2). | |
correct_matches: bool array of size (N,) indicating correct matches. | |
lw: line width as float pixels. | |
indices: indices of the images to draw the matches on. | |
""" | |
n_lines = len(lines[0]) | |
colors = sns.color_palette("husl", n_colors=n_lines) | |
np.random.shuffle(colors) | |
alphas = np.ones(n_lines) | |
# If correct_matches is not None, display wrong matches with a low alpha | |
if correct_matches is not None: | |
alphas[~np.array(correct_matches)] = 0.2 | |
fig = plt.gcf() | |
ax = fig.axes | |
assert len(ax) > max(indices) | |
axes = [ax[i] for i in indices] | |
fig.canvas.draw() | |
# Plot the lines | |
for a, l in zip(axes, lines): | |
# Transform the points into the figure coordinate system | |
transFigure = fig.transFigure.inverted() | |
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) | |
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) | |
fig.lines += [ | |
matplotlib.lines.Line2D( | |
(endpoint0[i, 0], endpoint1[i, 0]), | |
(endpoint0[i, 1], endpoint1[i, 1]), | |
zorder=1, | |
transform=fig.transFigure, | |
c=colors[i], | |
alpha=alphas[i], | |
linewidth=lw, | |
) | |
for i in range(n_lines) | |
] | |
return fig | |
def make_matching_figure( | |
img0, | |
img1, | |
mkpts0, | |
mkpts1, | |
color, | |
titles=None, | |
kpts0=None, | |
kpts1=None, | |
text=[], | |
dpi=75, | |
path=None, | |
pad=0, | |
): | |
# draw image pair | |
# assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' | |
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) | |
axes[0].imshow(img0) # , cmap='gray') | |
axes[1].imshow(img1) # , cmap='gray') | |
for i in range(2): # clear all frames | |
axes[i].get_yaxis().set_ticks([]) | |
axes[i].get_xaxis().set_ticks([]) | |
for spine in axes[i].spines.values(): | |
spine.set_visible(False) | |
if titles is not None: | |
axes[i].set_title(titles[i]) | |
plt.tight_layout(pad=pad) | |
if kpts0 is not None: | |
assert kpts1 is not None | |
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5) | |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5) | |
# draw matches | |
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: | |
fig.canvas.draw() | |
transFigure = fig.transFigure.inverted() | |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) | |
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) | |
fig.lines = [ | |
matplotlib.lines.Line2D( | |
(fkpts0[i, 0], fkpts1[i, 0]), | |
(fkpts0[i, 1], fkpts1[i, 1]), | |
transform=fig.transFigure, | |
c=color[i], | |
linewidth=2, | |
) | |
for i in range(len(mkpts0)) | |
] | |
# freeze the axes to prevent the transform to change | |
axes[0].autoscale(enable=False) | |
axes[1].autoscale(enable=False) | |
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) | |
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) | |
# put txts | |
txt_color = "k" if img0[:100, :200].mean() > 200 else "w" | |
fig.text( | |
0.01, | |
0.99, | |
"\n".join(text), | |
transform=fig.axes[0].transAxes, | |
fontsize=15, | |
va="top", | |
ha="left", | |
color=txt_color, | |
) | |
# save or return figure | |
if path: | |
plt.savefig(str(path), bbox_inches="tight", pad_inches=0) | |
plt.close() | |
else: | |
return fig | |
def _make_evaluation_figure(data, b_id, alpha="dynamic"): | |
b_mask = data["m_bids"] == b_id | |
conf_thr = _compute_conf_thresh(data) | |
img0 = ( | |
(data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) | |
) | |
img1 = ( | |
(data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32) | |
) | |
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy() | |
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy() | |
# for megadepth, we visualize matches on the resized image | |
if "scale0" in data: | |
kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]] | |
kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]] | |
epi_errs = data["epi_errs"][b_mask].cpu().numpy() | |
correct_mask = epi_errs < conf_thr | |
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 | |
n_correct = np.sum(correct_mask) | |
n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu()) | |
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) | |
# recall might be larger than 1, since the calculation of conf_matrix_gt | |
# uses groundtruth depths and camera poses, but epipolar distance is used here. | |
# matching info | |
if alpha == "dynamic": | |
alpha = dynamic_alpha(len(correct_mask)) | |
color = error_colormap(epi_errs, conf_thr, alpha=alpha) | |
text = [ | |
f"#Matches {len(kpts0)}", | |
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):" | |
f" {n_correct}/{len(kpts0)}", | |
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%):" | |
f" {n_correct}/{n_gt_matches}", | |
] | |
# make the figure | |
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text) | |
return figure | |
def _make_confidence_figure(data, b_id): | |
# TODO: Implement confidence figure | |
raise NotImplementedError() | |
def make_matching_figures(data, config, mode="evaluation"): | |
"""Make matching figures for a batch. | |
Args: | |
data (Dict): a batch updated by PL_LoFTR. | |
config (Dict): matcher config | |
Returns: | |
figures (Dict[str, List[plt.figure]] | |
""" | |
assert mode in ["evaluation", "confidence"] # 'confidence' | |
figures = {mode: []} | |
for b_id in range(data["image0"].size(0)): | |
if mode == "evaluation": | |
fig = _make_evaluation_figure( | |
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA | |
) | |
elif mode == "confidence": | |
fig = _make_confidence_figure(data, b_id) | |
else: | |
raise ValueError(f"Unknown plot mode: {mode}") | |
figures[mode].append(fig) | |
return figures | |
def dynamic_alpha( | |
n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2] | |
): | |
if n_matches == 0: | |
return 1.0 | |
ranges = list(zip(alphas, alphas[1:] + [None])) | |
loc = bisect.bisect_right(milestones, n_matches) - 1 | |
_range = ranges[loc] | |
if _range[1] is None: | |
return _range[0] | |
return _range[1] + (milestones[loc + 1] - n_matches) / ( | |
milestones[loc + 1] - milestones[loc] | |
) * (_range[0] - _range[1]) | |
def error_colormap(err, thr, alpha=1.0): | |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" | |
x = 1 - np.clip(err / (thr * 2), 0, 1) | |
return np.clip( | |
np.stack( | |
[2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1 | |
), | |
0, | |
1, | |
) | |
np.random.seed(1995) | |
color_map = np.arange(100) | |
np.random.shuffle(color_map) | |
def draw_topics( | |
data, | |
img0, | |
img1, | |
saved_folder="viz_topics", | |
show_n_topics=8, | |
saved_name=None, | |
): | |
topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"] | |
hw0_c, hw1_c = data["hw0_c"], data["hw1_c"] | |
hw0_i, hw1_i = data["hw0_i"], data["hw1_i"] | |
# print(hw0_i, hw1_i) | |
scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0] | |
if "scale0" in data: | |
scale0 *= data["scale0"][0] | |
else: | |
scale0 = (scale0, scale0) | |
if "scale1" in data: | |
scale1 *= data["scale1"][0] | |
else: | |
scale1 = (scale1, scale1) | |
n_topics = topic0.shape[-1] | |
# mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0 | |
# mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0 | |
theta0 = topic0[0].sum(dim=0) | |
theta0 /= theta0.sum().float() | |
theta1 = topic1[0].sum(dim=0) | |
theta1 /= theta1.sum().float() | |
# top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics] | |
# top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics] | |
top_topics = torch.argsort(theta0 * theta1, descending=True)[:show_n_topics] | |
# print(sum_topic0, sum_topic1) | |
topic0 = topic0[0].argmax( | |
dim=-1, keepdim=True | |
) # .float() / (n_topics - 1) #* 255 + 1 # | |
# topic0[~mask0_nonzero] = -1 | |
topic1 = topic1[0].argmax( | |
dim=-1, keepdim=True | |
) # .float() / (n_topics - 1) #* 255 + 1 | |
# topic1[~mask1_nonzero] = -1 | |
label_img0, label_img1 = ( | |
torch.zeros_like(topic0) - 1, | |
torch.zeros_like(topic1) - 1, | |
) | |
for i, k in enumerate(top_topics): | |
label_img0[topic0 == k] = color_map[k] | |
label_img1[topic1 == k] = color_map[k] | |
# print(hw0_c, scale0) | |
# print(hw1_c, scale1) | |
# map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0) | |
map_topic0 = ( | |
label_img0.float().view(hw0_c).cpu().numpy() | |
) # map_topic0.squeeze(0).squeeze(0).cpu().numpy() | |
map_topic0 = cv2.resize( | |
map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1])) | |
) | |
# map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1) | |
map_topic1 = ( | |
label_img1.float().view(hw1_c).cpu().numpy() | |
) # map_topic1.squeeze(0).squeeze(0).cpu().numpy() | |
map_topic1 = cv2.resize( | |
map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1])) | |
) | |
# show image0 | |
if saved_name is None: | |
return map_topic0, map_topic1 | |
if not os.path.exists(saved_folder): | |
os.makedirs(saved_folder) | |
path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name)) | |
plt.imshow(img0) | |
masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0) | |
plt.imshow( | |
masked_map_topic0, | |
cmap=plt.cm.jet, | |
vmin=0, | |
vmax=n_topics - 1, | |
alpha=0.3, | |
interpolation="bilinear", | |
) | |
# plt.show() | |
plt.axis("off") | |
plt.savefig(path_saved_img0, bbox_inches="tight", pad_inches=0, dpi=250) | |
plt.close() | |
path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name)) | |
plt.imshow(img1) | |
masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1) | |
plt.imshow( | |
masked_map_topic1, | |
cmap=plt.cm.jet, | |
vmin=0, | |
vmax=n_topics - 1, | |
alpha=0.3, | |
interpolation="bilinear", | |
) | |
plt.axis("off") | |
plt.savefig(path_saved_img1, bbox_inches="tight", pad_inches=0, dpi=250) | |
plt.close() | |
def draw_topicfm_demo( | |
data, | |
img0, | |
img1, | |
mkpts0, | |
mkpts1, | |
mcolor, | |
text, | |
show_n_topics=8, | |
topic_alpha=0.3, | |
margin=5, | |
path=None, | |
opencv_display=False, | |
opencv_title="", | |
): | |
topic_map0, topic_map1 = draw_topics( | |
data, img0, img1, show_n_topics=show_n_topics | |
) | |
mask_tm0, mask_tm1 = np.expand_dims( | |
topic_map0 >= 0, axis=-1 | |
), np.expand_dims(topic_map1 >= 0, axis=-1) | |
topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0) | |
topic_cm0 = cv2.cvtColor( | |
topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR | |
) | |
topic_cm1 = cv2.cvtColor( | |
topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR | |
) | |
overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32) | |
overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32) | |
cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0) | |
cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1) | |
overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), ( | |
overlay1 * 255 | |
).astype(np.uint8) | |
h0, w0 = img0.shape[:2] | |
h1, w1 = img1.shape[:2] | |
h, w = h0 * 2 + margin * 2, w0 * 2 + margin | |
out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8) | |
out_fig[:h0, :w0] = overlay0 | |
if h0 >= h1: | |
start = (h0 - h1) // 2 | |
out_fig[ | |
start : (start + h1), (w0 + margin) : (w0 + margin + w1) | |
] = overlay1 | |
else: | |
start = (h1 - h0) // 2 | |
out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[ | |
start : (start + h0) | |
] | |
step_h = h0 + margin * 2 | |
out_fig[step_h : step_h + h0, :w0] = (img0 * 255).astype(np.uint8) | |
if h0 >= h1: | |
start = step_h + (h0 - h1) // 2 | |
out_fig[start : start + h1, (w0 + margin) : (w0 + margin + w1)] = ( | |
img1 * 255 | |
).astype(np.uint8) | |
else: | |
start = (h1 - h0) // 2 | |
out_fig[step_h : step_h + h0, (w0 + margin) : (w0 + margin + w1)] = ( | |
img1[start : start + h0] * 255 | |
).astype(np.uint8) | |
# draw matching lines, this is inspried from | |
# https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py | |
mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) | |
mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int) | |
for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor): | |
c = c.tolist() | |
cv2.line( | |
out_fig, | |
(x0, y0 + step_h), | |
(x1 + margin + w0, y1 + step_h + (h0 - h1) // 2), | |
color=c, | |
thickness=1, | |
lineType=cv2.LINE_AA, | |
) | |
# display line end-points as circles | |
cv2.circle(out_fig, (x0, y0 + step_h), 2, c, -1, lineType=cv2.LINE_AA) | |
cv2.circle( | |
out_fig, | |
(x1 + margin + w0, y1 + step_h + (h0 - h1) // 2), | |
2, | |
c, | |
-1, | |
lineType=cv2.LINE_AA, | |
) | |
# Scale factor for consistent visualization across scales. | |
sc = min(h / 960.0, 2.0) | |
# Big text. | |
Ht = int(30 * sc) # text height | |
txt_color_fg = (255, 255, 255) | |
txt_color_bg = (0, 0, 0) | |
for i, t in enumerate(text): | |
cv2.putText( | |
out_fig, | |
t, | |
(int(8 * sc), Ht + step_h * i), | |
cv2.FONT_HERSHEY_DUPLEX, | |
1.0 * sc, | |
txt_color_bg, | |
2, | |
cv2.LINE_AA, | |
) | |
cv2.putText( | |
out_fig, | |
t, | |
(int(8 * sc), Ht + step_h * i), | |
cv2.FONT_HERSHEY_DUPLEX, | |
1.0 * sc, | |
txt_color_fg, | |
1, | |
cv2.LINE_AA, | |
) | |
if path is not None: | |
cv2.imwrite(str(path), out_fig) | |
if opencv_display: | |
cv2.imshow(opencv_title, out_fig) | |
cv2.waitKey(1) | |
return out_fig | |
def fig2im(fig): | |
fig.canvas.draw() | |
w, h = fig.canvas.get_width_height() | |
buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") | |
im = buf_ndarray.reshape(h, w, 3) | |
return im | |
def draw_matches( | |
mkpts0, mkpts1, img0, img1, conf, titles=None, dpi=150, path=None, pad=0.5 | |
): | |
thr = 5e-4 | |
thr = 0.5 | |
color = error_colormap(conf, thr, alpha=0.1) | |
text = [ | |
f"image name", | |
f"#Matches: {len(mkpts0)}", | |
] | |
if path: | |
fig2im( | |
make_matching_figure( | |
img0, | |
img1, | |
mkpts0, | |
mkpts1, | |
color, | |
titles=titles, | |
text=text, | |
path=path, | |
dpi=dpi, | |
pad=pad, | |
) | |
) | |
else: | |
return fig2im( | |
make_matching_figure( | |
img0, | |
img1, | |
mkpts0, | |
mkpts1, | |
color, | |
titles=titles, | |
text=text, | |
pad=pad, | |
dpi=dpi, | |
) | |
) | |
def draw_image_pairs(img0, img1, text=[], dpi=75, path=None, pad=0.5): | |
# draw image pair | |
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) | |
axes[0].imshow(img0) # , cmap='gray') | |
axes[1].imshow(img1) # , cmap='gray') | |
for i in range(2): # clear all frames | |
axes[i].get_yaxis().set_ticks([]) | |
axes[i].get_xaxis().set_ticks([]) | |
for spine in axes[i].spines.values(): | |
spine.set_visible(False) | |
plt.tight_layout(pad=pad) | |
# put txts | |
txt_color = "k" if img0[:100, :200].mean() > 200 else "w" | |
fig.text( | |
0.01, | |
0.99, | |
"\n".join(text), | |
transform=fig.axes[0].transAxes, | |
fontsize=15, | |
va="top", | |
ha="left", | |
color=txt_color, | |
) | |
# save or return figure | |
if path: | |
plt.savefig(str(path), bbox_inches="tight", pad_inches=0) | |
plt.close() | |
else: | |
return fig2im(fig) | |