|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import pickle |
|
|
|
import numpy as np |
|
|
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
|
|
from scipy.cluster.hierarchy import linkage |
|
|
|
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task |
|
|
|
from src.envs import QUEUE_REPO, RESULTS_REPO, API |
|
from src.utils import my_snapshot_download |
|
|
|
|
|
def is_float(string): |
|
try: |
|
float(string) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
|
|
def find_json_files(json_path): |
|
res = [] |
|
for root, dirs, files in os.walk(json_path): |
|
for file in files: |
|
if file.endswith(".json"): |
|
res.append(os.path.join(root, file)) |
|
return res |
|
|
|
|
|
def sanitise_metric(name: str) -> str: |
|
res = name |
|
res = res.replace("prompt_level_strict_acc", "Prompt-Level Accuracy") |
|
res = res.replace("acc", "Accuracy") |
|
res = res.replace("exact_match", "EM") |
|
res = res.replace("avg-selfcheckgpt", "AVG") |
|
res = res.replace("max-selfcheckgpt", "MAX") |
|
res = res.replace("rouge", "ROUGE-") |
|
res = res.replace("bertscore_precision", "BERT-P") |
|
res = res.replace("exact", "EM") |
|
res = res.replace("HasAns_EM", "HasAns") |
|
res = res.replace("NoAns_EM", "NoAns") |
|
res = res.replace("em", "EM") |
|
return res |
|
|
|
|
|
def sanitise_dataset(name: str) -> str: |
|
res = name |
|
res = res.replace("tqa8", "TriviaQA (8-shot)") |
|
res = res.replace("nq8", "NQ (8-shot)") |
|
res = res.replace("nq_open", "NQ (64-shot)") |
|
res = res.replace("triviaqa", "TriviaQA (64-shot)") |
|
res = res.replace("truthfulqa", "TruthfulQA") |
|
res = res.replace("ifeval", "IFEval") |
|
res = res.replace("selfcheckgpt", "SelfCheckGPT") |
|
res = res.replace("truefalse_cieacf", "True-False") |
|
res = res.replace("mc", "MC") |
|
res = res.replace("race", "RACE") |
|
res = res.replace("squad", "SQuAD") |
|
res = res.replace("memo-trap", "MemoTrap") |
|
res = res.replace("cnndm", "CNN/DM") |
|
res = res.replace("xsum", "XSum") |
|
res = res.replace("qa", "QA") |
|
res = res.replace("summarization", "Summarization") |
|
res = res.replace("dialogue", "Dialog") |
|
res = res.replace("halueval", "HaluEval") |
|
res = res.replace("_v2", "") |
|
res = res.replace("_", " ") |
|
return res |
|
|
|
|
|
cache_file = "data_map_cache.pkl" |
|
|
|
|
|
def load_data_map_from_cache(cache_file): |
|
if os.path.exists(cache_file): |
|
with open(cache_file, "rb") as f: |
|
return pickle.load(f) |
|
else: |
|
return None |
|
|
|
|
|
def save_data_map_to_cache(data_map, cache_file): |
|
with open(cache_file, "wb") as f: |
|
pickle.dump(data_map, f) |
|
|
|
|
|
|
|
data_map = load_data_map_from_cache(cache_file) |
|
|
|
|
|
if data_map is None: |
|
my_snapshot_download( |
|
repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60 |
|
) |
|
my_snapshot_download( |
|
repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60 |
|
) |
|
|
|
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND) |
|
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND) |
|
|
|
model_name_to_model_map = {} |
|
|
|
for path in request_path_lst: |
|
with open(path, "r") as f: |
|
data = json.load(f) |
|
model_name_to_model_map[data["model"]] = data |
|
|
|
model_dataset_metric_to_result_map = {} |
|
|
|
|
|
data_map = {} |
|
|
|
for path in result_path_lst: |
|
with open(path, "r") as f: |
|
data = json.load(f) |
|
model_name = data["config"]["model_name"] |
|
for dataset_name, results_dict in data["results"].items(): |
|
for metric_name, value in results_dict.items(): |
|
|
|
if model_name_to_model_map[model_name]["likes"] > 128: |
|
|
|
to_add = True |
|
|
|
if "f1" in metric_name: |
|
to_add = False |
|
|
|
if "stderr" in metric_name: |
|
to_add = False |
|
|
|
if "memo-trap_v2" in dataset_name: |
|
to_add = False |
|
|
|
if "faithdial" in dataset_name: |
|
to_add = False |
|
|
|
if "truthfulqa_gen" in dataset_name: |
|
to_add = False |
|
|
|
if "bertscore" in metric_name: |
|
if "precision" not in metric_name: |
|
to_add = False |
|
|
|
if "halueval" in dataset_name: |
|
if "acc" not in metric_name: |
|
to_add = False |
|
|
|
if "ifeval" in dataset_name: |
|
if "prompt_level_strict_acc" not in metric_name: |
|
to_add = False |
|
|
|
if "squad" in dataset_name: |
|
|
|
if "best_exact" in metric_name: |
|
to_add = False |
|
|
|
if "fever" in dataset_name: |
|
to_add = False |
|
|
|
if ("xsum" in dataset_name or "cnn" in dataset_name) and "v2" not in dataset_name: |
|
to_add = False |
|
|
|
if isinstance(value, str): |
|
if is_float(value): |
|
value = float(value) |
|
else: |
|
to_add = False |
|
|
|
if to_add: |
|
if "rouge" in metric_name: |
|
value /= 100.0 |
|
|
|
if "squad" in dataset_name: |
|
value /= 100.0 |
|
|
|
sanitised_metric_name = metric_name |
|
if "," in sanitised_metric_name: |
|
sanitised_metric_name = sanitised_metric_name.split(",")[0] |
|
sanitised_metric_name = sanitise_metric(sanitised_metric_name) |
|
sanitised_dataset_name = sanitise_dataset(dataset_name) |
|
|
|
model_dataset_metric_to_result_map[ |
|
(model_name, sanitised_dataset_name, sanitised_metric_name) |
|
] = value |
|
|
|
if model_name not in data_map: |
|
data_map[model_name] = {} |
|
data_map[model_name][(sanitised_dataset_name, sanitised_metric_name)] = value |
|
|
|
print( |
|
"model_name", |
|
model_name, |
|
"dataset_name", |
|
sanitised_dataset_name, |
|
"metric_name", |
|
sanitised_metric_name, |
|
"value", |
|
value, |
|
) |
|
|
|
save_data_map_to_cache(data_map, cache_file) |
|
|
|
model_name_lst = [m for m in data_map.keys()] |
|
|
|
nb_max_metrics = max(len(data_map[model_name]) for model_name in model_name_lst) |
|
|
|
for model_name in model_name_lst: |
|
if len(data_map[model_name]) < nb_max_metrics - 5: |
|
del data_map[model_name] |
|
|
|
plot_type_lst = ["all", "summ", "qa", "instr", "detect", "rc"] |
|
|
|
for plot_type in plot_type_lst: |
|
|
|
data_map_v2 = {} |
|
for model_name in data_map.keys(): |
|
for dataset_metric in data_map[model_name].keys(): |
|
if dataset_metric not in data_map_v2: |
|
data_map_v2[dataset_metric] = {} |
|
|
|
if plot_type in {"all"}: |
|
to_add = True |
|
if "ROUGE" in dataset_metric[1] and "ROUGE-L" not in dataset_metric[1]: |
|
to_add = False |
|
if "SQuAD" in dataset_metric[0] and "EM" not in dataset_metric[1]: |
|
to_add = False |
|
if "SelfCheckGPT" in dataset_metric[0] and "MAX" not in dataset_metric[1]: |
|
to_add = False |
|
if "64-shot" in dataset_metric[0]: |
|
to_add = False |
|
if to_add is True: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {"summ"}: |
|
if "CNN" in dataset_metric[0] or "XSum" in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {"qa"}: |
|
if "TriviaQA" in dataset_metric[0] or "NQ" in dataset_metric[0] or "TruthfulQA" in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {"instr"}: |
|
if "MemoTrap" in dataset_metric[0] or "IFEval" in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {"detect"}: |
|
if "HaluEval" in dataset_metric[0] or "SelfCheck" in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {"rc"}: |
|
if "RACE" in dataset_metric[0] or "SQuAD" in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
else: |
|
assert False, f"Unknown plot type: {plot_type}" |
|
|
|
|
|
df = pd.DataFrame.from_dict(data_map_v2, orient="index") |
|
df.index = [", ".join(map(str, idx)) for idx in df.index] |
|
|
|
o_df = df.copy(deep=True) |
|
|
|
|
|
|
|
print(df) |
|
|
|
|
|
df.replace([np.inf, -np.inf], np.nan, inplace=True) |
|
df.fillna(0, inplace=True) |
|
|
|
from sklearn.preprocessing import MinMaxScaler |
|
|
|
|
|
|
|
|
|
|
|
cell_height = 1.0 |
|
cell_width = 1.0 |
|
|
|
n_rows = len(df.index) |
|
n_cols = len(df.columns) |
|
|
|
|
|
fig_width = cell_width * n_cols + 0 |
|
fig_height = cell_height * n_rows + 0 |
|
|
|
col_cluster = True |
|
row_cluster = True |
|
|
|
sns.set_context("notebook", font_scale=1.3) |
|
|
|
dendrogram_ratio = (0.1, 0.1) |
|
|
|
if plot_type in {"detect"}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (0.1, 0.2) |
|
|
|
if plot_type in {"instr"}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (0.1, 0.4) |
|
|
|
if plot_type in {"qa"}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 4 |
|
dendrogram_ratio = (0.1, 0.2) |
|
|
|
if plot_type in {"summ"}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 2.0 |
|
dendrogram_ratio = (0.1, 0.1) |
|
row_cluster = False |
|
|
|
if plot_type in {"rc"}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (0.1, 0.4) |
|
|
|
print("figsize", (fig_width, fig_height)) |
|
|
|
o_df.to_json(f"plots/clustermap_{plot_type}.json", orient="split") |
|
|
|
print(f"Generating the clustermaps for {plot_type}") |
|
|
|
for cmap in [None, "coolwarm", "viridis"]: |
|
fig = sns.clustermap( |
|
df, |
|
method="ward", |
|
metric="euclidean", |
|
cmap=cmap, |
|
figsize=(fig_width, fig_height), |
|
annot=True, |
|
mask=o_df.isnull(), |
|
dendrogram_ratio=dendrogram_ratio, |
|
fmt=".2f", |
|
col_cluster=col_cluster, |
|
row_cluster=row_cluster, |
|
) |
|
|
|
|
|
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0) |
|
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90) |
|
|
|
cmap_suffix = "" if cmap is None else f"_{cmap}" |
|
|
|
|
|
fig.savefig(f"blog/figures/clustermap_{plot_type}{cmap_suffix}.pdf") |
|
fig.savefig(f"blog/figures/clustermap_{plot_type}{cmap_suffix}.png") |
|
fig.savefig(f"blog/figures/clustermap_{plot_type}{cmap_suffix}_t.png", transparent=True, facecolor="none") |
|
|