Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import sys | |
import json | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from scipy.cluster.hierarchy import linkage | |
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task | |
from src.envs import QUEUE_REPO, RESULTS_REPO, API | |
from src.utils import my_snapshot_download | |
def find_json_files(json_path): | |
res = [] | |
for root, dirs, files in os.walk(json_path): | |
for file in files: | |
if file.endswith(".json"): | |
res.append(os.path.join(root, file)) | |
return res | |
my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60) | |
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60) | |
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND) | |
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND) | |
model_name_to_model_map = {} | |
for path in request_path_lst: | |
with open(path, 'r') as f: | |
data = json.load(f) | |
model_name_to_model_map[data["model"]] = data | |
model_dataset_metric_to_result_map = {} | |
data_map = {} | |
for path in result_path_lst: | |
with open(path, 'r') as f: | |
data = json.load(f) | |
model_name = data["config"]["model_name"] | |
for dataset_name, results_dict in data["results"].items(): | |
for metric_name, value in results_dict.items(): | |
# print(model_name, dataset_name, metric_name, value) | |
if ',' in metric_name and '_stderr' not in metric_name \ | |
and 'f1' not in metric_name \ | |
and model_name_to_model_map[model_name]["likes"] > 256: | |
to_add = True | |
if 'selfcheck' in dataset_name: | |
if 'max' not in metric_name: | |
to_add = False | |
if 'nq_open' in dataset_name or 'triviaqa' in dataset_name: | |
to_add = False | |
# pass | |
# breakpoint() | |
if 'bertscore' in metric_name: | |
if 'precision' not in metric_name: | |
to_add = False | |
if 'correctness,' in metric_name or 'em,' in metric_name: | |
to_add = False | |
if 'rouge' in metric_name: | |
if 'rougeL' not in metric_name: | |
to_add = False | |
if 'ifeval' in dataset_name: | |
if 'prompt_level_strict_acc' not in metric_name: | |
to_add = False | |
if 'squad' in dataset_name: | |
to_add = False | |
if 'fever' in dataset_name: | |
to_add = False | |
if 'rouge' in metric_name: | |
value /= 100.0 | |
if to_add: | |
sanitised_metric_name = metric_name.split(',')[0] | |
model_dataset_metric_to_result_map[(model_name, dataset_name, sanitised_metric_name)] = value | |
# if (model_name, dataset_name) not in data_map: | |
# data_map[(model_name, dataset_name)] = {} | |
# data_map[(model_name, dataset_name)][metric_name] = value | |
if model_name not in data_map: | |
data_map[model_name] = {} | |
data_map[model_name][(dataset_name, sanitised_metric_name)] = value | |
print('model_name', model_name, 'dataset_name', dataset_name, 'metric_name', metric_name, 'value', value) | |
model_name_lst = [m for m in data_map.keys()] | |
for m in model_name_lst: | |
if len(data_map[m]) < 8: | |
del data_map[m] | |
df = pd.DataFrame.from_dict(data_map, orient='index') | |
o_df = df.copy(deep=True) | |
print(df) | |
# Check for NaN or infinite values and replace them | |
df.replace([np.inf, -np.inf], np.nan, inplace=True) # Replace infinities with NaN | |
df.fillna(0, inplace=True) # Replace NaN with 0 (or use another imputation strategy) | |
from sklearn.preprocessing import MinMaxScaler | |
# scaler = MinMaxScaler() | |
# df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns) | |
sns.set_context("notebook", font_scale=1.0) | |
# fig = sns.clustermap(df, method='average', metric='cosine', cmap='coolwarm', figsize=(16, 12), annot=True) | |
fig = sns.clustermap(df, method='ward', metric='euclidean', cmap='coolwarm', figsize=(16, 12), annot=True, mask=o_df.isnull()) | |
# Adjust the size of the cells (less wide) | |
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0) | |
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90) | |
# Save the clustermap to file | |
fig.savefig('plots/clustermap.pdf') | |
fig.savefig('plots/clustermap.png') | |