#!/usr/bin/env python3

import os
import sys
import json

import numpy as np

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from scipy.cluster.hierarchy import linkage

from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task

from src.envs import QUEUE_REPO, RESULTS_REPO, API
from src.utils import my_snapshot_download


def find_json_files(json_path):
    res = []
    for root, dirs, files in os.walk(json_path):
        for file in files:
            if file.endswith(".json"):
                res.append(os.path.join(root, file))
    return res


my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)

result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND)
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND)

model_name_to_model_map = {}

for path in request_path_lst:
    with open(path, 'r') as f:
        data = json.load(f)
    model_name_to_model_map[data["model"]] = data

model_dataset_metric_to_result_map = {}
data_map = {}

for path in result_path_lst:
    with open(path, 'r') as f:
        data = json.load(f)
    model_name = data["config"]["model_name"]
    for dataset_name, results_dict in data["results"].items():
        for metric_name, value in results_dict.items():

            # print(model_name, dataset_name, metric_name, value)

            if ',' in metric_name and '_stderr' not in metric_name \
                    and 'f1' not in metric_name \
                    and model_name_to_model_map[model_name]["likes"] > 256:

                to_add = True

                if 'selfcheck' in dataset_name:
                    if 'max' not in metric_name:
                        to_add = False

                if 'nq_open' in dataset_name or 'triviaqa' in dataset_name:
                    to_add = False
                    # pass

                # breakpoint()

                if 'bertscore' in metric_name:
                    if 'precision' not in metric_name:
                        to_add = False

                if 'correctness,' in metric_name or 'em,' in metric_name:
                    to_add = False

                if 'rouge' in metric_name:
                    if 'rougeL' not in metric_name:
                        to_add = False

                if 'ifeval' in dataset_name:
                    if 'prompt_level_strict_acc' not in metric_name:
                        to_add = False

                if 'squad' in dataset_name:
                    to_add = False

                if 'fever' in dataset_name:
                    to_add = False

                if 'rouge' in metric_name:
                    value /= 100.0

                if to_add:
                    sanitised_metric_name = metric_name.split(',')[0]
                    model_dataset_metric_to_result_map[(model_name, dataset_name, sanitised_metric_name)] = value

                    # if (model_name, dataset_name) not in data_map:
                    #     data_map[(model_name, dataset_name)] = {}
                    # data_map[(model_name, dataset_name)][metric_name] = value

                    if model_name not in data_map:
                        data_map[model_name] = {}
                    data_map[model_name][(dataset_name, sanitised_metric_name)] = value

                    print('model_name', model_name, 'dataset_name', dataset_name, 'metric_name', metric_name, 'value', value)

model_name_lst = [m for m in data_map.keys()]
for m in model_name_lst:
    if len(data_map[m]) < 8:
        del data_map[m]

df = pd.DataFrame.from_dict(data_map, orient='index')
o_df = df.copy(deep=True)

print(df)

# Check for NaN or infinite values and replace them
df.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace infinities with NaN
df.fillna(0, inplace=True)  # Replace NaN with 0 (or use another imputation strategy)

from sklearn.preprocessing import MinMaxScaler

# scaler = MinMaxScaler()
# df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns)

sns.set_context("notebook", font_scale=1.0)

# fig = sns.clustermap(df, method='average', metric='cosine', cmap='coolwarm', figsize=(16, 12), annot=True)
fig = sns.clustermap(df, method='ward', metric='euclidean', cmap='coolwarm', figsize=(16, 12), annot=True, mask=o_df.isnull())

# Adjust the size of the cells (less wide)
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0)
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90)

# Save the clustermap to file
fig.savefig('plots/clustermap.pdf')
fig.savefig('plots/clustermap.png')