jetclustering / src /plotting /eval_matrix.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
5.15 kB
import matplotlib.pyplot as plt
import numpy as np
def ax_tiny_histogram(ax, labels, colors, values):
# Create bars
bars = ax.barh(range(len(labels)), values, color=colors, alpha=0.5)
# Add labels inside the bars, left-aligned
for i, (bar, label) in enumerate(zip(bars, labels)):
ax.text(min(values)-0.004, bar.get_y() + bar.get_height()/2, label,
va='center', ha='left', fontsize=8, color='black', clip_on=True)
ax.text(max(values)+0.001, bar.get_y() + bar.get_height()/2, f'{values[i]:.3f}',
va='center', ha='right', fontsize=8)
ax.set_yticks([])
ax.set_xticks([]) # Hide ticks for minimal look
for spine in ax.spines.values():
spine.set_visible(False)
ax.set_xlim(min(values)-0.005, max(values)+0.005)
return ax
def multiple_matrix_plot(result, labels, colors, custom_val_formula=lambda x: 2*x[0]*x[1]/(x[0]+x[1]), rename_dict={}): # custom_val_formula set to F1 score, and x is [precision, recall]
# result: mDark -> mMed -> rinv -> {label->[P, R]} - the order of the labels is set with 'labels' and the colors are set with 'colors'
# labels: list of labels to plot
mediator_masses = sorted(list(result.keys()))
r_invs = sorted(list(set([rinv for mMed in result for rinv in result[mMed]])))
sz = 3
#fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 6*len(mediator_masses)))
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), 0.65*sz*len(mediator_masses)))
if len(mediator_masses) == 1 and len(r_invs) == 1:
ax = np.array([[ax]])
for i, mMed in enumerate(mediator_masses):
for k, rinv in enumerate(r_invs):
if mMed not in result:
continue
if rinv not in result[mMed]:
continue
r = result[mMed][rinv]
r = {key: custom_val_formula(val) for key, val in r.items()}
#ax_tiny = fig.add_axes([0.3, 0.1 + i*0.2, 0.15, 0.15])
#ax_tiny = fig.add_axes([0.1 + k*0.2, 0.1 + i*0.2, 0.15, 0.15])
for label in labels:
if label not in r:
print("Label not in result:", label , " - skipping!")
return None, None
ax_tiny_histogram(ax[i, k], [rename_dict.get(l,l) for l in labels], colors, [r[label] for label in labels])
ax[i, k].set_title(f"$m_{{Z'}}$ = {mMed} GeV, $r_{{inv.}}$ = {rinv}")
#ax.set_title(f"$m_{mMed}$ GeV")
#ax.set_xlabel("$r_{inv}$")
#ax.set_ylabel("$m_{Z'}$ [GeV]")
#ax.set_xticks(range(len(r_invs)))
#ax.set_xticklabels(r_invs)
#ax.set_yticks(range(len(mediator_masses)))
#ax.set_yticklabels(mediator_masses)
fig.tight_layout()
return fig, ax
def matrix_plot(result, color_scheme, cbar_label, ax=None, metric_comp_func=None, is_qcd=False):
make_fig = ax is None
dark_masses = [20]
if is_qcd:
dark_masses = [0]
if make_fig:
fig, ax = plt.subplots(len(dark_masses), 1, figsize=(5, 5))
mediator_masses = sorted(list(result.keys()))
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
if len(dark_masses) == 1:
ax = [ax]
for i, mDark in enumerate(dark_masses):
data = np.zeros((len(mediator_masses), len(r_invs)))
for j, mMed in enumerate(mediator_masses):
for k, rinv in enumerate(r_invs):
if mMed not in result:
continue
if mDark not in result[mMed]:
continue
if rinv not in result[mMed][mDark]:
continue
r = result[mMed][mDark][rinv]
if metric_comp_func is not None:
try:
r = metric_comp_func(r)
except:
r=0
data[j, k] = r
ax[i].imshow(data, cmap="Blues")
for (j, k), val in np.ndenumerate(data):
ax[i].text(k, j, f'{val:.3f}', ha='center', va='center', color='black')
ax[i].set_xticks(range(len(r_invs)))
ax[i].set_xticklabels(r_invs)
ax[i].set_yticks(range(len(mediator_masses)))
ax[i].set_yticklabels(mediator_masses)
ax[i].set_xlabel("$r_{inv}$")
ax[i].set_ylabel("$m_{Z'}$ [GeV]")
#ax[i].set_title(f"mDark = {mDark} GeV")
if color_scheme.lower() == "greens":
# color it from 0 to 1.0 - set limits on the cbar
cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i])
else:
cbar = plt.colorbar(ax[i].imshow(data, cmap=color_scheme), ax=ax[i])
cbar.set_label(cbar_label)
if make_fig:
fig.tight_layout()
return fig
def scatter_plot(ax, xs, ys, label, color=None, pattern=".--"):
idx = np.argsort(xs)
xs = np.array(xs)[idx]
ys = np.array(ys)[idx]
if color is not None:
ax.plot(xs, ys, pattern, label=label, color=color)
else:
ax.plot(xs, ys, pattern, label=label, color=color)