Spaces:
Running
Running
File size: 5,666 Bytes
2673dcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
"""
2D visualization primitives based on Matplotlib.
1) Plot images with `plot_images`.
2) Call `plot_keypoints` or `plot_matches` any number of times.
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
"""
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import numpy as np
import torch
def cm_RdGn(x):
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
x = np.clip(x, 0, 1)[..., None]*2
c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]])
return np.clip(c, 0, 1)
def cm_BlRdGn(x_):
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
x = np.clip(x_, 0, 1)[..., None]*2
c = x*np.array([[0, 1., 0, 1.]]) + (2-x)*np.array([[1., 0, 0, 1.]])
xn = -np.clip(x_, -1, 0)[..., None]*2
cn = xn*np.array([[0, 0.1, 1, 1.]]) + (2-xn)*np.array([[1., 0, 0, 1.]])
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
return out
def cm_prune(x_):
""" Custom colormap to visualize pruning """
if isinstance(x_, torch.Tensor):
x_ = x_.cpu().numpy()
max_i = max(x_)
norm_x = np.where(x_ == max_i, -1, (x_-1) / 9)
return cm_BlRdGn(norm_x)
def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
adaptive=True):
"""Plot a set of images horizontally.
Args:
imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
adaptive: whether the figure size should fit the image aspect ratios.
"""
# conversion to (H, W, 3) for torch.Tensor
imgs = [img.permute(1, 2, 0).cpu().numpy()
if (isinstance(img, torch.Tensor) and img.dim() == 3) else img
for img in imgs]
n = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n
if adaptive:
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
else:
ratios = [4/3] * n
figsize = [sum(ratios)*4.5, 4.5]
fig, ax = plt.subplots(
1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
if n == 1:
ax = [ax]
for i in range(n):
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
ax[i].get_yaxis().set_ticks([])
ax[i].get_xaxis().set_ticks([])
ax[i].set_axis_off()
for spine in ax[i].spines.values(): # remove frame
spine.set_visible(False)
if titles:
ax[i].set_title(titles[i])
fig.tight_layout(pad=pad)
def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0):
"""Plot keypoints for existing images.
Args:
kpts: list of ndarrays of size (N, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float.
"""
if not isinstance(colors, list):
colors = [colors] * len(kpts)
if not isinstance(a, list):
a = [a] * len(kpts)
if axes is None:
axes = plt.gcf().axes
for ax, k, c, alpha in zip(axes, kpts, colors, a):
if isinstance(k, torch.Tensor):
k = k.cpu().numpy()
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None,
axes=None):
"""Plot matches for a pair of existing images.
Args:
kpts0, kpts1: corresponding keypoints of size (N, 2).
color: color of each match, string or RGB tuple. Random if not given.
lw: width of the lines.
ps: size of the end points (no endpoint if ps=0)
indices: indices of the images to draw the matches on.
a: alpha opacity of the match lines.
"""
fig = plt.gcf()
if axes is None:
ax = fig.axes
ax0, ax1 = ax[0], ax[1]
else:
ax0, ax1 = axes
if isinstance(kpts0, torch.Tensor):
kpts0 = kpts0.cpu().numpy()
if isinstance(kpts1, torch.Tensor):
kpts1 = kpts1.cpu().numpy()
assert len(kpts0) == len(kpts1)
if color is None:
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
color = [color] * len(kpts0)
if lw > 0:
for i in range(len(kpts0)):
line = matplotlib.patches.ConnectionPatch(
xyA=(kpts0[i, 0], kpts0[i, 1]), xyB=(kpts1[i, 0], kpts1[i, 1]),
coordsA=ax0.transData, coordsB=ax1.transData,
axesA=ax0, axesB=ax1,
zorder=1, color=color[i], linewidth=lw, clip_on=True,
alpha=a, label=None if labels is None else labels[i],
picker=5.0)
line.set_annotation_clip(True)
fig.add_artist(line)
# freeze the axes to prevent the transform to change
ax0.autoscale(enable=False)
ax1.autoscale(enable=False)
if ps > 0:
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w',
lcolor='k', lwidth=2, ha='left', va='top'):
ax = plt.gcf().axes[idx]
t = ax.text(*pos, text, fontsize=fs, ha=ha, va=va,
color=color, transform=ax.transAxes)
if lcolor is not None:
t.set_path_effects([
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
path_effects.Normal()])
def save_plot(path, **kw):
"""Save the current figure without any white margin."""
plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw)
|