|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import pickle |
|
|
|
import numpy as np |
|
|
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
|
|
from scipy.cluster.hierarchy import linkage |
|
|
|
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task |
|
|
|
from src.envs import QUEUE_REPO, RESULTS_REPO, API |
|
from src.utils import my_snapshot_download |
|
|
|
|
|
def is_float(string): |
|
try: |
|
float(string) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
|
|
def find_json_files(json_path): |
|
res = [] |
|
for root, dirs, files in os.walk(json_path): |
|
for file in files: |
|
if file.endswith(".json"): |
|
res.append(os.path.join(root, file)) |
|
return res |
|
|
|
|
|
def sanitise_metric(name: str) -> str: |
|
res = name |
|
res = res.replace("prompt_level_strict_acc", "Prompt-Level Accuracy") |
|
res = res.replace("acc", "Accuracy") |
|
res = res.replace("exact_match", "EM") |
|
res = res.replace("avg-selfcheckgpt", "AVG") |
|
res = res.replace("max-selfcheckgpt", "MAX") |
|
res = res.replace("rouge", "ROUGE-") |
|
res = res.replace("bertscore_precision", "BERT-P") |
|
res = res.replace("exact", "EM") |
|
res = res.replace("HasAns_EM", "HasAns") |
|
res = res.replace("NoAns_EM", "NoAns") |
|
res = res.replace("em", "EM") |
|
return res |
|
|
|
|
|
def sanitise_dataset(name: str) -> str: |
|
res = name |
|
res = res.replace("tqa8", "TriviaQA (8-shot)") |
|
res = res.replace("nq8", "NQ (8-shot)") |
|
res = res.replace("nq_open", "NQ (64-shot)") |
|
res = res.replace("triviaqa", "TriviaQA (64-shot)") |
|
res = res.replace("truthfulqa", "TruthfulQA") |
|
res = res.replace("ifeval", "IFEval") |
|
res = res.replace("selfcheckgpt", "SelfCheckGPT") |
|
res = res.replace("truefalse_cieacf", "True-False") |
|
res = res.replace("mc", "MC") |
|
res = res.replace("race", "RACE") |
|
res = res.replace("squad", "SQuAD") |
|
res = res.replace("memo-trap", "MemoTrap") |
|
res = res.replace("cnndm", "CNN/DM") |
|
res = res.replace("xsum", "XSum") |
|
res = res.replace("qa", "QA") |
|
res = res.replace("summarization", "Summarization") |
|
res = res.replace("dialogue", "Dialog") |
|
res = res.replace("halueval", "HaluEval") |
|
res = res.replace("_v2", "") |
|
res = res.replace("_", " ") |
|
return res |
|
|
|
|
|
cache_file = 'data_map_cache.pkl' |
|
|
|
|
|
def load_data_map_from_cache(cache_file): |
|
if os.path.exists(cache_file): |
|
with open(cache_file, 'rb') as f: |
|
return pickle.load(f) |
|
else: |
|
return None |
|
|
|
|
|
def save_data_map_to_cache(data_map, cache_file): |
|
with open(cache_file, 'wb') as f: |
|
pickle.dump(data_map, f) |
|
|
|
|
|
|
|
data_map = load_data_map_from_cache(cache_file) |
|
|
|
|
|
if data_map is None: |
|
my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60) |
|
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60) |
|
|
|
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND) |
|
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND) |
|
|
|
model_name_to_model_map = {} |
|
|
|
for path in request_path_lst: |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
model_name_to_model_map[data["model"]] = data |
|
|
|
model_dataset_metric_to_result_map = {} |
|
|
|
|
|
data_map = {} |
|
|
|
for path in result_path_lst: |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
model_name = data["config"]["model_name"] |
|
for dataset_name, results_dict in data["results"].items(): |
|
for metric_name, value in results_dict.items(): |
|
|
|
if model_name_to_model_map[model_name]["likes"] > 128: |
|
|
|
to_add = True |
|
|
|
if 'f1' in metric_name: |
|
to_add = False |
|
|
|
if 'stderr' in metric_name: |
|
to_add = False |
|
|
|
if 'memo-trap_v2' in dataset_name: |
|
to_add = False |
|
|
|
if 'faithdial' in dataset_name: |
|
to_add = False |
|
|
|
if 'truthfulqa_gen' in dataset_name: |
|
to_add = False |
|
|
|
if 'bertscore' in metric_name: |
|
if 'precision' not in metric_name: |
|
to_add = False |
|
|
|
if 'halueval' in dataset_name: |
|
if 'acc' not in metric_name: |
|
to_add = False |
|
|
|
if 'ifeval' in dataset_name: |
|
if 'prompt_level_strict_acc' not in metric_name: |
|
to_add = False |
|
|
|
if 'squad' in dataset_name: |
|
|
|
if 'best_exact' in metric_name: |
|
to_add = False |
|
|
|
if 'fever' in dataset_name: |
|
to_add = False |
|
|
|
if ('xsum' in dataset_name or 'cnn' in dataset_name) and 'v2' not in dataset_name: |
|
to_add = False |
|
|
|
if isinstance(value, str): |
|
if is_float(value): |
|
value = float(value) |
|
else: |
|
to_add = False |
|
|
|
if to_add: |
|
if 'rouge' in metric_name: |
|
value /= 100.0 |
|
|
|
if 'squad' in dataset_name: |
|
value /= 100.0 |
|
|
|
sanitised_metric_name = metric_name |
|
if "," in sanitised_metric_name: |
|
sanitised_metric_name = sanitised_metric_name.split(',')[0] |
|
sanitised_metric_name = sanitise_metric(sanitised_metric_name) |
|
sanitised_dataset_name = sanitise_dataset(dataset_name) |
|
|
|
model_dataset_metric_to_result_map[(model_name, sanitised_dataset_name, sanitised_metric_name)] = value |
|
|
|
if model_name not in data_map: |
|
data_map[model_name] = {} |
|
data_map[model_name][(sanitised_dataset_name, sanitised_metric_name)] = value |
|
|
|
print('model_name', model_name, 'dataset_name', sanitised_dataset_name, 'metric_name', sanitised_metric_name, 'value', value) |
|
|
|
save_data_map_to_cache(data_map, cache_file) |
|
|
|
model_name_lst = [m for m in data_map.keys()] |
|
|
|
nb_max_metrics = max(len(data_map[model_name]) for model_name in model_name_lst) |
|
|
|
for model_name in model_name_lst: |
|
if len(data_map[model_name]) < nb_max_metrics - 5: |
|
del data_map[model_name] |
|
|
|
plot_type_lst = ['all', 'summ', 'qa', 'instr', 'detect', 'rc'] |
|
|
|
for plot_type in plot_type_lst: |
|
|
|
data_map_v2 = {} |
|
for model_name in data_map.keys(): |
|
for dataset_metric in data_map[model_name].keys(): |
|
if dataset_metric not in data_map_v2: |
|
data_map_v2[dataset_metric] = {} |
|
|
|
if plot_type in {'all'}: |
|
to_add = True |
|
if 'ROUGE' in dataset_metric[1] and 'ROUGE-L' not in dataset_metric[1]: |
|
to_add = False |
|
if 'SQuAD' in dataset_metric[0] and 'EM' not in dataset_metric[1]: |
|
to_add = False |
|
if 'SelfCheckGPT' in dataset_metric[0] and 'MAX' not in dataset_metric[1]: |
|
to_add = False |
|
if '64-shot' in dataset_metric[0]: |
|
to_add = False |
|
if to_add is True: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {'summ'}: |
|
if 'CNN' in dataset_metric[0] or 'XSum' in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {'qa'}: |
|
if 'TriviaQA' in dataset_metric[0] or 'NQ' in dataset_metric[0] or 'TruthfulQA' in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {'instr'}: |
|
if 'MemoTrap' in dataset_metric[0] or 'IFEval' in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {'detect'}: |
|
if 'HaluEval' in dataset_metric[0] or 'SelfCheck' in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
elif plot_type in {'rc'}: |
|
if 'RACE' in dataset_metric[0] or 'SQuAD' in dataset_metric[0]: |
|
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] |
|
else: |
|
assert False, f"Unknown plot type: {plot_type}" |
|
|
|
|
|
df = pd.DataFrame.from_dict(data_map_v2, orient='index') |
|
df.index = [', '.join(map(str, idx)) for idx in df.index] |
|
|
|
o_df = df.copy(deep=True) |
|
|
|
|
|
|
|
print(df) |
|
|
|
|
|
df.replace([np.inf, -np.inf], np.nan, inplace=True) |
|
df.fillna(0, inplace=True) |
|
|
|
from sklearn.preprocessing import MinMaxScaler |
|
|
|
|
|
|
|
|
|
|
|
cell_height = 1.0 |
|
cell_width = 1.0 |
|
|
|
n_rows = len(df.index) |
|
n_cols = len(df.columns) |
|
|
|
|
|
fig_width = cell_width * n_cols + 0 |
|
fig_height = cell_height * n_rows + 0 |
|
|
|
col_cluster = True |
|
row_cluster = True |
|
|
|
sns.set_context("notebook", font_scale=1.3) |
|
|
|
dendrogram_ratio = (.1, .1) |
|
|
|
if plot_type in {'detect'}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (.1, .2) |
|
|
|
if plot_type in {'instr'}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (.1, .4) |
|
|
|
if plot_type in {'qa'}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 4 |
|
dendrogram_ratio = (.1, .2) |
|
|
|
if plot_type in {'summ'}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 2.0 |
|
dendrogram_ratio = (.1, .1) |
|
row_cluster = False |
|
|
|
if plot_type in {'rc'}: |
|
fig_width = cell_width * n_cols - 2 |
|
fig_height = cell_height * n_rows + 5.2 |
|
dendrogram_ratio = (.1, .4) |
|
|
|
print('figsize', (fig_width, fig_height)) |
|
|
|
o_df.to_json(f'plots/clustermap_{plot_type}.json', orient='split') |
|
|
|
print(f'Generating the clustermaps for {plot_type}') |
|
|
|
for cmap in [None, 'coolwarm', 'viridis']: |
|
fig = sns.clustermap(df, |
|
method='ward', |
|
metric='euclidean', |
|
cmap=cmap, |
|
figsize=(fig_width, fig_height), |
|
annot=True, |
|
mask=o_df.isnull(), |
|
dendrogram_ratio=dendrogram_ratio, |
|
fmt='.2f', |
|
col_cluster=col_cluster, |
|
row_cluster=row_cluster) |
|
|
|
|
|
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0) |
|
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90) |
|
|
|
cmap_suffix = '' if cmap is None else f'_{cmap}' |
|
|
|
|
|
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}.pdf') |
|
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}.png') |
|
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}_t.png', transparent=True, facecolor="none") |
|
|