Spaces:
Running
Running
import functools | |
import traceback | |
from copy import deepcopy | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.widgets import Button | |
from omegaconf import OmegaConf | |
from ..datasets.base_dataset import collate | |
# from ..eval.export_predictions import load_predictions | |
from ..models.cache_loader import CacheLoader | |
from .tools import RadioHideTool | |
class GlobalFrame: | |
default_conf = { | |
"x": "???", | |
"y": "???", | |
"diff": False, | |
"child": {}, | |
"remove_outliers": False, | |
} | |
child_frame = None # MatchFrame | |
childs = [] | |
lines = [] | |
scatters = {} | |
def __init__( | |
self, conf, results, loader, predictions, title=None, child_frame=None | |
): | |
self.child_frame = child_frame | |
if self.child_frame is not None: | |
# We do NOT merge inside the child frame to keep settings across figs | |
self.default_conf["child"] = self.child_frame.default_conf | |
self.conf = OmegaConf.merge(self.default_conf, conf) | |
self.results = results | |
self.loader = loader | |
self.predictions = predictions | |
self.metrics = set() | |
for k, v in results.items(): | |
self.metrics.update(v.keys()) | |
self.metrics = sorted(list(self.metrics)) | |
self.conf.x = conf["x"] if conf["x"] else self.metrics[0] | |
self.conf.y = conf["y"] if conf["y"] else self.metrics[1] | |
assert self.conf.x in self.metrics | |
assert self.conf.y in self.metrics | |
self.names = list(results) | |
self.fig, self.axes = self.init_frame() | |
if title is not None: | |
self.fig.canvas.manager.set_window_title(title) | |
self.xradios = self.fig.canvas.manager.toolmanager.add_tool( | |
"x", | |
RadioHideTool, | |
options=self.metrics, | |
callback_fn=self.update_x, | |
active=self.conf.x, | |
keymap="x", | |
) | |
self.yradios = self.fig.canvas.manager.toolmanager.add_tool( | |
"y", | |
RadioHideTool, | |
options=self.metrics, | |
callback_fn=self.update_y, | |
active=self.conf.y, | |
keymap="y", | |
) | |
if self.fig.canvas.manager.toolbar is not None: | |
self.fig.canvas.manager.toolbar.add_tool("x", "navigation") | |
self.fig.canvas.manager.toolbar.add_tool("y", "navigation") | |
def init_frame(self): | |
"""initialize frame""" | |
fig, ax = plt.subplots() | |
ax.set_title("click on points") | |
diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06]) | |
self.diffb = Button(diffb_ax, label="diff_only") | |
self.diffb.on_clicked(self.diff_clicked) | |
fig.canvas.mpl_connect("pick_event", self.on_scatter_pick) | |
fig.canvas.mpl_connect("motion_notify_event", self.hover) | |
return fig, ax | |
def draw(self): | |
"""redraw content in frame""" | |
self.scatters = {} | |
self.axes.clear() | |
self.axes.set_xlabel(self.conf.x) | |
self.axes.set_ylabel(self.conf.y) | |
refx = 0.0 | |
refy = 0.0 | |
x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str)) | |
y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str)) | |
if self.conf.diff: | |
if not x_cat: | |
refx = np.array(self.results[self.names[0]][self.conf.x]) | |
if not y_cat: | |
refy = np.array(self.results[self.names[0]][self.conf.y]) | |
for name in list(self.results.keys()): | |
x = np.array(self.results[name][self.conf.x]) | |
y = np.array(self.results[name][self.conf.y]) | |
if x_cat and np.char.isdigit(x.astype(str)).all(): | |
x = x.astype(int) | |
if y_cat and np.char.isdigit(y.astype(str)).all(): | |
y = y.astype(int) | |
x = x if x_cat else x - refx | |
y = y if y_cat else y - refy | |
(s,) = self.axes.plot( | |
x, y, "o", markersize=3, label=name, picker=True, pickradius=5 | |
) | |
self.scatters[name] = s | |
if x_cat and not y_cat: | |
xunique, ind, xinv, xbin = np.unique( | |
x, return_inverse=True, return_counts=True, return_index=True | |
) | |
ybin = np.bincount(xinv, weights=y) | |
sort_ax = np.argsort(ind) | |
self.axes.step( | |
xunique[sort_ax], | |
(ybin / xbin)[sort_ax], | |
where="mid", | |
color=s.get_color(), | |
) | |
if not x_cat: | |
xavg = np.nan_to_num(x).mean() | |
self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0) | |
xmed = np.median(x - refx) | |
self.axes.axvline( | |
xmed, | |
c=s.get_color(), | |
zorder=0, | |
alpha=0.5, | |
linestyle="dashed", | |
visible=False, | |
) | |
if not y_cat: | |
yavg = np.nan_to_num(y).mean() | |
self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5) | |
ymed = np.median(y - refy) | |
self.axes.axhline( | |
ymed, | |
c=s.get_color(), | |
zorder=0, | |
alpha=0.5, | |
linestyle="dashed", | |
visible=False, | |
) | |
if x_cat and x.dtype == object and xunique.shape[0] > 5: | |
self.axes.set_xticklabels(xunique[sort_ax], rotation=90) | |
self.axes.legend() | |
def on_scatter_pick(self, handle): | |
try: | |
art = handle.artist | |
try: | |
event = handle.mouseevent.button.value | |
except AttributeError: | |
return | |
name = art.get_label() | |
ind = handle.ind[0] | |
# draw lines | |
self.spawn_child(name, ind, event=event) | |
except Exception: | |
traceback.print_exc() | |
exit(0) | |
def spawn_child(self, model_name, ind, event=None): | |
[line.remove() for line in self.lines] | |
self.lines = [] | |
x_source = self.scatters[model_name].get_xdata()[ind] | |
y_source = self.scatters[model_name].get_ydata()[ind] | |
for oname in self.names: | |
xn = self.scatters[oname].get_xdata()[ind] | |
yn = self.scatters[oname].get_ydata()[ind] | |
(ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r") | |
self.lines.append(ln) | |
self.fig.canvas.draw_idle() | |
if self.child_frame is None: | |
return | |
data = collate([self.loader.dataset[ind]]) | |
preds = {} | |
for name, pfile in self.predictions.items(): | |
preds[name] = CacheLoader({"path": str(pfile), "add_data_path": False})( | |
data | |
) | |
summaries_i = { | |
name: {k: v[ind] for k, v in res.items() if k != "names"} | |
for name, res in self.results.items() | |
} | |
frame = self.child_frame( | |
self.conf.child, | |
deepcopy(data), | |
preds, | |
title=str(data["name"][0]), | |
event=event, | |
summaries=summaries_i, | |
) | |
frame.fig.canvas.mpl_connect( | |
"key_press_event", | |
functools.partial( | |
self.on_childframe_key_event, frame=frame, ind=ind, event=event | |
), | |
) | |
self.childs.append(frame) | |
# if plt.rcParams['backend'] == 'webagg': | |
# self.fig.canvas.manager_class.refresh_all() | |
self.childs[-1].fig.show() | |
def hover(self, event): | |
if event.inaxes == self.axes: | |
for _, s in self.scatters.items(): | |
cont, ind = s.contains(event) | |
if cont: | |
ind = ind["ind"][0] | |
xdata, ydata = s.get_data() | |
[line.remove() for line in self.lines] | |
self.lines = [] | |
for oname in self.names: | |
xn = self.scatters[oname].get_xdata()[ind] | |
yn = self.scatters[oname].get_ydata()[ind] | |
(ln,) = self.axes.plot( | |
[xdata[ind], xn], | |
[ydata[ind], yn], | |
"black", | |
zorder=0, | |
alpha=0.5, | |
) | |
self.lines.append(ln) | |
self.fig.canvas.draw_idle() | |
break | |
def diff_clicked(self, args): | |
self.conf.diff = not self.conf.diff | |
self.draw() | |
self.fig.canvas.draw_idle() | |
def update_x(self, x): | |
self.conf.x = x | |
self.draw() | |
def update_y(self, y): | |
self.conf.y = y | |
self.draw() | |
def on_childframe_key_event(self, key_event, frame, ind, event): | |
if key_event.key == "delete": | |
plt.close(frame.fig) | |
self.childs.remove(frame) | |
elif key_event.key in ["left", "right", "shift+left", "shift+right"]: | |
key = key_event.key | |
if key.startswith("shift+"): | |
key = key.replace("shift+", "") | |
else: | |
plt.close(frame.fig) | |
self.childs.remove(frame) | |
new_ind = ind + 1 if key_event.key == "right" else ind - 1 | |
self.spawn_child( | |
self.names[0], | |
new_ind % len(self.loader), | |
event=event, | |
) | |