import glob
import json
import os
from typing import List

from huggingface_hub import HfApi
from tqdm import tqdm

from src.display_models.model_metadata_flags import DO_NOT_SUBMIT_MODELS, FLAGGED_MODELS
from src.display_models.model_metadata_type import MODEL_TYPE_METADATA, ModelType, model_type_from_str
from src.display_models.utils import AutoEvalColumn, model_hyperlink

api = HfApi(token=os.environ.get("H4_TOKEN", None))


def get_model_metadata(leaderboard_data: List[dict]):
    for model_data in tqdm(leaderboard_data):
        request_files = os.path.join(
            "eval-queue",
            model_data["model_name_for_query"] + "_eval_request_*" + ".json",
        )
        request_files = glob.glob(request_files)

        # Select correct request file (precision)
        request_file = ""
        if len(request_files) == 1:
            request_file = request_files[0]
        elif len(request_files) > 1:
            request_files = sorted(request_files, reverse=True)
            for tmp_request_file in request_files:
                with open(tmp_request_file, "r") as f:
                    req_content = json.load(f)
                    if (
                        req_content["status"] == "FINISHED"
                        and req_content["precision"] == model_data["Precision"].split(".")[-1]
                    ):
                        request_file = tmp_request_file

        try:
            with open(request_file, "r") as f:
                request = json.load(f)
            model_type = model_type_from_str(request["model_type"])
            model_data[AutoEvalColumn.model_type.name] = model_type.value.name
            model_data[AutoEvalColumn.model_type_symbol.name] = model_type.value.symbol  # + ("🔺" if is_delta else "")
            model_data[AutoEvalColumn.license.name] = request["license"]
            model_data[AutoEvalColumn.likes.name] = request["likes"]
            model_data[AutoEvalColumn.params.name] = request["params"]
        except Exception:
            if model_data["model_name_for_query"] in MODEL_TYPE_METADATA:
                model_data[AutoEvalColumn.model_type.name] = MODEL_TYPE_METADATA[
                    model_data["model_name_for_query"]
                ].value.name
                model_data[AutoEvalColumn.model_type_symbol.name] = MODEL_TYPE_METADATA[
                    model_data["model_name_for_query"]
                ].value.symbol  # + ("🔺" if is_delta else "")
            else:
                model_data[AutoEvalColumn.model_type.name] = ModelType.Unknown.value.name
                model_data[AutoEvalColumn.model_type_symbol.name] = ModelType.Unknown.value.symbol


def flag_models(leaderboard_data: List[dict]):
    for model_data in leaderboard_data:
        if model_data["model_name_for_query"] in FLAGGED_MODELS:
            issue_num = FLAGGED_MODELS[model_data["model_name_for_query"]].split("/")[-1]
            issue_link = model_hyperlink(
                FLAGGED_MODELS[model_data["model_name_for_query"]],
                f"See discussion #{issue_num}",
            )
            model_data[
                AutoEvalColumn.model.name
            ] = f"{model_data[AutoEvalColumn.model.name]} has been flagged! {issue_link}"


def remove_forbidden_models(leaderboard_data: List[dict]):
    indices_to_remove = []
    for ix, model in enumerate(leaderboard_data):
        if model["model_name_for_query"] in DO_NOT_SUBMIT_MODELS:
            indices_to_remove.append(ix)

    for ix in reversed(indices_to_remove):
        leaderboard_data.pop(ix)
    return leaderboard_data


def apply_metadata(leaderboard_data: List[dict]):
    leaderboard_data = remove_forbidden_models(leaderboard_data)
    get_model_metadata(leaderboard_data)
    flag_models(leaderboard_data)