Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
import functools
import traceback
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button
from omegaconf import OmegaConf
from ..datasets.base_dataset import collate
# from ..eval.export_predictions import load_predictions
from ..models.cache_loader import CacheLoader
from .tools import RadioHideTool
class GlobalFrame:
default_conf = {
"x": "???",
"y": "???",
"diff": False,
"child": {},
"remove_outliers": False,
}
child_frame = None # MatchFrame
childs = []
lines = []
scatters = {}
def __init__(
self, conf, results, loader, predictions, title=None, child_frame=None
):
self.child_frame = child_frame
if self.child_frame is not None:
# We do NOT merge inside the child frame to keep settings across figs
self.default_conf["child"] = self.child_frame.default_conf
self.conf = OmegaConf.merge(self.default_conf, conf)
self.results = results
self.loader = loader
self.predictions = predictions
self.metrics = set()
for k, v in results.items():
self.metrics.update(v.keys())
self.metrics = sorted(list(self.metrics))
self.conf.x = conf["x"] if conf["x"] else self.metrics[0]
self.conf.y = conf["y"] if conf["y"] else self.metrics[1]
assert self.conf.x in self.metrics
assert self.conf.y in self.metrics
self.names = list(results)
self.fig, self.axes = self.init_frame()
if title is not None:
self.fig.canvas.manager.set_window_title(title)
self.xradios = self.fig.canvas.manager.toolmanager.add_tool(
"x",
RadioHideTool,
options=self.metrics,
callback_fn=self.update_x,
active=self.conf.x,
keymap="x",
)
self.yradios = self.fig.canvas.manager.toolmanager.add_tool(
"y",
RadioHideTool,
options=self.metrics,
callback_fn=self.update_y,
active=self.conf.y,
keymap="y",
)
if self.fig.canvas.manager.toolbar is not None:
self.fig.canvas.manager.toolbar.add_tool("x", "navigation")
self.fig.canvas.manager.toolbar.add_tool("y", "navigation")
def init_frame(self):
"""initialize frame"""
fig, ax = plt.subplots()
ax.set_title("click on points")
diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06])
self.diffb = Button(diffb_ax, label="diff_only")
self.diffb.on_clicked(self.diff_clicked)
fig.canvas.mpl_connect("pick_event", self.on_scatter_pick)
fig.canvas.mpl_connect("motion_notify_event", self.hover)
return fig, ax
def draw(self):
"""redraw content in frame"""
self.scatters = {}
self.axes.clear()
self.axes.set_xlabel(self.conf.x)
self.axes.set_ylabel(self.conf.y)
refx = 0.0
refy = 0.0
x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str))
y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str))
if self.conf.diff:
if not x_cat:
refx = np.array(self.results[self.names[0]][self.conf.x])
if not y_cat:
refy = np.array(self.results[self.names[0]][self.conf.y])
for name in list(self.results.keys()):
x = np.array(self.results[name][self.conf.x])
y = np.array(self.results[name][self.conf.y])
if x_cat and np.char.isdigit(x.astype(str)).all():
x = x.astype(int)
if y_cat and np.char.isdigit(y.astype(str)).all():
y = y.astype(int)
x = x if x_cat else x - refx
y = y if y_cat else y - refy
(s,) = self.axes.plot(
x, y, "o", markersize=3, label=name, picker=True, pickradius=5
)
self.scatters[name] = s
if x_cat and not y_cat:
xunique, ind, xinv, xbin = np.unique(
x, return_inverse=True, return_counts=True, return_index=True
)
ybin = np.bincount(xinv, weights=y)
sort_ax = np.argsort(ind)
self.axes.step(
xunique[sort_ax],
(ybin / xbin)[sort_ax],
where="mid",
color=s.get_color(),
)
if not x_cat:
xavg = np.nan_to_num(x).mean()
self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0)
xmed = np.median(x - refx)
self.axes.axvline(
xmed,
c=s.get_color(),
zorder=0,
alpha=0.5,
linestyle="dashed",
visible=False,
)
if not y_cat:
yavg = np.nan_to_num(y).mean()
self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5)
ymed = np.median(y - refy)
self.axes.axhline(
ymed,
c=s.get_color(),
zorder=0,
alpha=0.5,
linestyle="dashed",
visible=False,
)
if x_cat and x.dtype == object and xunique.shape[0] > 5:
self.axes.set_xticklabels(xunique[sort_ax], rotation=90)
self.axes.legend()
def on_scatter_pick(self, handle):
try:
art = handle.artist
try:
event = handle.mouseevent.button.value
except AttributeError:
return
name = art.get_label()
ind = handle.ind[0]
# draw lines
self.spawn_child(name, ind, event=event)
except Exception:
traceback.print_exc()
exit(0)
def spawn_child(self, model_name, ind, event=None):
[line.remove() for line in self.lines]
self.lines = []
x_source = self.scatters[model_name].get_xdata()[ind]
y_source = self.scatters[model_name].get_ydata()[ind]
for oname in self.names:
xn = self.scatters[oname].get_xdata()[ind]
yn = self.scatters[oname].get_ydata()[ind]
(ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r")
self.lines.append(ln)
self.fig.canvas.draw_idle()
if self.child_frame is None:
return
data = collate([self.loader.dataset[ind]])
preds = {}
for name, pfile in self.predictions.items():
preds[name] = CacheLoader({"path": str(pfile), "add_data_path": False})(
data
)
summaries_i = {
name: {k: v[ind] for k, v in res.items() if k != "names"}
for name, res in self.results.items()
}
frame = self.child_frame(
self.conf.child,
deepcopy(data),
preds,
title=str(data["name"][0]),
event=event,
summaries=summaries_i,
)
frame.fig.canvas.mpl_connect(
"key_press_event",
functools.partial(
self.on_childframe_key_event, frame=frame, ind=ind, event=event
),
)
self.childs.append(frame)
# if plt.rcParams['backend'] == 'webagg':
# self.fig.canvas.manager_class.refresh_all()
self.childs[-1].fig.show()
def hover(self, event):
if event.inaxes == self.axes:
for _, s in self.scatters.items():
cont, ind = s.contains(event)
if cont:
ind = ind["ind"][0]
xdata, ydata = s.get_data()
[line.remove() for line in self.lines]
self.lines = []
for oname in self.names:
xn = self.scatters[oname].get_xdata()[ind]
yn = self.scatters[oname].get_ydata()[ind]
(ln,) = self.axes.plot(
[xdata[ind], xn],
[ydata[ind], yn],
"black",
zorder=0,
alpha=0.5,
)
self.lines.append(ln)
self.fig.canvas.draw_idle()
break
def diff_clicked(self, args):
self.conf.diff = not self.conf.diff
self.draw()
self.fig.canvas.draw_idle()
def update_x(self, x):
self.conf.x = x
self.draw()
def update_y(self, y):
self.conf.y = y
self.draw()
def on_childframe_key_event(self, key_event, frame, ind, event):
if key_event.key == "delete":
plt.close(frame.fig)
self.childs.remove(frame)
elif key_event.key in ["left", "right", "shift+left", "shift+right"]:
key = key_event.key
if key.startswith("shift+"):
key = key.replace("shift+", "")
else:
plt.close(frame.fig)
self.childs.remove(frame)
new_ind = ind + 1 if key_event.key == "right" else ind - 1
self.spawn_child(
self.names[0],
new_ind % len(self.loader),
event=event,
)