|
import pprint |
|
|
|
import numpy as np |
|
|
|
from . import viz2d |
|
from .tools import RadioHideTool, ToggleTool, __plot_dict__ |
|
|
|
|
|
|
|
|
|
|
|
class FormatPrinter(pprint.PrettyPrinter): |
|
def __init__(self, formats): |
|
super(FormatPrinter, self).__init__() |
|
self.formats = formats |
|
|
|
def format(self, obj, ctx, maxlvl, lvl): |
|
if type(obj) in self.formats: |
|
return self.formats[type(obj)] % obj, 1, 0 |
|
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl) |
|
|
|
|
|
class TwoViewFrame: |
|
default_conf = { |
|
"default": "image", |
|
"summary_visible": False, |
|
} |
|
|
|
plot_dict = __plot_dict__ |
|
|
|
childs = [] |
|
|
|
event_to_image = [None, "image", "horizon_line", "lat_pred", "lat_gt"] |
|
|
|
def __init__(self, conf, data, preds, title=None, event=1, summaries=None): |
|
self.conf = conf |
|
self.data = data |
|
self.preds = preds |
|
self.names = list(preds.keys()) |
|
self.plot = self.event_to_image[event] |
|
self.summaries = summaries |
|
self.fig, self.axes, self.summary_arts = self.init_frame() |
|
if title is not None: |
|
self.fig.canvas.manager.set_window_title(title) |
|
|
|
keys = None |
|
for _, pred in preds.items(): |
|
keys = set(pred.keys()) if keys is None else keys.intersection(pred.keys()) |
|
|
|
keys = keys.union(data.keys()) |
|
|
|
self.options = [k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)] |
|
self.handle = None |
|
self.radios = self.fig.canvas.manager.toolmanager.add_tool( |
|
"switch plot", |
|
RadioHideTool, |
|
options=self.options, |
|
callback_fn=self.draw, |
|
active=conf.default, |
|
keymap="R", |
|
) |
|
|
|
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool( |
|
"toggle summary", |
|
ToggleTool, |
|
toggled=self.conf.summary_visible, |
|
callback_fn=self.set_summary_visible, |
|
keymap="t", |
|
) |
|
|
|
if self.fig.canvas.manager.toolbar is not None: |
|
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation") |
|
self.draw(conf.default) |
|
|
|
def init_frame(self): |
|
"""initialize frame""" |
|
imgs = [[self.data["image"][0].permute(1, 2, 0) for _ in self.names]] |
|
|
|
|
|
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5) |
|
[viz2d.add_text(i, n, axes=axes[0]) for i, n in enumerate(self.names)] |
|
|
|
fig.canvas.mpl_connect("pick_event", self.click_artist) |
|
if self.summaries is not None: |
|
font_size = 7 |
|
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"}) |
|
toggle_artists = [ |
|
viz2d.add_text( |
|
i, |
|
formatter.pformat(self.summaries[n]), |
|
axes=axes[0], |
|
pos=(0.01, 0.01), |
|
va="bottom", |
|
backgroundcolor=(0, 0, 0, 0.5), |
|
visible=self.conf.summary_visible, |
|
fs=font_size, |
|
) |
|
for i, n in enumerate(self.names) |
|
] |
|
else: |
|
toggle_artists = [] |
|
return fig, axes, toggle_artists |
|
|
|
def draw(self, value): |
|
"""redraw content in frame""" |
|
self.clear() |
|
self.conf.default = value |
|
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds) |
|
return self.handle |
|
|
|
def clear(self): |
|
if self.handle is not None: |
|
try: |
|
self.handle.clear() |
|
except AttributeError: |
|
pass |
|
self.handle = None |
|
for row in self.axes: |
|
for ax in row: |
|
[li.remove() for li in ax.lines] |
|
[c.remove() for c in ax.collections] |
|
self.fig.artists.clear() |
|
self.fig.canvas.draw_idle() |
|
self.handle = None |
|
|
|
def click_artist(self, event): |
|
art = event.artist |
|
select = art.get_arrowstyle().arrow == "-" |
|
art.set_arrowstyle("<|-|>" if select else "-") |
|
if select: |
|
art.set_zorder(1) |
|
if hasattr(self.handle, "click_artist"): |
|
self.handle.click_artist(event) |
|
self.fig.canvas.draw_idle() |
|
|
|
def set_summary_visible(self, visible): |
|
self.conf.summary_visible = visible |
|
[s.set_visible(visible) for s in self.summary_arts] |
|
self.fig.canvas.draw_idle() |
|
|