import numpy as np import seaborn as sns import matplotlib import matplotlib.pyplot as plt from pathlib import Path from typing import Dict, Any, Optional, Tuple, List, Union def plot_images( imgs: List[np.ndarray], titles: Optional[List[str]] = None, cmaps: Union[str, List[str]] = "gray", dpi: int = 100, size: Optional[int] = 5, pad: float = 0.5, ) -> plt.Figure: """Plot a set of images horizontally. Args: imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). titles: a list of strings, as titles for each image. cmaps: colormaps for monochrome images. If a single string is given, it is used for all images. dpi: DPI of the figure. size: figure size in inches (width). If not provided, the figure size is determined automatically. pad: padding between subplots, in inches. Returns: The created figure. """ n = len(imgs) if not isinstance(cmaps, list): cmaps = [cmaps] * n figsize = (size * n, size * 6 / 5) if size is not None else None fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) if n == 1: ax = [ax] for i in range(n): ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) ax[i].get_yaxis().set_ticks([]) ax[i].get_xaxis().set_ticks([]) ax[i].set_axis_off() for spine in ax[i].spines.values(): # remove frame spine.set_visible(False) if titles: ax[i].set_title(titles[i]) fig.tight_layout(pad=pad) return fig def plot_color_line_matches( lines: List[np.ndarray], correct_matches: Optional[np.ndarray] = None, lw: float = 2.0, indices: Tuple[int, int] = (0, 1), ) -> matplotlib.figure.Figure: """Plot line matches for existing images with multiple colors. Args: lines: List of ndarrays of size (N, 2, 2) representing line segments. correct_matches: Optional bool array of size (N,) indicating correct matches. If not None, display wrong matches with a low alpha. lw: Line width as float pixels. indices: Indices of the images to draw the matches on. Returns: The modified matplotlib figure. """ n_lines = lines[0].shape[0] colors = sns.color_palette("husl", n_colors=n_lines) np.random.shuffle(colors) alphas = np.ones(n_lines) if correct_matches is not None: alphas[~np.array(correct_matches)] = 0.2 fig = plt.gcf() ax = typing.cast(List[matplotlib.axes.Axes], fig.axes) assert len(ax) > max(indices) axes = [ax[i] for i in indices] fig.canvas.draw() # Plot the lines for a, l in zip(axes, lines): # Transform the points into the figure coordinate system transFigure = fig.transFigure.inverted() endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) fig.lines += [ matplotlib.lines.Line2D( (endpoint0[i, 0], endpoint1[i, 0]), (endpoint0[i, 1], endpoint1[i, 1]), zorder=1, transform=fig.transFigure, c=colors[i], alpha=alphas[i], linewidth=lw, ) for i in range(n_lines) ] return fig def make_matching_figure( img0: np.ndarray, img1: np.ndarray, mkpts0: np.ndarray, mkpts1: np.ndarray, color: np.ndarray, titles: Optional[List[str]] = None, kpts0: Optional[np.ndarray] = None, kpts1: Optional[np.ndarray] = None, text: List[str] = [], dpi: int = 75, path: Optional[Path] = None, pad: float = 0.0, ) -> Optional[plt.Figure]: """Draw image pair with matches. Args: img0: image0 as HxWx3 numpy array. img1: image1 as HxWx3 numpy array. mkpts0: matched points in image0 as Nx2 numpy array. mkpts1: matched points in image1 as Nx2 numpy array. color: colors for the matches as Nx4 numpy array. titles: titles for the two subplots. kpts0: keypoints in image0 as Kx2 numpy array. kpts1: keypoints in image1 as Kx2 numpy array. text: list of strings to display in the top-left corner of the image. dpi: dots per inch of the saved figure. path: if not None, save the figure to this path. pad: padding around the image as a fraction of the image size. Returns: The matplotlib Figure object if path is None. """ # draw image pair fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) axes[0].imshow(img0) # , cmap='gray') axes[1].imshow(img1) # , cmap='gray') for i in range(2): # clear all frames axes[i].get_yaxis().set_ticks([]) axes[i].get_xaxis().set_ticks([]) for spine in axes[i].spines.values(): spine.set_visible(False) if titles is not None: axes[i].set_title(titles[i]) plt.tight_layout(pad=pad) if kpts0 is not None: assert kpts1 is not None axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5) axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5) # draw matches if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: fig.canvas.draw() transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) fig.lines = [ matplotlib.lines.Line2D( (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), transform=fig.transFigure, c=color[i], linewidth=2, ) for i in range(len(mkpts0)) ] # freeze the axes to prevent the transform to change axes[0].autoscale(enable=False) axes[1].autoscale(enable=False) axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) # put txts txt_color = "k" if img0[:100, :200].mean() > 200 else "w" fig.text( 0.01, 0.99, "\n".join(text), transform=fig.axes[0].transAxes, fontsize=15, va="top", ha="left", color=txt_color, ) # save or return figure if path: plt.savefig(str(path), bbox_inches="tight", pad_inches=0) plt.close() else: return fig def error_colormap( err: np.ndarray, thr: float, alpha: float = 1.0 ) -> np.ndarray: """ Create a colormap based on the error values. Args: err: Error values as a numpy array of shape (N,). thr: Threshold value for the error. alpha: Alpha value for the colormap, between 0 and 1. Returns: Colormap as a numpy array of shape (N, 4) with values in [0, 1]. """ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" x = 1 - np.clip(err / (thr * 2), 0, 1) return np.clip( np.stack( [2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1 ), 0, 1, ) np.random.seed(1995) color_map = np.arange(100) np.random.shuffle(color_map) def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray: """ Convert a matplotlib figure to a numpy array with RGB values. Args: fig: A matplotlib figure. Returns: A numpy array with shape (height, width, 3) and dtype uint8 containing the RGB values of the figure. """ fig.canvas.draw() (width, height) = fig.canvas.get_width_height() buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") return buf_ndarray.reshape(height, width, 3) def draw_matches( mkpts0: List[np.ndarray], mkpts1: List[np.ndarray], img0: np.ndarray, img1: np.ndarray, conf: np.ndarray, titles: Optional[List[str]] = None, texts: Optional[List[str]] = None, dpi: int = 150, path: Optional[str] = None, pad: float = 0.5, ) -> np.ndarray: """ Draw matches between two images. Args: mkpts0: List of matches from the first image, with shape (N, 2) mkpts1: List of matches from the second image, with shape (N, 2) img0: First image, with shape (H, W, 3) img1: Second image, with shape (H, W, 3) conf: Confidence values for the matches, with shape (N,) titles: Optional list of title strings for the plot dpi: DPI for the saved image path: Optional path to save the image to. If None, the image is not saved. pad: Padding between subplots Returns: The figure as a numpy array with shape (height, width, 3) and dtype uint8 containing the RGB values of the figure. """ thr = 5e-4 thr = 0.5 color = error_colormap(conf, thr, alpha=0.1) text = [ "image name", f"#Matches: {len(mkpts0)}", ] if path: fig2im( make_matching_figure( img0, img1, mkpts0, mkpts1, color, titles=titles, text=text, path=path, dpi=dpi, pad=pad, ) ) else: return fig2im( make_matching_figure( img0, img1, mkpts0, mkpts1, color, titles=titles, text=text, pad=pad, dpi=dpi, ) ) def draw_image_pairs( img0: np.ndarray, img1: np.ndarray, text: List[str] = [], dpi: int = 75, path: Optional[str] = None, pad: float = 0.5, ) -> np.ndarray: """Draw image pair horizontally. Args: img0: First image, with shape (H, W, 3) img1: Second image, with shape (H, W, 3) text: List of strings to print. Each string is a new line. dpi: DPI of the figure. path: Path to save the image to. If None, the image is not saved and the function returns the figure as a numpy array with shape (height, width, 3) and dtype uint8 containing the RGB values of the figure. pad: Padding between subplots Returns: The figure as a numpy array with shape (height, width, 3) and dtype uint8 containing the RGB values of the figure, or None if path is not None. """ # draw image pair fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) axes[0].imshow(img0) # , cmap='gray') axes[1].imshow(img1) # , cmap='gray') for i in range(2): # clear all frames axes[i].get_yaxis().set_ticks([]) axes[i].get_xaxis().set_ticks([]) for spine in axes[i].spines.values(): spine.set_visible(False) plt.tight_layout(pad=pad) # put txts txt_color = "k" if img0[:100, :200].mean() > 200 else "w" fig.text( 0.01, 0.99, "\n".join(text), transform=fig.axes[0].transAxes, fontsize=15, va="top", ha="left", color=txt_color, ) # save or return figure if path: plt.savefig(str(path), bbox_inches="tight", pad_inches=0) plt.close() else: return fig2im(fig) def display_matches( pred: Dict[str, np.ndarray], titles: List[str] = [], texts: List[str] = [], dpi: int = 300, ) -> Tuple[np.ndarray, int]: """ Displays the matches between two images. Args: pred: Dictionary containing the original images and the matches. titles: Optional titles for the plot. dpi: Resolution of the plot. Returns: The resulting concatenated plot and the number of inliers. """ img0 = pred["image0_orig"] img1 = pred["image1_orig"] num_inliers = 0 if ( "keypoints0_orig" in pred and "keypoints1_orig" in pred and pred["keypoints0_orig"] is not None and pred["keypoints1_orig"] is not None ): mkpts0 = pred["keypoints0_orig"] mkpts1 = pred["keypoints1_orig"] num_inliers = len(mkpts0) if "mconf" in pred: mconf = pred["mconf"] else: mconf = np.ones(len(mkpts0)) fig_mkpts = draw_matches( mkpts0, mkpts1, img0, img1, mconf, dpi=dpi, titles=titles, texts=texts, ) fig = fig_mkpts if ( "line0_orig" in pred and "line1_orig" in pred and pred["line0_orig"] is not None and pred["line1_orig"] is not None ): # lines mtlines0 = pred["line0_orig"] mtlines1 = pred["line1_orig"] num_inliers = len(mtlines0) fig_lines = plot_images( [img0.squeeze(), img1.squeeze()], ["Image 0 - matched lines", "Image 1 - matched lines"], dpi=300, ) fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2) fig_lines = fig2im(fig_lines) # keypoints mkpts0 = pred.get("line_keypoints0_orig") mkpts1 = pred.get("line_keypoints1_orig") if mkpts0 is not None and mkpts1 is not None: num_inliers = len(mkpts0) if "mconf" in pred: mconf = pred["mconf"] else: mconf = np.ones(len(mkpts0)) fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300) fig_lines = cv2.resize( fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]) ) fig = np.concatenate([fig_mkpts, fig_lines], axis=0) else: fig = fig_lines return fig, num_inliers