Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import sys | |
import json | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from scipy.cluster.hierarchy import linkage | |
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task | |
from src.envs import QUEUE_REPO, RESULTS_REPO, API | |
from src.utils import my_snapshot_download | |
def find_json_files(json_path): | |
res = [] | |
for root, dirs, files in os.walk(json_path): | |
for file in files: | |
if file.endswith(".json"): | |
res.append(os.path.join(root, file)) | |
return res | |
def sanitise_metric(name: str) -> str: | |
res = name | |
res = res.replace("prompt_level_strict_acc", "Prompt-Level Accuracy") | |
res = res.replace("acc", "Accuracy") | |
res = res.replace("exact_match", "EM") | |
res = res.replace("avg-selfcheckgpt", "AVG") | |
res = res.replace("max-selfcheckgpt", "MAX") | |
res = res.replace("rouge", "ROUGE-") | |
res = res.replace("bertscore_precision", "BERT-P") | |
res = res.replace("exact", "EM") | |
res = res.replace("HasAns_EM", "HasAns") | |
res = res.replace("NoAns_EM", "NoAns") | |
return res | |
def sanitise_dataset(name: str) -> str: | |
res = name | |
res = res.replace("tqa8", "TriviaQA") | |
res = res.replace("nq8", "NQ") | |
res = res.replace("truthfulqa", "TruthfulQA") | |
res = res.replace("ifeval", "IFEval") | |
res = res.replace("selfcheckgpt", "SelfCheckGPT") | |
res = res.replace("truefalse_cieacf", "True-False") | |
res = res.replace("mc", "MC") | |
res = res.replace("race", "RACE") | |
res = res.replace("squad", "SQuAD") | |
res = res.replace("memo-trap", "MemoTrap") | |
res = res.replace("cnndm", "CNN/DM") | |
res = res.replace("xsum", "XSum") | |
res = res.replace("qa", "QA") | |
res = res.replace("summarization", "Summarization") | |
res = res.replace("dialogue", "Dialog") | |
res = res.replace("halueval", "HaluEval") | |
res = res.replace("_", " ") | |
return res | |
cache_file = 'data_map_cache.pkl' | |
def load_data_map_from_cache(cache_file): | |
if os.path.exists(cache_file): | |
with open(cache_file, 'rb') as f: | |
return pickle.load(f) | |
else: | |
return None | |
def save_data_map_to_cache(data_map, cache_file): | |
with open(cache_file, 'wb') as f: | |
pickle.dump(data_map, f) | |
# Try to load the data_map from the cache file | |
data_map = load_data_map_from_cache(cache_file) | |
if data_map is None: | |
my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60) | |
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60) | |
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND) | |
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND) | |
model_name_to_model_map = {} | |
for path in request_path_lst: | |
with open(path, 'r') as f: | |
data = json.load(f) | |
model_name_to_model_map[data["model"]] = data | |
model_dataset_metric_to_result_map = {} | |
# data_map[model_name][(dataset_name, sanitised_metric_name)] = value | |
data_map = {} | |
for path in result_path_lst: | |
with open(path, 'r') as f: | |
data = json.load(f) | |
model_name = data["config"]["model_name"] | |
for dataset_name, results_dict in data["results"].items(): | |
for metric_name, value in results_dict.items(): | |
if ',' in metric_name and '_stderr' not in metric_name \ | |
and 'f1' not in metric_name \ | |
and model_name_to_model_map[model_name]["likes"] > 128: | |
to_add = True | |
if 'memo-trap_v2' in dataset_name: | |
to_add = False | |
if 'selfcheck' in dataset_name: | |
# if 'max' in metric_name: | |
# to_add = False | |
pass | |
if 'faithdial' in dataset_name: | |
to_add = False | |
if 'nq_open' in dataset_name or 'triviaqa' in dataset_name: | |
to_add = False | |
if 'truthfulqa_gen' in dataset_name: | |
to_add = False | |
if 'bertscore' in metric_name: | |
if 'precision' not in metric_name: | |
to_add = False | |
if 'correctness,' in metric_name or 'em,' in metric_name: | |
to_add = False | |
if 'rouge' in metric_name: | |
pass | |
# if 'rougeL' not in metric_name: | |
# to_add = False | |
if 'ifeval' in dataset_name: | |
if 'prompt_level_strict_acc' not in metric_name: | |
to_add = False | |
if 'squad' in dataset_name: | |
# to_add = False | |
if 'best_exact' in metric_name: | |
to_add = False | |
if 'fever' in dataset_name: | |
to_add = False | |
if ('xsum' in dataset_name or 'cnn' in dataset_name) and 'v2' in dataset_name: | |
to_add = False | |
if 'rouge' in metric_name: | |
value /= 100.0 | |
if 'squad' in dataset_name: | |
value /= 100.0 | |
if to_add: | |
sanitised_metric_name = sanitise_metric(metric_name.split(',')[0]) | |
sanitised_dataset_name = sanitise_dataset(dataset_name) | |
model_dataset_metric_to_result_map[(model_name, sanitised_dataset_name, sanitised_metric_name)] = value | |
if model_name not in data_map: | |
data_map[model_name] = {} | |
data_map[model_name][(sanitised_dataset_name, sanitised_metric_name)] = value | |
print('model_name', model_name, 'dataset_name', sanitised_dataset_name, 'metric_name', sanitised_metric_name, 'value', value) | |
save_data_map_to_cache(data_map, cache_file) | |
model_name_lst = [m for m in data_map.keys()] | |
nb_max_metrics = max(len(data_map[model_name]) for model_name in model_name_lst) | |
for model_name in model_name_lst: | |
if len(data_map[model_name]) < nb_max_metrics - 5: | |
del data_map[model_name] | |
plot_type_lst = ['all', 'summ', 'qa', 'instr', 'detect', 'rc'] | |
for plot_type in plot_type_lst: | |
data_map_v2 = {} | |
for model_name in data_map.keys(): | |
for dataset_metric in data_map[model_name].keys(): | |
if dataset_metric not in data_map_v2: | |
data_map_v2[dataset_metric] = {} | |
if plot_type in {'all'}: | |
to_add = True | |
if 'ROUGE' in dataset_metric[1] and 'ROUGE-L' not in dataset_metric[1]: | |
to_add = False | |
if 'SQuAD' in dataset_metric[0] and 'EM' not in dataset_metric[1]: | |
to_add = False | |
if 'SelfCheckGPT' in dataset_metric[0] and 'MAX' not in dataset_metric[1]: | |
to_add = False | |
if to_add is True: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
elif plot_type in {'summ'}: | |
if 'CNN' in dataset_metric[0] or 'XSum' in dataset_metric[0]: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
elif plot_type in {'qa'}: | |
if 'TriviaQA' in dataset_metric[0] or 'NQ' in dataset_metric[0] or 'TruthfulQA' in dataset_metric[0]: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
elif plot_type in {'instr'}: | |
if 'MemoTrap' in dataset_metric[0] or 'IFEval' in dataset_metric[0]: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
elif plot_type in {'detect'}: | |
if 'HaluEval' in dataset_metric[0] or 'SelfCheck' in dataset_metric[0]: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
elif plot_type in {'rc'}: | |
if 'RACE' in dataset_metric[0] or 'SQuAD' in dataset_metric[0]: | |
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric] | |
else: | |
assert False, f"Unknown plot type: {plot_type}" | |
# df = pd.DataFrame.from_dict(data_map, orient='index') # Invert the y-axis (rows) | |
df = pd.DataFrame.from_dict(data_map_v2, orient='index') # Invert the y-axis (rows) | |
df.index = [', '.join(map(str, idx)) for idx in df.index] | |
o_df = df.copy(deep=True) | |
# breakpoint() | |
print(df) | |
# Check for NaN or infinite values and replace them | |
df.replace([np.inf, -np.inf], np.nan, inplace=True) # Replace infinities with NaN | |
df.fillna(0, inplace=True) # Replace NaN with 0 (or use another imputation strategy) | |
from sklearn.preprocessing import MinMaxScaler | |
# scaler = MinMaxScaler() | |
# df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns) | |
# Calculate dimensions based on the DataFrame size | |
cell_height = 1.0 # Height of each cell in inches | |
cell_width = 1.0 # Width of each cell in inches | |
n_rows = len(df.index) # Datasets and Metrics | |
n_cols = len(df.columns) # Models | |
# Calculate figure size dynamically | |
fig_width = cell_width * n_cols + 0 | |
fig_height = cell_height * n_rows + 0 | |
col_cluster = True | |
row_cluster = True | |
sns.set_context("notebook", font_scale=1.3) | |
dendrogram_ratio = (.1, .1) | |
if plot_type in {'detect'}: | |
fig_width = cell_width * n_cols - 2 | |
fig_height = cell_height * n_rows + 5.2 | |
dendrogram_ratio = (.1, .2) | |
if plot_type in {'instr'}: | |
fig_width = cell_width * n_cols - 2 | |
fig_height = cell_height * n_rows + 5.2 | |
dendrogram_ratio = (.1, .4) | |
if plot_type in {'qa'}: | |
fig_width = cell_width * n_cols - 2 | |
fig_height = cell_height * n_rows + 4 | |
dendrogram_ratio = (.1, .2) | |
if plot_type in {'summ'}: | |
fig_width = cell_width * n_cols - 2 | |
fig_height = cell_height * n_rows + 2.0 | |
dendrogram_ratio = (.1, .1) | |
row_cluster = False | |
if plot_type in {'rc'}: | |
fig_width = cell_width * n_cols - 2 | |
fig_height = cell_height * n_rows + 5.2 | |
dendrogram_ratio = (.1, .4) | |
print('figsize', (fig_width, fig_height)) | |
o_df.to_json(f'plots/clustermap_{plot_type}.json', orient='split') | |
print(f'Generating the clustermaps for {plot_type}') | |
for cmap in [None, 'coolwarm', 'viridis']: | |
fig = sns.clustermap(df, | |
method='ward', | |
metric='euclidean', | |
cmap=cmap, | |
figsize=(fig_width, fig_height), # figsize=(24, 16), | |
annot=True, | |
mask=o_df.isnull(), | |
dendrogram_ratio=dendrogram_ratio, | |
fmt='.2f', | |
col_cluster=col_cluster, | |
row_cluster=row_cluster) | |
# Adjust the size of the cells (less wide) | |
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0) | |
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90) | |
cmap_suffix = '' if cmap is None else f'_{cmap}' | |
# Save the clustermap to file | |
fig.savefig(f'plots/clustermap_{plot_type}{cmap_suffix}.pdf') | |
fig.savefig(f'plots/clustermap_{plot_type}{cmap_suffix}.png') | |
fig.savefig(f'plots/clustermap_{plot_type}{cmap_suffix}_t.png', transparent=True, facecolor="none") | |