jetclustering / scripts /analysis /count_matched_quarks.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
37.9 kB
import os
from tqdm import tqdm
import argparse
import numpy as np
import pandas as pd
import pickle
import torch
import time
from src.utils.utils import CPU_Unpickler
from src.dataset.get_dataset import get_iter
from src.plotting.eval_matrix import matrix_plot
from src.utils.paths import get_path
from pathlib import Path
import matplotlib.pyplot as plt
from src.dataset.dataset import EventDataset
# This script attempts to open dataset files and prints the number of events in each one.
R = 0.8
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--augment-soft-particles", "-aug-soft", action="store_true")
parser.add_argument("--plot-only", action="store_true")
parser.add_argument("--jets-object", type=str, default="fatjets")
parser.add_argument("--eval-dir", type=str, default="")
parser.add_argument("--clustering-suffix", type=str, default="") # default: 1020, also want to try 1010 or others...?
parser.add_argument("--pt-jet-cutoff", type=float, default=100.0)
parser.add_argument("--high-eta-only", action="store_true") # eta > 1.5 quarks only
parser.add_argument("--low-eta-only", action="store_true") # eta < 1.5 quarks only
parser.add_argument("--parton-level", "-pl", action="store_true") # To be used together with 'fastjet_jets'
parser.add_argument("--gen-level", "-gl", action="store_true")
args = parser.parse_args()
path = get_path(args.input, "preprocessed_data")
import wandb
api = wandb.Api()
def get_run_by_name(name):
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
if runs.length != 1:
return None
return runs[0]
def resolve_preproc_data_path(path):
rel_path = path.split("/preprocessed_data/")[-1]
return get_path(rel_path, "preprocessed_data")
if args.eval_dir:
eval_dir = get_path(args.eval_dir, "results", fallback=True)
dataset_path_to_eval_file = {}
top_folder_name = eval_dir.split("/")[-1]
config = get_run_by_name(top_folder_name).config
for file in os.listdir(eval_dir):
if file.startswith("eval_") and file.endswith(".pkl"):
file_number = file.split("_")[1].split(".")[0]
clustering_file = "clustering_{}.pkl".format(file_number)
if args.clustering_suffix:
clustering_file = "clustering_{}_{}.pkl".format(args.clustering_suffix, file_number)
f = CPU_Unpickler(open(os.path.join(eval_dir, file), "rb")).load()
clustering_file = os.path.join(eval_dir, clustering_file)
if "model_cluster" in f and not args.clustering_suffix:
clustering_file = None
dataset_path_to_eval_file[resolve_preproc_data_path(f["filename"])] = [os.path.join(eval_dir, file), clustering_file]
print(dataset_path_to_eval_file)
if args.output == "":
args.output = args.input
output_path = os.path.join(get_path(args.output, "results"), "count_matched_quarks")
Path(output_path).mkdir(parents=True, exist_ok=True)
def get_bc_scores_for_jets(event):
scores = event.pfcands.bc_scores_pfcands
clusters = event.pfcands.bc_labels_pfcands
selected_clusters_idx = torch.where(event.model_jets.pt > 100)[0]
result = []
for c in selected_clusters_idx:
result.append(scores[clusters == c.item()])
return result
def calculate_m(objects, mt=False):
# set a mask returning only the two highest pt jets
mask = objects.pt.argsort(descending=True)[:2]
total_E = objects.E[mask].sum()
total_pxyz = objects.pxyz[mask].sum(dim=0)
if mt:
return np.sqrt(total_E**2 - total_pxyz[0]**2 - total_pxyz[1]**2).item()
return np.sqrt(total_E**2 - total_pxyz[2]**2 - total_pxyz[1]**2 - total_pxyz[0]**2).item()
thresholds = np.linspace(0.1, 1, 20)
# also add 100 points between 0 and 0.1 at the beginning
thresholds = np.concatenate([np.linspace(0, 0.1, 100), thresholds])
def get_mc_gt_per_event(event):
# get the monte carlo GT pt for the event. This is pt of the particles closer than 0.8 to each of the dark quarks
result = []
dq = [event.matrix_element_gen_particles.eta, event.matrix_element_gen_particles.phi]
for i in range(len(dq[0])):
dq_coords = [dq[0][i], dq[1][i]]
cone_filter = torch.sqrt((event.pfcands.eta - dq_coords[0])**2 + (event.pfcands.phi - dq_coords[1])**2) < 0.8
#cone_filter_special = torch.sqrt(
# (event.special_pfcands.eta - dq_coords[0]) ** 2 + (event.special_pfcands.phi - dq_coords[1]) ** 2) < R
eta_cone, phi_cone, pt_cone = event.pfcands.eta[cone_filter], event.pfcands.phi[cone_filter], event.pfcands.pt[cone_filter]
px_cone = torch.sum(pt_cone * np.cos(phi_cone))
py_cone = torch.sum(pt_cone * np.sin(phi_cone))
pz_cone = torch.sum(pt_cone * np.sinh(eta_cone))
pt_cone = torch.sqrt(px_cone**2 + py_cone**2)
result.append(pt_cone.item())
return result
if not args.plot_only:
n_matched_quarks = {}
unmatched_quarks = {}
n_fake_jets = {} # Number of jets that have not been matched to a quark
bc_scores_matched = {}
bc_scores_unmatched = {}
precision_and_recall = {} # Array of [n_relevant_retrieved, all_retrieved, all_relevant], or in our language, [n_matched_dark_quarks, n_jets, n_dark_quarks]
precision_and_recall_fastjets = {}
pr_obj_score_thresholds = {} # same as precision_and_recall, except it gives a dictionary instead of the array, and the keys are the thresholds for objectness score
mass_resolution = {} # Contains {'m_true': [], 'm_pred': [], 'mt_true': [], 'mt_pred': []} # mt = transverse mass, m = invariant mass
matched_jet_properties = {} # contains {'pt_gen_particle': [], 'pt_mc_truth': [], 'pt_pred': [], 'eta_gen_particle': [], 'eta_mc_truth': [], 'eta_pred': [], 'phi_gen_particle': [], 'phi_mc_truth': [], 'phi_pred': []}
matched_jet_properties_fastjets = {}
is_dq_matched_per_event = {}
dq_pt_per_event = {}
gt_pt_per_event = {}
gt_props_per_event = {"eta": {}, "phi": {}}
print("LISTING DIRECTORY", path, ":", os.listdir(path))
for subdataset in os.listdir(path):
print("-----", subdataset, "-----")
current_path = os.path.join(path, subdataset)
model_clusters_file = None
model_output_file = None
if subdataset not in precision_and_recall:
precision_and_recall[subdataset] = [0, 0, 0]
precision_and_recall_fastjets[subdataset] = {}
matched_jet_properties_fastjets[subdataset] = {}
is_dq_matched_per_event[subdataset] = []
dq_pt_per_event[subdataset] = []
gt_pt_per_event[subdataset] = []
if args.jets_object == "fastjet_jets":
is_dq_matched_per_event[subdataset] = {}
dq_pt_per_event[subdataset] = {}
gt_pt_per_event[subdataset] = {}
for key in gt_props_per_event:
if subdataset not in gt_props_per_event[key]:
gt_props_per_event[key][subdataset] = {}
else:
for key in gt_props_per_event:
if subdataset not in gt_props_per_event[key]:
gt_props_per_event[key][subdataset] = []
pr_obj_score_thresholds[subdataset] = {}
for i in range(len(thresholds)):
pr_obj_score_thresholds[subdataset][i] = [0, 0, 0]
if subdataset not in mass_resolution:
mass_resolution[subdataset] = {'m_true': [], 'm_pred': [], 'mt_true': [], 'mt_pred': [], 'n_jets': []}
if args.eval_dir:
if current_path not in dataset_path_to_eval_file:
print("Skipping", current_path)
print(dataset_path_to_eval_file)
continue
model_clusters_file = dataset_path_to_eval_file[current_path][1]
model_output_file = dataset_path_to_eval_file[current_path][0]
#dataset = get_iter(current_path, model_clusters_file=model_clusters_file, model_output_file=model_output_file,
# include_model_jets_unfiltered=True)
fastjet_R = None
if args.jets_object == "fastjet_jets":
fastjet_R = np.array([0.8])
config = {"parton_level": args.parton_level, "gen_level": args.gen_level}
print("Config:", config)
dataset = EventDataset.from_directory(current_path, model_clusters_file=model_clusters_file,
model_output_file=model_output_file,
include_model_jets_unfiltered=True, fastjet_R=fastjet_R,
parton_level=config.get("parton_level", False), gen_level=config.get("gen_level", False),
aug_soft=args.augment_soft_particles, seed=1000000, pt_jet_cutoff=args.pt_jet_cutoff)
n = 0
for x in tqdm(range(len(dataset))):
data = dataset[x]
if data is None:
print("Skipping", x)
continue
#try:
# data = dataset[x]
#except:
# print("Exception")
# break # skip this event
jets_object = data.__dict__[args.jets_object]
n += 1
if args.dataset_cap != -1 and n > args.dataset_cap:
break
if args.high_eta_only and torch.max(torch.abs(data.matrix_element_gen_particles.eta)) < 1.5:
continue
if args.low_eta_only and torch.max(torch.abs(data.matrix_element_gen_particles.eta)) > 1.5:
continue
if not args.jets_object == "fastjet_jets":
jets = [jets_object.eta, jets_object.phi]
dq = [data.matrix_element_gen_particles.eta, data.matrix_element_gen_particles.phi]
# calculate deltaR between each jet and each quark
distance_matrix = np.zeros((len(jets_object), len(data.matrix_element_gen_particles)))
for i in range(len(jets_object)):
for j in range(len(data.matrix_element_gen_particles)):
deta = jets[0][i] - dq[0][j]
dphi = abs(jets[1][i] - dq[1][j])
if dphi > np.pi:
dphi -= 2 * np.pi #- dphi
distance_matrix[i, j] = np.sqrt(deta**2 + dphi**2)
# row-wise argmin
distance_matrix = distance_matrix.T
#min_distance = np.min(distance_matrix, axis=1)
n_jets = len(jets_object)
precision_and_recall[subdataset][1] += n_jets
precision_and_recall[subdataset][2] += len(data.matrix_element_gen_particles)
if "obj_score" in jets_object.__dict__:
print("Also evaluating using objectness score")
for i in range(len(thresholds)):
filt = torch.sigmoid(jets_object.obj_score) >= thresholds[i]
pr_obj_score_thresholds[subdataset][i][1] += torch.sum(filt).item()
pr_obj_score_thresholds[subdataset][i][2] += len(data.matrix_element_gen_particles)
mass_resolution[subdataset]['m_true'].append(calculate_m(data.matrix_element_gen_particles))
mass_resolution[subdataset]['m_pred'].append(calculate_m(jets_object))
mass_resolution[subdataset]['mt_true'].append(calculate_m(data.matrix_element_gen_particles, mt=True))
mass_resolution[subdataset]['mt_pred'].append(calculate_m(jets_object, mt=True))
mass_resolution[subdataset]['n_jets'].append(n_jets)
if len(jets_object):
if subdataset not in matched_jet_properties:
matched_jet_properties[subdataset] = {'pt_gen_particle': [], 'pt_mc_truth': [], 'pt_pred': [],
'eta_gen_particle': [], 'eta_pred': [],
'phi_gen_particle': [], 'phi_pred': []}
quark_to_jet = np.min(distance_matrix, axis=1)
quark_to_jet_idx = np.argmin(distance_matrix, axis=1)
quark_to_jet[quark_to_jet > R] = -1
n_matched_quarks[subdataset] = n_matched_quarks.get(subdataset, []) + [np.sum(quark_to_jet != -1)]
n_fake_jets[subdataset] = n_fake_jets.get(subdataset, []) + [n_jets - np.sum(quark_to_jet != -1)]
f = quark_to_jet != -1
matched_jet_properties[subdataset]["pt_gen_particle"] += data.matrix_element_gen_particles.pt[f].tolist()
matched_jet_properties[subdataset]["pt_pred"] += jets_object.pt[quark_to_jet_idx[f]].tolist()
matched_jet_properties[subdataset]["eta_gen_particle"] += data.matrix_element_gen_particles.eta[f].tolist()
matched_jet_properties[subdataset]["eta_pred"] += jets_object.eta[quark_to_jet_idx[f]].tolist()
matched_jet_properties[subdataset]["phi_gen_particle"] += data.matrix_element_gen_particles.phi[f].tolist()
matched_jet_properties[subdataset]["phi_pred"] += jets_object.phi[quark_to_jet_idx[f]].tolist()
precision_and_recall[subdataset][0] += np.sum(quark_to_jet != -1)
if "obj_score" in jets_object.__dict__:
for i in range(len(thresholds)):
filt = torch.sigmoid(jets_object.obj_score) >= thresholds[i]
dist_matrix_filt = distance_matrix[:, filt.numpy()]
if filt.sum() == 0:
continue
quark_to_jet_filt = np.min(dist_matrix_filt, axis=1)
quark_to_jet_filt[quark_to_jet_filt > R] = -1
pr_obj_score_thresholds[subdataset][i][0] += np.sum(quark_to_jet_filt != -1)
filt = quark_to_jet == -1
#if args.jets_object == "model_jets":
#matched_jet_idx = sorted(np.argmin(distance_matrix, axis=1)[quark_to_jet != -1])
#unmatched_jet_idx = sorted(list(set(list(range(n_jets))) - set(matched_jet_idx)))
#scores = get_bc_scores_for_jets(data)
#for i in matched_jet_idx:
# bc_scores_matched[subdataset] = bc_scores_matched.get(subdataset, []) + [torch.mean(scores[i]).item()]
#for i in unmatched_jet_idx:
# bc_scores_unmatched[subdataset] = bc_scores_unmatched.get(subdataset, []) + [torch.mean(scores[i]).item()]
else:
n_matched_quarks[subdataset] = n_matched_quarks.get(subdataset, []) + [0]
n_fake_jets[subdataset] = n_fake_jets.get(subdataset, []) + [n_jets]
filt = torch.ones(len(data.matrix_element_gen_particles)).bool()
quark_to_jet = torch.ones(len(data.matrix_element_gen_particles)).long() * -1
is_dq_matched_per_event[subdataset].append(quark_to_jet.tolist())
dq_pt_per_event[subdataset].append(data.matrix_element_gen_particles.pt.tolist())
gt_pt_per_event[subdataset].append(get_mc_gt_per_event(data))
gt_props_per_event["eta"][subdataset].append(data.matrix_element_gen_particles.eta.tolist())
gt_props_per_event["phi"][subdataset].append(data.matrix_element_gen_particles.phi.tolist())
if subdataset not in unmatched_quarks:
unmatched_quarks[subdataset] = {"pt": [], "eta": [], "phi": [], "pt_all": [], "frac_evt_E_matched": [], "frac_evt_E_unmatched": []}
unmatched_quarks[subdataset]["pt"] += data.matrix_element_gen_particles.pt[filt].tolist()
unmatched_quarks[subdataset]["pt_all"] += data.matrix_element_gen_particles.pt.tolist()
unmatched_quarks[subdataset]["eta"] += data.matrix_element_gen_particles.eta[filt].tolist()
unmatched_quarks[subdataset]["phi"] += data.matrix_element_gen_particles.phi[filt].tolist()
visible_E_event = torch.sum(data.pfcands.E) #+ torch.sum(data.special_pfcands.E)
matched_quarks = np.where(quark_to_jet != -1)[0]
for i in range(len(data.matrix_element_gen_particles)):
dq_coords = [dq[0][i], dq[1][i]]
cone_filter = torch.sqrt((data.pfcands.eta - dq_coords[0])**2 + (data.pfcands.phi - dq_coords[1])**2) < R
#cone_filter_special = torch.sqrt(
# (data.special_pfcands.eta - dq_coords[0]) ** 2 + (data.special_pfcands.phi - dq_coords[1]) ** 2) < R
E_in_cone = data.pfcands.E[cone_filter].sum()# + data.special_pfcands.E[cone_filter_special].sum()
if i in matched_quarks:
unmatched_quarks[subdataset]["frac_evt_E_matched"].append(E_in_cone / visible_E_event)
else:
unmatched_quarks[subdataset]["frac_evt_E_unmatched"].append(E_in_cone / visible_E_event)
#print("Number of matched quarks:", np.sum(quark_to_jet != -1))
else:
for key in jets_object:
jets = [jets_object[key].eta, jets_object[key].phi]
dq = [data.matrix_element_gen_particles.eta, data.matrix_element_gen_particles.phi]
# calculate deltaR between each jet and each quark
distance_matrix = np.zeros((len(jets_object[key]), len(data.matrix_element_gen_particles)))
for i in range(len(jets_object[key])):
for j in range(len(data.matrix_element_gen_particles)):
deta = jets[0][i] - dq[0][j]
dphi = abs(jets[1][i] - dq[1][j])
if dphi > np.pi:
dphi -= 2 * np.pi
#elif dphi < -np.pi:
# dphi += 2 * np.pi
assert abs(dphi) <= np.pi, "dphi is not in [-pi, pi] range: {}".format(dphi)
distance_matrix[i, j] = np.sqrt(deta ** 2 + dphi ** 2)
# Row-wise argmin
distance_matrix = distance_matrix.T
# min_distance = np.min(distance_matrix, axis=1)
n_jets = len(jets_object[key])
if key not in precision_and_recall_fastjets[subdataset]:
precision_and_recall_fastjets[subdataset][key] = [0, 0, 0]
if key not in matched_jet_properties_fastjets[subdataset]:
is_dq_matched_per_event[subdataset][key] = []
dq_pt_per_event[subdataset][key] = []
gt_pt_per_event[subdataset][key] = []
for prop in gt_props_per_event:
if key not in gt_props_per_event[prop][subdataset]:
gt_props_per_event[prop][subdataset][key] = []
matched_jet_properties_fastjets[subdataset][key] = {"pt_gen_particle": [], "pt_pred": [],
"eta_gen_particle": [], "eta_pred": [],
"phi_gen_particle": [], "phi_pred": []}
precision_and_recall_fastjets[subdataset][key][1] += n_jets
precision_and_recall_fastjets[subdataset][key][2] += len(data.matrix_element_gen_particles)
if len(jets_object[key]):
quark_to_jet = np.min(distance_matrix, axis=1)
quark_to_jet_idx = np.argmin(distance_matrix, axis=1)
quark_to_jet[quark_to_jet > R] = -1
precision_and_recall_fastjets[subdataset][key][0] += np.sum(quark_to_jet != -1)
f = quark_to_jet != -1
matched_jet_properties_fastjets[subdataset][key]["pt_gen_particle"] += data.matrix_element_gen_particles.pt[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["pt_pred"] += jets_object[key].pt[quark_to_jet_idx[f]].tolist()
matched_jet_properties_fastjets[subdataset][key]["eta_gen_particle"] += data.matrix_element_gen_particles.eta[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["eta_pred"] += jets_object[key].eta[quark_to_jet_idx[f]].tolist()
matched_jet_properties_fastjets[subdataset][key]["phi_gen_particle"] += data.matrix_element_gen_particles.phi[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["phi_pred"] += jets_object[key].phi[quark_to_jet_idx[f]].tolist()
else:
quark_to_jet = torch.ones(len(data.matrix_element_gen_particles)).long() * -1
is_dq_matched_per_event[subdataset][key].append(quark_to_jet.tolist())
dq_pt_per_event[subdataset][key].append(data.matrix_element_gen_particles.pt.tolist())
gt_pt_per_event[subdataset][key].append(get_mc_gt_per_event(data))
gt_props_per_event["eta"][subdataset][key].append(data.matrix_element_gen_particles.eta.tolist())
gt_props_per_event["phi"][subdataset][key].append(data.matrix_element_gen_particles.phi.tolist())
avg_n_matched_quarks = {}
avg_n_fake_jets = {}
for key in n_matched_quarks:
avg_n_matched_quarks[key] = np.mean(n_matched_quarks[key])
avg_n_fake_jets[key] = np.mean(n_fake_jets[key])
def get_properties(name):
if "qcd" in name.lower():
print("QCD file! Not using mMed, mDark, rinv")
return 0, 0, 0
# get mediator mass, dark quark mass, r_inv from the filename
parts = name.strip().strip("/").split("/")[-1].split("_")
try:
mMed = int(parts[1].split("-")[1])
mDark = int(parts[2].split("-")[1])
rinv = float(parts[3].split("-")[1])
except:
# another convention
mMed = int(parts[2].split("-")[1])
mDark = int(parts[3].split("-")[1])
rinv = float(parts[4].split("-")[1])
return mMed, mDark, rinv
result = {}
result_unmatched = {}
result_fakes = {}
result_bc = {}
result_PR = {}
result_PR_AKX = {}
result_PR_thresholds = {}
result_m = {}
result_jet_properties = {}
result_jet_properties_AKX = {}
result_quark_to_jet ={}
result_pt_mc_gt = {}
result_pt_dq = {}
result_props_dq = {"eta": {}, "phi": {}}
if args.jets_object != "fastjet_jets":
for key in avg_n_matched_quarks:
mMed, mDark, rinv = get_properties(key)
if mMed not in result:
result[mMed] = {}
result_unmatched[mMed] = {}
result_fakes[mMed] = {}
result_bc[mMed] = {}
result_PR[mMed] = {}
result_PR_AKX[mMed] = {}
result_PR_thresholds[mMed] = {}
result_m[mMed] = {}
result_jet_properties[mMed] = {}
result_jet_properties_AKX[mMed] = {}
result_quark_to_jet[mMed] = {}
result_pt_mc_gt[mMed] = {}
result_pt_dq[mMed] = {}
for prop in gt_props_per_event:
if mMed not in result_props_dq[prop]:
result_props_dq[prop][mMed] = {}
if mDark not in result[mMed]:
result[mMed][mDark] = {}
result_unmatched[mMed][mDark] = {}
result_fakes[mMed][mDark] = {}
result_bc[mMed][mDark] = {}
result_PR[mMed][mDark] = {}
result_PR_thresholds[mMed][mDark] = {}
result_PR_AKX[mMed][mDark] = {}
result_m[mMed][mDark] = {}
result_jet_properties[mMed][mDark] = {}
result_jet_properties_AKX[mMed][mDark] = {}
result_quark_to_jet[mMed][mDark] = {}
result_pt_mc_gt[mMed][mDark] = {}
result_pt_dq[mMed][mDark] = {}
for prop in gt_props_per_event:
if mDark not in result_props_dq[prop][mMed]:
result_props_dq[prop][mMed][mDark] = {}
result[mMed][mDark][rinv] = avg_n_matched_quarks[key]
result_unmatched[mMed][mDark][rinv] = unmatched_quarks[key]
result_fakes[mMed][mDark][rinv] = avg_n_fake_jets[key]
result_jet_properties[mMed][mDark][rinv] = matched_jet_properties[key]
result_quark_to_jet[mMed][mDark][rinv] = is_dq_matched_per_event[key]
result_pt_mc_gt[mMed][mDark][rinv] = gt_pt_per_event[key]
result_pt_dq[mMed][mDark][rinv] = dq_pt_per_event[key]
for prop in gt_props_per_event:
result_props_dq[prop][mMed][mDark][rinv] = gt_props_per_event[prop][key]
#result_bc[mMed][mDark][rinv] = {
# "matched": bc_scores_matched[key],
# "unmatched": bc_scores_unmatched[key]
#}
result_PR_thresholds[mMed][mDark][rinv] = pr_obj_score_thresholds[key]
if precision_and_recall[key][1] == 0 or precision_and_recall[key][2] == 0:
result_PR[mMed][mDark][rinv] = [0, 0]
print(mMed, mDark, rinv)
print("PR zero", key, precision_and_recall[key])
else:
result_PR[mMed][mDark][rinv] = [precision_and_recall[key][0] / precision_and_recall[key][1], precision_and_recall[key][0] / precision_and_recall[key][2]]
result_m[mMed][mDark][rinv] = {key: np.array(val) for key, val in mass_resolution[key].items()}
if args.jets_object == "fastjet_jets":
r = precision_and_recall_fastjets[key]
if rinv not in result_PR_AKX[mMed][mDark]:
result_PR_AKX[mMed][mDark][rinv] = {}
for k in r:
if r[k][1] == 0 or r[k][2] == 0:
result_PR_AKX[mMed][mDark][rinv][k] = [0, 0]
else:
result_PR_AKX[mMed][mDark][rinv][k] = [r[k][0] / r[k][1], r[k][0] / r[k][2]]
else:
for key in precision_and_recall_fastjets: # key=radius of AK
mMed, mDark, rinv = get_properties(key)
if mMed not in result_PR_AKX:
result_PR_AKX[mMed] = {}
result_jet_properties_AKX[mMed] = {}
result_quark_to_jet[mMed] = {}
result_pt_mc_gt[mMed] = {}
result_pt_dq[mMed] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed] = {}
if mDark not in result_PR_AKX[mMed]:
result_PR_AKX[mMed][mDark] = {}
result_jet_properties_AKX[mMed][mDark] = {}
result_quark_to_jet[mMed][mDark] = {}
result_pt_mc_gt[mMed][mDark] = {}
result_pt_dq[mMed][mDark] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark] = {}
r = precision_and_recall_fastjets[key]
if rinv not in result_PR_AKX[mMed][mDark]:
result_PR_AKX[mMed][mDark][rinv] = {}
result_jet_properties_AKX[mMed][mDark][rinv] = {}
result_quark_to_jet[mMed][mDark][rinv] = {}
result_pt_mc_gt[mMed][mDark][rinv] = {}
result_pt_dq[mMed][mDark][rinv] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark][rinv] = {}
for k in r:
result_quark_to_jet[mMed][mDark][rinv][k] = is_dq_matched_per_event[key][k]
result_pt_mc_gt[mMed][mDark][rinv][k] = gt_pt_per_event[key][k]
result_pt_dq[mMed][mDark][rinv][k] = dq_pt_per_event[key][k]
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark][rinv][k] = gt_props_per_event[prop][key][k]
result_jet_properties_AKX[mMed][mDark][rinv][k] = matched_jet_properties_fastjets[key][k]
if r[k][1] == 0 or r[k][2] == 0:
result_PR_AKX[mMed][mDark][rinv][k] = [0, 0]
else:
result_PR_AKX[mMed][mDark][rinv][k] = [r[k][0] / r[k][1], r[k][0] / r[k][2]]
pickle.dump(result_quark_to_jet, open(os.path.join(output_path, "result_quark_to_jet.pkl"), "wb"))
pickle.dump(result_pt_mc_gt, open(os.path.join(output_path, "result_pt_mc_gt.pkl"), "wb"))
pickle.dump(result_pt_dq, open(os.path.join(output_path, "result_pt_dq.pkl"), "wb"))
pickle.dump(result, open(os.path.join(output_path, "result.pkl"), "wb"))
pickle.dump(result_unmatched, open(os.path.join(output_path, "result_unmatched.pkl"), "wb"))
pickle.dump(result_fakes, open(os.path.join(output_path, "result_fakes.pkl"), "wb"))
pickle.dump(result_bc, open(os.path.join(output_path, "result_bc.pkl"), "wb"))
pickle.dump(result_props_dq, open(os.path.join(output_path, "result_props_dq.pkl"), "wb"))
if args.jets_object == "fastjet_jets":
pickle.dump(result_PR_AKX, open(os.path.join(output_path, "result_PR_AKX.pkl"), "wb"))
pickle.dump(result_jet_properties_AKX, open(os.path.join(output_path, "result_jet_properties_AKX.pkl"), "wb"))
pickle.dump(result_PR, open(os.path.join(output_path, "result_PR.pkl"), "wb"))
pickle.dump(result_PR_thresholds, open(os.path.join(output_path, "result_PR_thresholds.pkl"), "wb"))
pickle.dump(result_m, open(os.path.join(output_path, "result_m.pkl"), "wb"))
pickle.dump(result_jet_properties, open(os.path.join(output_path, "result_jet_properties.pkl"), "wb"))
with open(os.path.join(output_path, "eval_done.txt"), "w") as f:
f.write("True")
# Write the number of events to n_events.txt
with open(os.path.join(output_path, "n_events.txt"), "w") as f:
f.write(str(n))
if args.plot_only:
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb"))
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
result_PR_thresholds = pickle.load(open(os.path.join(output_path, "result_PR_thresholds.pkl"), "rb"))
if args.jets_object == "fastjet_jets":
print("Only computing fastjet jets - exiting now, the metrics have been saved to disk")
import sys
sys.exit(0)
fig, ax = plt.subplots(3, 1, figsize=(4, 12))
def get_plots_for_params(mMed, mDark, rInv):
precisions = []
recalls = []
f1_scores = []
for i in range(len(thresholds)):
if result_PR_thresholds[mMed][mDark][rInv][i][1] == 0:
precisions.append(0)
else:
precisions.append(result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][1])
if result_PR_thresholds[mMed][mDark][rInv][i][2] == 0:
recalls.append(0)
else:
recalls.append(result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][2])
for i in range(len(thresholds)):
if precisions[i] + recalls[i] == 0:
f1_scores.append(0)
else:
f1_scores.append(2*precisions[i]*recalls[i] / (precisions[i] + recalls[i]))
return precisions, recalls, f1_scores
def plot_for_params(a, b, c):
precisions, recalls, f1_scores = get_plots_for_params(a, b, c)
ax[0].plot(thresholds, precisions, ".--", label=f"mMed={a},rInv={c}")
ax[1].plot(thresholds, recalls, ".--", label=f"mMed={a},rInv={c}")
ax[2].plot(thresholds, f1_scores, ".--", label=f"mMed={a},rInv={c}")
if "qcd" in args.input.lower():
print("QCD dataset - not plotting thresholds")
import sys
sys.exit(0)
plot_for_params(900, 20, 0.3)
plot_for_params(700, 20, 0.7)
#plot_for_params(3000, 20, 0.3)
plot_for_params(900, 20, 0.7)
plot_for_params(1000, 20, 0.3)
ax[0].grid()
ax[1].grid()
ax[2].grid()
ax[0].set_ylabel("Precision")
ax[1].set_ylabel("Recall")
ax[2].set_ylabel("F1 score")
ax[0].legend()
ax[1].legend()
ax[2].legend()
ax[0].set_xscale("log")
ax[1].set_xscale("log")
ax[2].set_xscale("log")
fig.tight_layout()
fig.savefig(os.path.join(output_path, "pr_thresholds.pdf"))
matrix_plot(result, "Blues", "Avg. matched dark quarks / event").savefig(os.path.join(output_path, "avg_matched_dark_quarks.pdf"))
matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event").savefig(os.path.join(output_path, "avg_unmatched_jets.pdf"))
matrix_plot(result_PR, "Reds", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0]).savefig(os.path.join(output_path, "precision.pdf"))
matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1]).savefig(os.path.join(output_path, "recall.pdf"))
matrix_plot(result_PR, "Purples", "F_1 score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1])).savefig(os.path.join(output_path, "f1_score.pdf"))
dark_masses = [20]
mediator_masses = sorted(list(result.keys()))
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["pt"]
ax[i, j].hist(data, bins=50, histtype="step", label="Unmatched")
ax[i, j].hist(result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["pt_all"], bins=50, histtype="step", label="All")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("pt")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "unmatched_dark_quarks_pt.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data_x = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["eta"]
data_y = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["phi"]
# 2d histogram
ax[i, j].hist2d(data_x, data_y, bins=10, cmap="Blues")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("unmatched dark quark eta")
ax[i, j].set_ylabel("unmatched dark quark phi")
fig.tight_layout()
fig.savefig(os.path.join(output_path, "unmatched_dark_quarks_eta_phi.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_matched"]
data_unmatched = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_unmatched"]
bins = np.linspace(0, 1, 100)
ax[i, j].hist(data_unmatched, bins=bins, histtype="step", label="Unmatched")
ax[i, j].hist(data, bins=bins, histtype="step", label="Matched")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("E (R<0.8) / event E")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "frac_E_in_cone.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_matched"]
data_unmatched = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_unmatched"]
bins = np.linspace(0, 1, 100)
ax[i, j].hist(data_unmatched, bins=bins, histtype="step", label="Unmatched dark quark", density=True)
ax[i, j].hist(data, bins=bins, histtype="step", label="Matched dark quark", density=True)
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("E (R<0.8) / event E")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "frac_E_in_cone_density.pdf"))
'''
fig, ax = plt.subplots(figsize=(5, 5))
unmatched = result_bc[900][20][0.3]["unmatched"]
matched = result_bc[900][20][0.3]["matched"]
bins = np.linspace(0, 1, 100)
ax.hist(unmatched, bins=bins, histtype="step", label="Unmatched jet")
ax.hist(matched, bins=bins, histtype="step", label="Matched jet")
ax.set_title("mMed = 900, mDark = 20, rinv = 0.3")
ax.set_xlabel("BC score")
ax.set_ylabel("count")
ax.set_yscale("log")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "avg_scores_matched_vs_unmatched_jet.pdf"))
'''