|  | """ | 
					
						
						|  | 2D visualization primitives based on Matplotlib. | 
					
						
						|  | 1) Plot images with `plot_images`. | 
					
						
						|  | 2) Call `plot_keypoints` or `plot_matches` any number of times. | 
					
						
						|  | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import matplotlib | 
					
						
						|  | import matplotlib.patheffects as path_effects | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cm_RdGn(x): | 
					
						
						|  | """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" | 
					
						
						|  | x = np.clip(x, 0, 1)[..., None] * 2 | 
					
						
						|  | c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) | 
					
						
						|  | return np.clip(c, 0, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cm_BlRdGn(x_): | 
					
						
						|  | """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" | 
					
						
						|  | x = np.clip(x_, 0, 1)[..., None] * 2 | 
					
						
						|  | c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) | 
					
						
						|  |  | 
					
						
						|  | xn = -np.clip(x_, -1, 0)[..., None] * 2 | 
					
						
						|  | cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) | 
					
						
						|  | out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cm_prune(x_): | 
					
						
						|  | """Custom colormap to visualize pruning""" | 
					
						
						|  | if isinstance(x_, torch.Tensor): | 
					
						
						|  | x_ = x_.cpu().numpy() | 
					
						
						|  | max_i = max(x_) | 
					
						
						|  | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) | 
					
						
						|  | return cm_BlRdGn(norm_x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): | 
					
						
						|  | """Plot a set of images horizontally. | 
					
						
						|  | Args: | 
					
						
						|  | imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). | 
					
						
						|  | titles: a list of strings, as titles for each image. | 
					
						
						|  | cmaps: colormaps for monochrome images. | 
					
						
						|  | adaptive: whether the figure size should fit the image aspect ratios. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | imgs = [ | 
					
						
						|  | img.permute(1, 2, 0).cpu().numpy() | 
					
						
						|  | if (isinstance(img, torch.Tensor) and img.dim() == 3) | 
					
						
						|  | else img | 
					
						
						|  | for img in imgs | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | n = len(imgs) | 
					
						
						|  | if not isinstance(cmaps, (list, tuple)): | 
					
						
						|  | cmaps = [cmaps] * n | 
					
						
						|  |  | 
					
						
						|  | if adaptive: | 
					
						
						|  | ratios = [i.shape[1] / i.shape[0] for i in imgs] | 
					
						
						|  | else: | 
					
						
						|  | ratios = [4 / 3] * n | 
					
						
						|  | figsize = [sum(ratios) * 4.5, 4.5] | 
					
						
						|  | fig, ax = plt.subplots( | 
					
						
						|  | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} | 
					
						
						|  | ) | 
					
						
						|  | if n == 1: | 
					
						
						|  | ax = [ax] | 
					
						
						|  | for i in range(n): | 
					
						
						|  | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) | 
					
						
						|  | ax[i].get_yaxis().set_ticks([]) | 
					
						
						|  | ax[i].get_xaxis().set_ticks([]) | 
					
						
						|  | ax[i].set_axis_off() | 
					
						
						|  | for spine in ax[i].spines.values(): | 
					
						
						|  | spine.set_visible(False) | 
					
						
						|  | if titles: | 
					
						
						|  | ax[i].set_title(titles[i]) | 
					
						
						|  | fig.tight_layout(pad=pad) | 
					
						
						|  | return fig, ax | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): | 
					
						
						|  | """Plot keypoints for existing images. | 
					
						
						|  | Args: | 
					
						
						|  | kpts: list of ndarrays of size (N, 2). | 
					
						
						|  | colors: string, or list of list of tuples (one for each keypoints). | 
					
						
						|  | ps: size of the keypoints as float. | 
					
						
						|  | """ | 
					
						
						|  | if not isinstance(colors, list): | 
					
						
						|  | colors = [colors] * len(kpts) | 
					
						
						|  | if not isinstance(a, list): | 
					
						
						|  | a = [a] * len(kpts) | 
					
						
						|  | if axes is None: | 
					
						
						|  | axes = plt.gcf().axes | 
					
						
						|  | for ax, k, c, alpha in zip(axes, kpts, colors, a): | 
					
						
						|  | if isinstance(k, torch.Tensor): | 
					
						
						|  | k = k.cpu().numpy() | 
					
						
						|  | ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): | 
					
						
						|  | """Plot matches for a pair of existing images. | 
					
						
						|  | Args: | 
					
						
						|  | kpts0, kpts1: corresponding keypoints of size (N, 2). | 
					
						
						|  | color: color of each match, string or RGB tuple. Random if not given. | 
					
						
						|  | lw: width of the lines. | 
					
						
						|  | ps: size of the end points (no endpoint if ps=0) | 
					
						
						|  | indices: indices of the images to draw the matches on. | 
					
						
						|  | a: alpha opacity of the match lines. | 
					
						
						|  | """ | 
					
						
						|  | fig = plt.gcf() | 
					
						
						|  | if axes is None: | 
					
						
						|  | ax = fig.axes | 
					
						
						|  | ax0, ax1 = ax[0], ax[1] | 
					
						
						|  | else: | 
					
						
						|  | ax0, ax1 = axes | 
					
						
						|  | if isinstance(kpts0, torch.Tensor): | 
					
						
						|  | kpts0 = kpts0.cpu().numpy() | 
					
						
						|  | if isinstance(kpts1, torch.Tensor): | 
					
						
						|  | kpts1 = kpts1.cpu().numpy() | 
					
						
						|  | assert len(kpts0) == len(kpts1) | 
					
						
						|  | if color is None: | 
					
						
						|  | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() | 
					
						
						|  | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): | 
					
						
						|  | color = [color] * len(kpts0) | 
					
						
						|  |  | 
					
						
						|  | if lw > 0: | 
					
						
						|  | for i in range(len(kpts0)): | 
					
						
						|  | line = matplotlib.patches.ConnectionPatch( | 
					
						
						|  | xyA=(kpts0[i, 0], kpts0[i, 1]), | 
					
						
						|  | xyB=(kpts1[i, 0], kpts1[i, 1]), | 
					
						
						|  | coordsA=ax0.transData, | 
					
						
						|  | coordsB=ax1.transData, | 
					
						
						|  | axesA=ax0, | 
					
						
						|  | axesB=ax1, | 
					
						
						|  | zorder=1, | 
					
						
						|  | color=color[i], | 
					
						
						|  | linewidth=lw, | 
					
						
						|  | clip_on=True, | 
					
						
						|  | alpha=a, | 
					
						
						|  | label=None if labels is None else labels[i], | 
					
						
						|  | picker=5.0, | 
					
						
						|  | ) | 
					
						
						|  | line.set_annotation_clip(True) | 
					
						
						|  | fig.add_artist(line) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ax0.autoscale(enable=False) | 
					
						
						|  | ax1.autoscale(enable=False) | 
					
						
						|  |  | 
					
						
						|  | if ps > 0: | 
					
						
						|  | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) | 
					
						
						|  | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def add_text( | 
					
						
						|  | idx, | 
					
						
						|  | text, | 
					
						
						|  | pos=(0.01, 0.99), | 
					
						
						|  | fs=15, | 
					
						
						|  | color="w", | 
					
						
						|  | lcolor="k", | 
					
						
						|  | lwidth=2, | 
					
						
						|  | ha="left", | 
					
						
						|  | va="top", | 
					
						
						|  | ): | 
					
						
						|  | ax = plt.gcf().axes[idx] | 
					
						
						|  | t = ax.text( | 
					
						
						|  | *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes | 
					
						
						|  | ) | 
					
						
						|  | if lcolor is not None: | 
					
						
						|  | t.set_path_effects( | 
					
						
						|  | [ | 
					
						
						|  | path_effects.Stroke(linewidth=lwidth, foreground=lcolor), | 
					
						
						|  | path_effects.Normal(), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_plot(path, **kw): | 
					
						
						|  | """Save the current figure without any white margin.""" | 
					
						
						|  | plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) | 
					
						
						|  |  |