""" 2D visualization primitives based on Matplotlib. 1) Plot images with `plot_images`. 2) Call `plot_keypoints` or `plot_matches` any number of times. 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. """ import matplotlib import matplotlib.patheffects as path_effects import matplotlib.pyplot as plt import numpy as np import seaborn as sns def cm_ranking(sc, ths=[512, 1024, 2048, 4096]): ls = sc.shape[0] colors = ["red", "yellow", "lime", "cyan", "blue"] out = ["gray"] * ls for i in range(ls): for c, th in zip(colors[: len(ths) + 1], ths + [ls]): if i < th: out[i] = c break sid = np.argsort(sc, axis=0).flip(0) out = np.array(out)[sid] return out def cm_RdBl(x): """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" x = np.clip(x, 0, 1)[..., None] * 2 c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]]) return np.clip(c, 0, 1) def cm_RdGn(x): """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" x = np.clip(x, 0, 1)[..., None] * 2 c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) return np.clip(c, 0, 1) def cm_BlRdGn(x_): """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" x = np.clip(x_, 0, 1)[..., None] * 2 c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) xn = -np.clip(x_, -1, 0)[..., None] * 2 cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) return out def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): """Plot a set of images horizontally. Args: imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). titles: a list of strings, as titles for each image. cmaps: colormaps for monochrome images. adaptive: whether the figure size should fit the image aspect ratios. """ n = len(imgs) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n if adaptive: ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H else: ratios = [4 / 3] * n figsize = [sum(ratios) * 4.5, 4.5] fig, axs = plt.subplots( 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} ) if n == 1: axs = [axs] for i, (img, ax) in enumerate(zip(imgs, axs)): ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) ax.set_axis_off() if titles: ax.set_title(titles[i]) fig.tight_layout(pad=pad) def plot_image_grid( imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, fig=None, adaptive=True, figs=2.0, return_fig=False, set_lim=False, ): """Plot a grid of images. Args: imgs: a list of lists of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). titles: a list of strings, as titles for each image. cmaps: colormaps for monochrome images. adaptive: whether the figure size should fit the image aspect ratios. """ nr, n = len(imgs), len(imgs[0]) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n if adaptive: ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H else: ratios = [4 / 3] * n figsize = [sum(ratios) * figs, nr * figs] if fig is None: fig, axs = plt.subplots( nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} ) else: axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios}) fig.figure.set_size_inches(figsize) if nr == 1: axs = [axs] for j in range(nr): for i in range(n): ax = axs[j][i] ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i])) ax.set_axis_off() if set_lim: ax.set_xlim([0, imgs[j][i].shape[1]]) ax.set_ylim([imgs[j][i].shape[0], 0]) if titles: ax.set_title(titles[j][i]) if isinstance(fig, plt.Figure): fig.tight_layout(pad=pad) if return_fig: return fig, axs else: return axs def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): """Plot keypoints for existing images. Args: kpts: list of ndarrays of size (N, 2). colors: string, or list of list of tuples (one for each keypoints). ps: size of the keypoints as float. """ if not isinstance(colors, list): colors = [colors] * len(kpts) if not isinstance(a, list): a = [a] * len(kpts) if axes is None: axes = plt.gcf().axes for ax, k, c, alpha in zip(axes, kpts, colors, a): ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): """Plot matches for a pair of existing images. Args: kpts0, kpts1: corresponding keypoints of size (N, 2). color: color of each match, string or RGB tuple. Random if not given. lw: width of the lines. ps: size of the end points (no endpoint if ps=0) indices: indices of the images to draw the matches on. a: alpha opacity of the match lines. """ fig = plt.gcf() if axes is None: ax = fig.axes ax0, ax1 = ax[0], ax[1] else: ax0, ax1 = axes assert len(kpts0) == len(kpts1) if color is None: color = sns.color_palette("husl", n_colors=len(kpts0)) elif len(color) > 0 and not isinstance(color[0], (tuple, list)): color = [color] * len(kpts0) if lw > 0: for i in range(len(kpts0)): line = matplotlib.patches.ConnectionPatch( xyA=(kpts0[i, 0], kpts0[i, 1]), xyB=(kpts1[i, 0], kpts1[i, 1]), coordsA=ax0.transData, coordsB=ax1.transData, axesA=ax0, axesB=ax1, zorder=1, color=color[i], linewidth=lw, clip_on=True, alpha=a, label=None if labels is None else labels[i], picker=5.0, ) line.set_annotation_clip(True) fig.add_artist(line) # freeze the axes to prevent the transform to change ax0.autoscale(enable=False) ax1.autoscale(enable=False) if ps > 0: ax0.scatter( kpts0[:, 0], kpts0[:, 1], c=color, s=ps, label=None if labels is None or len(labels) == 0 else labels[0], ) ax1.scatter( kpts1[:, 0], kpts1[:, 1], c=color, s=ps, label=None if labels is None or len(labels) == 0 else labels[1], ) def add_text( idx, text, pos=(0.01, 0.99), fs=15, color="w", lcolor="k", lwidth=2, ha="left", va="top", axes=None, **kwargs, ): if axes is None: axes = plt.gcf().axes ax = axes[idx] t = ax.text( *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes, **kwargs, ) if lcolor is not None: t.set_path_effects( [ path_effects.Stroke(linewidth=lwidth, foreground=lcolor), path_effects.Normal(), ] ) return t def draw_epipolar_line( line, axis, imshape=None, color="b", label=None, alpha=1.0, visible=True ): if imshape is not None: h, w = imshape[:2] else: _, w = axis.get_xlim() h, _ = axis.get_ylim() imshape = (h + 0.5, w + 0.5) # Intersect line with lines representing image borders. X1 = np.cross(line, [1, 0, -1]) X1 = X1[:2] / X1[2] X2 = np.cross(line, [1, 0, -w]) X2 = X2[:2] / X2[2] X3 = np.cross(line, [0, 1, -1]) X3 = X3[:2] / X3[2] X4 = np.cross(line, [0, 1, -h]) X4 = X4[:2] / X4[2] # Find intersections which are not outside the image, # which will therefore be on the image border. Xs = [X1, X2, X3, X4] Ps = [] for p in range(4): X = Xs[p] if (0 <= X[0] <= (w + 1e-6)) and (0 <= X[1] <= (h + 1e-6)): Ps.append(X) if len(Ps) == 2: break # Plot line, if it's visible in the image. if len(Ps) == 2: art = axis.plot( [Ps[0][0], Ps[1][0]], [Ps[0][1], Ps[1][1]], color, linestyle="dashed", label=label, alpha=alpha, visible=visible, )[0] return art else: return None def get_line(F, kp): hom_kp = np.array([list(kp) + [1.0]]).transpose() return np.dot(F, hom_kp) def plot_epipolar_lines( pts0, pts1, F, color="b", axes=None, labels=None, a=1.0, visible=True ): if axes is None: axes = plt.gcf().axes assert len(axes) == 2 for ax, kps in zip(axes, [pts1, pts0]): _, w = ax.get_xlim() h, _ = ax.get_ylim() imshape = (h + 0.5, w + 0.5) for i in range(kps.shape[0]): if ax == axes[0]: line = get_line(F.transpose(0, 1), kps[i])[:, 0] else: line = get_line(F, kps[i])[:, 0] draw_epipolar_line( line, ax, imshape, color=color, label=None if labels is None else labels[i], alpha=a, visible=visible, ) def plot_heatmaps(heatmaps, vmin=0.0, vmax=None, cmap="Spectral", a=0.5, axes=None): if axes is None: axes = plt.gcf().axes artists = [] for i in range(len(axes)): a_ = a if isinstance(a, float) else a[i] art = axes[i].imshow( heatmaps[i], alpha=(heatmaps[i] > vmin).float() * a_, vmin=vmin, vmax=vmax, cmap=cmap, ) artists.append(art) return artists def plot_lines( lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, alpha=1.0, indices=(0, 1), ): """Plot lines and endpoints for existing images. Args: lines: list of ndarrays of size (N, 2, 2). colors: string, or list of list of tuples (one for each keypoints). ps: size of the keypoints as float pixels. lw: line width as float pixels. alpha: transparency of the points and lines. indices: indices of the images to draw the matches on. """ if not isinstance(line_colors, list): line_colors = [line_colors] * len(lines) if not isinstance(point_colors, list): point_colors = [point_colors] * len(lines) fig = plt.gcf() ax = fig.axes assert len(ax) > max(indices) axes = [ax[i] for i in indices] # Plot the lines and junctions for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): for i in range(len(l)): line = matplotlib.lines.Line2D( (l[i, 0, 0], l[i, 1, 0]), (l[i, 0, 1], l[i, 1, 1]), zorder=1, c=lc, linewidth=lw, alpha=alpha, ) a.add_line(line) pts = l.reshape(-1, 2) a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)): """Plot line matches for existing images with multiple colors. Args: lines: list of ndarrays of size (N, 2, 2). correct_matches: bool array of size (N,) indicating correct matches. lw: line width as float pixels. indices: indices of the images to draw the matches on. """ n_lines = len(lines[0]) colors = sns.color_palette("husl", n_colors=n_lines) np.random.shuffle(colors) alphas = np.ones(n_lines) # If correct_matches is not None, display wrong matches with a low alpha if correct_matches is not None: alphas[~np.array(correct_matches)] = 0.2 fig = plt.gcf() ax = fig.axes assert len(ax) > max(indices) axes = [ax[i] for i in indices] # Plot the lines for a, img_lines in zip(axes, lines): for i, line in enumerate(img_lines): fig.add_artist( matplotlib.patches.ConnectionPatch( xyA=tuple(line[0]), coordsA=a.transData, xyB=tuple(line[1]), coordsB=a.transData, zorder=1, color=colors[i], linewidth=lw, alpha=alphas[i], ) ) def save_plot(path, **kw): """Save the current figure without any white margin.""" plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) def plot_cumulative( errors: dict, thresholds: list, colors=None, title="", unit="-", logx=False, ): thresholds = np.linspace(min(thresholds), max(thresholds), 100) plt.figure(figsize=[5, 8]) for method in errors: recall = [] errs = np.array(errors[method]) for th in thresholds: recall.append(np.mean(errs <= th)) plt.plot( thresholds, np.array(recall) * 100, label=method, c=colors[method] if colors else None, linewidth=3, ) plt.grid() plt.xlabel(unit, fontsize=25) if logx: plt.semilogx() plt.ylim([0, 100]) plt.yticks(ticks=[0, 20, 40, 60, 80, 100]) plt.ylabel(title + "Recall [%]", rotation=0, fontsize=25) plt.gca().yaxis.set_label_coords(x=0.45, y=1.02) plt.tick_params(axis="both", which="major", labelsize=20) plt.yticks(rotation=0) plt.legend( bbox_to_anchor=(0.45, -0.12), ncol=2, loc="upper center", fontsize=20, handlelength=3, ) plt.tight_layout() return plt.gcf()