jetclustering / scripts /metrics_plots_vs_pt_cutoff.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
7.35 kB
import os
from tqdm import tqdm
import argparse
import pickle
from src.plotting.eval_matrix import matrix_plot, scatter_plot
from src.utils.paths import get_path
import matplotlib.pyplot as plt
import numpy as np
# This script produces the pt cutoff vs. f1 score
inputs = {
30: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_30.0",
40: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_40.0",
50: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_50.0",
60: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_60.0",
70: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_70.0",
80: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_80.0",
90: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_90.0",
100: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset"
}
inputs = {
30: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_30.0",
40: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_40.0",
50: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_50.0",
60: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_60.0",
70: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_70.0",
80: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_80.0",
90: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_90.0",
100: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy",
}
'''
print("PLOTTING QCD")
inputs = {
30: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_30.0",
40: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_40.0",
50: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_50.0",
60: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_60.0",
70: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_70.0",
80: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_80.0",
90: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_90.0",
100: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD"
}
'''
files = {
key: pickle.load(open(os.path.join(get_path(value, "results"), "precision_recall.pkl"), "rb")) for key, value in inputs.items()
}
titles = {key: set(value.keys()) for key, value in files.items()}
# make a set of the intersections of titles
intersections = sorted(list(set.intersection(*titles.values())))
titles_to_plot = {
"AK, R=0.8": ["AK8", "gray"],
"GT_R=0.8 LGATr_GP_IRC_S_50k_s12900, sc. (aug)": ["LGATr_GP_IRC_S", "red"],
"GT_R=0.8 LGATr_GP_50k_s25020, sc. (aug)": ["LGATr_GP", "purple"],
"GT_R=0.8 base_LGATr_s50000, sc.": ["LGATr", "orange"]
} # To plot different variations of the model
print("QCD") # colors= [{"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}, {"base_LGATr": "LGATr_900_03"}],
titles_to_plot = {
"AK, R=0.8": ["AK8", "gray"],
"GT_R=0.8 base_LGATr_s50000, sc.": ["LGATr_900_03", "orange"],
"GT_R=0.8 LGATr_QCD_s50000, sc.": ["LGATr_QCD", "purple"],
"GT_R=0.8 LGATr_700_07_s50000, sc.": ["LGATr_700_07", "red"],
"GT_R=0.8 LGATr_700_07+900_03_s50000, sc.": ["LGATr_700_07+900_03", "blue"],
"GT_R=0.8 LGATr_700_07+900_03+QCD_s50000, sc.": ["LGATr_700_07+900_03+QCD", "green"],
}
titles_to_plot = {
"AK, R=0.8": ["AK8", "gray"],
"GT_R=0.8 LGATr_GP_IRC_S_50k_s12900, sc. (aug)": ["LGATr_900_03", "orange"],
"GT_R=0.8 LGATr_GP_IRC_S_QCD_s24000, sc. (aug)": ["LGATr_QCD", "purple"],
"GT_R=0.8 LGATr_GP_IRC_S_700_07_s24000, sc. (aug)": ["LGATr_700_07", "red"],
"GT_R=0.8 LGATr_GP_IRC_S_700_07+900_03_s24000, sc. (aug)": ["LGATr_700_07+900_03", "blue"],
"GT_R=0.8 LGATr_GP_IRC_S_700_07+900_03+QCD_s24000, sc. (aug)": ["LGATr_700_07+900_03+QCD", "green"],
}
intersections = sorted(list(titles_to_plot.keys()))
output_dirs = []
for _, value in inputs.items():
output_dirs.append(get_path(value, "results"))
result = files[100][intersections[0]]
mediator_masses = sorted(list(result.keys()))
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
sz = 4
#fig, ax = plt.subplots(len(inputs), len(titles_to_plot), figsize=(sz*len(titles_to_plot), sz*len(inputs)))
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses)))
figp, axp = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses)))
figr, axr = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses)))
if len(r_invs) == 1 and len(mediator_masses) == 1:
ax = np.array([[ax]])
axp = np.array([[axp]])
axr = np.array([[axr]])
grids = set()
for i, mMed in enumerate(mediator_masses):
for j, rInv in enumerate(r_invs):
for k, title in enumerate(intersections):
label, color = titles_to_plot[title]
pts = sorted(list(inputs.keys()))
precisions = []
recalls = []
f1_scores = []
for pt in pts:
precision, recall = files[pt][title][mMed][20][rInv]
precisions.append(precision)
recalls.append(recall)
f1_score = 2 * precision * recall / (precision + recall)
f1_scores.append(f1_score)
ax[i, j].plot(pts, f1_scores, ".-", label=label, color=color)
axp[i, j].plot(pts, precisions, ".-", label=label, color=color)
axr[i, j].plot(pts, recalls, ".-", label=label, color=color)
ax[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}")
ax[i, j].set_xlabel("$p_T^{cutoff}$")
axp[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}")
axp[i, j].set_xlabel("$p_T^{cutoff}$")
axr[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}")
axr[i, j].set_xlabel("$p_T^{cutoff}$")
ax[i, j].set_ylabel("$F_1$ score")
axp[i, j].set_ylabel("Precision")
axr[i, j].set_ylabel("Recall")
ax[i, j].legend()
axp[i, j].legend()
axr[i, j].legend()
if (i, j) not in grids:
ax[i, j].grid()
axp[i, j].grid()
axr[i, j].grid()
grids.add((i, j))
for f in output_dirs:
fig.tight_layout()
fname = os.path.join(f, "pt_cutoff_vs_f1_score.pdf")
fig.tight_layout()
fig.savefig(fname)
print("saved to", fname)
fname = os.path.join(f, "pt_cutoff_vs_precision.pdf")
figp.tight_layout()
figp.savefig(fname)
fname = os.path.join(f, "pt_cutoff_vs_recall.pdf")
figr.tight_layout()
figr.savefig(fname)