Spaces:
Runtime error
Runtime error
| import inspect | |
| import sys | |
| import warnings | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from matplotlib.backend_tools import ToolToggleBase | |
| from matplotlib.widgets import RadioButtons, Slider | |
| from ..geometry.epipolar import T_to_F, generalized_epi_dist | |
| from ..geometry.homography import sym_homography_error | |
| from ..visualization.viz2d import ( | |
| cm_ranking, | |
| cm_RdGn, | |
| draw_epipolar_line, | |
| get_line, | |
| plot_color_line_matches, | |
| plot_heatmaps, | |
| plot_keypoints, | |
| plot_lines, | |
| plot_matches, | |
| ) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| plt.rcParams["toolbar"] = "toolmanager" | |
| class RadioHideTool(ToolToggleBase): | |
| """Show lines with a given gid.""" | |
| default_keymap = "R" | |
| description = "Show by gid" | |
| default_toggled = False | |
| radio_group = "default" | |
| def __init__( | |
| self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.f = 1.0 | |
| self.options = options | |
| self.callback_fn = callback_fn | |
| self.active = self.options.index(active) if active else 0 | |
| self.default_keymap = keymap | |
| self.enabled = self.default_toggled | |
| def build_radios(self): | |
| w = 0.2 | |
| self.radios_ax = self.figure.add_axes([1.0 - w, 0.7, w, 0.2], zorder=1) | |
| # self.radios_ax = self.figure.add_axes([0.5-w/2, 1.0-0.2, w, 0.2], zorder=1) | |
| self.radios = RadioButtons(self.radios_ax, self.options, active=self.active) | |
| self.radios.on_clicked(self.on_radio_clicked) | |
| def enable(self, *args): | |
| size = self.figure.get_size_inches() | |
| size[0] *= self.f | |
| self.build_radios() | |
| self.figure.canvas.draw_idle() | |
| self.enabled = True | |
| def disable(self, *args): | |
| size = self.figure.get_size_inches() | |
| size[0] /= self.f | |
| self.radios_ax.remove() | |
| self.radios = None | |
| self.figure.canvas.draw_idle() | |
| self.enabled = False | |
| def on_radio_clicked(self, value): | |
| self.active = self.options.index(value) | |
| enabled = self.enabled | |
| if enabled: | |
| self.disable() | |
| if self.callback_fn is not None: | |
| self.callback_fn(value) | |
| if enabled: | |
| self.enable() | |
| class ToggleTool(ToolToggleBase): | |
| """Show lines with a given gid.""" | |
| default_keymap = "t" | |
| description = "Show by gid" | |
| def __init__(self, *args, callback_fn=None, keymap="t", **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.f = 1.0 | |
| self.callback_fn = callback_fn | |
| self.default_keymap = keymap | |
| self.enabled = self.default_toggled | |
| def enable(self, *args): | |
| self.callback_fn(True) | |
| def disable(self, *args): | |
| self.callback_fn(False) | |
| def add_whitespace_left(fig, factor): | |
| w, h = fig.get_size_inches() | |
| left = fig.subplotpars.left | |
| fig.set_size_inches([w * (1 + factor), h]) | |
| fig.subplots_adjust(left=(factor + left) / (1 + factor)) | |
| def add_whitespace_bottom(fig, factor): | |
| w, h = fig.get_size_inches() | |
| b = fig.subplotpars.bottom | |
| fig.set_size_inches([w, h * (1 + factor)]) | |
| fig.subplots_adjust(bottom=(factor + b) / (1 + factor)) | |
| fig.canvas.draw_idle() | |
| class KeypointPlot: | |
| plot_name = "keypoints" | |
| required_keys = ["keypoints0", "keypoints1"] | |
| def __init__(self, fig, axes, data, preds): | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_keypoints([pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i]) | |
| class LinePlot: | |
| plot_name = "lines" | |
| required_keys = ["lines0", "lines1"] | |
| def __init__(self, fig, axes, data, preds): | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_lines([pred["lines0"][0], pred["lines1"][0]]) | |
| class KeypointRankingPlot: | |
| plot_name = "keypoint_ranking" | |
| required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"] | |
| def __init__(self, fig, axes, data, preds): | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0] | |
| plot_keypoints( | |
| [kp0, kp1], axes=axes[i], colors=[cm_ranking(sc0), cm_ranking(sc1)] | |
| ) | |
| class KeypointScoresPlot: | |
| plot_name = "keypoint_scores" | |
| required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"] | |
| def __init__(self, fig, axes, data, preds): | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0] | |
| plot_keypoints( | |
| [kp0, kp1], axes=axes[i], colors=[cm_RdGn(sc0), cm_RdGn(sc1)] | |
| ) | |
| class HeatmapPlot: | |
| plot_name = "heatmaps" | |
| required_keys = ["heatmap0", "heatmap1"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.artists = [] | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| heatmaps = [pred["heatmap0"][0, 0], pred["heatmap1"][0, 0]] | |
| heatmaps = [torch.sigmoid(h) if h.min() < 0.0 else h for h in heatmaps] | |
| self.artists += plot_heatmaps(heatmaps, axes=axes[i], cmap="rainbow") | |
| def clear(self): | |
| for x in self.artists: | |
| x.remove() | |
| class ImagePlot: | |
| plot_name = "images" | |
| required_keys = ["view0", "view1"] | |
| def __init__(self, fig, axes, data, preds): | |
| pass | |
| class MatchesPlot: | |
| plot_name = "matches" | |
| required_keys = ["keypoints0", "keypoints1", "matches0", "matching_scores0"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_keypoints( | |
| [pred["keypoints0"][0], pred["keypoints1"][0]], | |
| axes=axes[i], | |
| colors="blue", | |
| ) | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| m0 = pred["matches0"][0] | |
| valid = m0 > -1 | |
| kpm0 = kp0[valid] | |
| kpm1 = kp1[m0[valid]] | |
| mscores = pred["matching_scores0"][0][valid] | |
| plot_matches( | |
| kpm0, | |
| kpm1, | |
| color=cm_RdGn(mscores).tolist(), | |
| axes=axes[i], | |
| labels=mscores, | |
| lw=0.5, | |
| ) | |
| class LineMatchesPlot: | |
| plot_name = "line_matches" | |
| required_keys = ["lines0", "lines1", "line_matches0"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| lines0, lines1 = pred["lines0"][0], pred["lines1"][0] | |
| m0 = pred["line_matches0"][0] | |
| valid = m0 > -1 | |
| m_lines0 = lines0[valid] | |
| m_lines1 = lines1[m0[valid]] | |
| plot_color_line_matches([m_lines0, m_lines1]) | |
| class GtMatchesPlot: | |
| plot_name = "gt_matches" | |
| required_keys = ["keypoints0", "keypoints1", "matches0", "gt_matches0"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_keypoints( | |
| [pred["keypoints0"][0], pred["keypoints1"][0]], | |
| axes=axes[i], | |
| colors="blue", | |
| ) | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| m0 = pred["matches0"][0] | |
| gtm0 = pred["gt_matches0"][0] | |
| valid = (m0 > -1) & (gtm0 >= -1) | |
| kpm0 = kp0[valid] | |
| kpm1 = kp1[m0[valid]] | |
| correct = gtm0[valid] == m0[valid] | |
| plot_matches( | |
| kpm0, | |
| kpm1, | |
| color=cm_RdGn(correct).tolist(), | |
| axes=axes[i], | |
| labels=correct, | |
| lw=0.5, | |
| ) | |
| class GtLineMatchesPlot: | |
| plot_name = "gt_line_matches" | |
| required_keys = ["lines0", "lines1", "line_matches0", "line_gt_matches0"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| lines0, lines1 = pred["lines0"][0], pred["lines1"][0] | |
| m0 = pred["line_matches0"][0] | |
| gtm0 = pred["gt_line_matches0"][0] | |
| valid = (m0 > -1) & (gtm0 >= -1) | |
| m_lines0 = lines0[valid] | |
| m_lines1 = lines1[m0[valid]] | |
| plot_color_line_matches([m_lines0, m_lines1]) | |
| class HomographyMatchesPlot: | |
| plot_name = "homography" | |
| required_keys = ["keypoints0", "keypoints1", "matches0", "H_0to1"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| add_whitespace_bottom(fig, 0.1) | |
| self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06]) | |
| self.range = Slider( | |
| self.range_ax, | |
| label="Homography Error", | |
| valmin=0, | |
| valmax=5, | |
| valinit=3.0, | |
| valstep=1.0, | |
| ) | |
| self.range.on_changed(self.color_matches) | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_keypoints( | |
| [pred["keypoints0"][0], pred["keypoints1"][0]], | |
| axes=axes[i], | |
| colors="blue", | |
| ) | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| m0 = pred["matches0"][0] | |
| valid = m0 > -1 | |
| kpm0 = kp0[valid] | |
| kpm1 = kp1[m0[valid]] | |
| errors = sym_homography_error(kpm0, kpm1, data["H_0to1"][0]) | |
| plot_matches( | |
| kpm0, | |
| kpm1, | |
| color=cm_RdGn(errors < self.range.val).tolist(), | |
| axes=axes[i], | |
| labels=errors.numpy(), | |
| lw=0.5, | |
| ) | |
| def clear(self): | |
| w, h = self.fig.get_size_inches() | |
| self.fig.set_size_inches(w, h / 1.1) | |
| self.fig.subplots_adjust(**self.sbpars) | |
| self.range_ax.remove() | |
| def color_matches(self, args): | |
| for line in self.fig.artists: | |
| label = line.get_label() | |
| line.set_color(cm_RdGn([float(label) < args])[0]) | |
| class EpipolarMatchesPlot: | |
| plot_name = "epipolar_matches" | |
| required_keys = ["keypoints0", "keypoints1", "matches0", "T_0to1", "view0", "view1"] | |
| def __init__(self, fig, axes, data, preds): | |
| self.fig = fig | |
| self.axes = axes | |
| self.sbpars = { | |
| k: v | |
| for k, v in vars(fig.subplotpars).items() | |
| if k in ["left", "right", "top", "bottom"] | |
| } | |
| add_whitespace_bottom(fig, 0.1) | |
| self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06]) | |
| self.range = Slider( | |
| self.range_ax, | |
| label="Epipolar Error [px]", | |
| valmin=0, | |
| valmax=5, | |
| valinit=3.0, | |
| valstep=1.0, | |
| ) | |
| self.range.on_changed(self.color_matches) | |
| camera0 = data["view0"]["camera"][0] | |
| camera1 = data["view1"]["camera"][0] | |
| T_0to1 = data["T_0to1"][0] | |
| for i, name in enumerate(preds): | |
| pred = preds[name] | |
| plot_keypoints( | |
| [pred["keypoints0"][0], pred["keypoints1"][0]], | |
| axes=axes[i], | |
| colors="blue", | |
| ) | |
| kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] | |
| m0 = pred["matches0"][0] | |
| valid = m0 > -1 | |
| kpm0 = kp0[valid] | |
| kpm1 = kp1[m0[valid]] | |
| errors = generalized_epi_dist( | |
| kpm0, | |
| kpm1, | |
| camera0, | |
| camera1, | |
| T_0to1, | |
| all=False, | |
| essential=False, | |
| ) | |
| plot_matches( | |
| kpm0, | |
| kpm1, | |
| color=cm_RdGn(errors < self.range.val).tolist(), | |
| axes=axes[i], | |
| labels=errors.numpy(), | |
| lw=0.5, | |
| ) | |
| self.F = T_to_F(camera0, camera1, T_0to1) | |
| def clear(self): | |
| w, h = self.fig.get_size_inches() | |
| self.fig.set_size_inches(w, h / 1.1) | |
| self.fig.subplots_adjust(**self.sbpars) | |
| self.range_ax.remove() | |
| def color_matches(self, args): | |
| for art in self.fig.artists: | |
| label = art.get_label() | |
| if label is not None: | |
| art.set_color(cm_RdGn([float(label) < args])[0]) | |
| def click_artist(self, event): | |
| art = event.artist | |
| if art.get_label() is not None: | |
| if hasattr(art, "epilines"): | |
| [ | |
| x.set_visible(not x.get_visible()) | |
| for x in art.epilines | |
| if x is not None | |
| ] | |
| else: | |
| xy1 = art.xy1 | |
| xy2 = art.xy2 | |
| line0 = get_line(self.F.transpose(0, 1), xy2)[:, 0] | |
| line1 = get_line(self.F, xy1)[:, 0] | |
| art.epilines = [ | |
| draw_epipolar_line(line0, art.axesA), | |
| draw_epipolar_line(line1, art.axesB), | |
| ] | |
| __plot_dict__ = { | |
| obj.plot_name: obj | |
| for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass) | |
| if hasattr(obj, "plot_name") | |
| } | |