Spaces:
Running
Running
import pprint | |
import numpy as np | |
from . import viz2d | |
from .tools import RadioHideTool, ToggleTool, __plot_dict__ | |
class FormatPrinter(pprint.PrettyPrinter): | |
def __init__(self, formats): | |
super(FormatPrinter, self).__init__() | |
self.formats = formats | |
def format(self, obj, ctx, maxlvl, lvl): | |
if type(obj) in self.formats: | |
return self.formats[type(obj)] % obj, 1, 0 | |
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl) | |
class TwoViewFrame: | |
default_conf = { | |
"default": "matches", | |
"summary_visible": False, | |
} | |
plot_dict = __plot_dict__ | |
childs = [] | |
event_to_image = [None, "color", "depth", "color+depth"] | |
def __init__(self, conf, data, preds, title=None, event=1, summaries=None): | |
self.conf = conf | |
self.data = data | |
self.preds = preds | |
self.names = list(preds.keys()) | |
self.plot = self.event_to_image[event] | |
self.summaries = summaries | |
self.fig, self.axes, self.summary_arts = self.init_frame() | |
if title is not None: | |
self.fig.canvas.manager.set_window_title(title) | |
keys = None | |
for _, pred in preds.items(): | |
if keys is None: | |
keys = set(pred.keys()) | |
else: | |
keys = keys.intersection(pred.keys()) | |
keys = keys.union(data.keys()) | |
self.options = [ | |
k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys) | |
] | |
self.handle = None | |
self.radios = self.fig.canvas.manager.toolmanager.add_tool( | |
"switch plot", | |
RadioHideTool, | |
options=self.options, | |
callback_fn=self.draw, | |
active=conf.default, | |
keymap="R", | |
) | |
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool( | |
"toggle summary", | |
ToggleTool, | |
toggled=self.conf.summary_visible, | |
callback_fn=self.set_summary_visible, | |
keymap="t", | |
) | |
if self.fig.canvas.manager.toolbar is not None: | |
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation") | |
self.draw(conf.default) | |
def init_frame(self): | |
"""initialize frame""" | |
view0, view1 = self.data["view0"], self.data["view1"] | |
if self.plot == "color" or self.plot == "color+depth": | |
imgs = [ | |
view0["image"][0].permute(1, 2, 0), | |
view1["image"][0].permute(1, 2, 0), | |
] | |
elif self.plot == "depth": | |
imgs = [view0["depth"][0], view1["depth"][0]] | |
else: | |
raise ValueError(self.plot) | |
imgs = [imgs for _ in self.names] # repeat for each model | |
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5) | |
[viz2d.add_text(0, n, axes=axes[i]) for i, n in enumerate(self.names)] | |
if ( | |
self.plot == "color+depth" | |
and "depth" in view0.keys() | |
and view0["depth"] is not None | |
): | |
hmaps = [[view0["depth"][0], view1["depth"][0]] for _ in self.names] | |
[ | |
viz2d.plot_heatmaps(hmaps[i], axes=axes[i], cmap="Spectral") | |
for i, _ in enumerate(hmaps) | |
] | |
fig.canvas.mpl_connect("pick_event", self.click_artist) | |
if self.summaries is not None: | |
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"}) | |
toggle_artists = [ | |
viz2d.add_text( | |
0, | |
formatter.pformat(self.summaries[n]), | |
axes=axes[i], | |
pos=(0.01, 0.01), | |
va="bottom", | |
backgroundcolor=(0, 0, 0, 0.5), | |
visible=self.conf.summary_visible, | |
) | |
for i, n in enumerate(self.names) | |
] | |
else: | |
toggle_artists = [] | |
return fig, axes, toggle_artists | |
def draw(self, value): | |
"""redraw content in frame""" | |
self.clear() | |
self.conf.default = value | |
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds) | |
return self.handle | |
def clear(self): | |
if self.handle is not None: | |
try: | |
self.handle.clear() | |
except AttributeError: | |
pass | |
self.handle = None | |
for row in self.axes: | |
for ax in row: | |
[li.remove() for li in ax.lines] | |
[c.remove() for c in ax.collections] | |
self.fig.artists.clear() | |
self.fig.canvas.draw_idle() | |
self.handle = None | |
def click_artist(self, event): | |
art = event.artist | |
select = art.get_arrowstyle().arrow == "-" | |
art.set_arrowstyle("<|-|>" if select else "-") | |
if select: | |
art.set_zorder(1) | |
if hasattr(self.handle, "click_artist"): | |
self.handle.click_artist(event) | |
self.fig.canvas.draw_idle() | |
def set_summary_visible(self, visible): | |
self.conf.summary_visible = visible | |
[s.set_visible(visible) for s in self.summary_arts] | |
self.fig.canvas.draw_idle() | |