import gradio as gr
import pandas as pd
import random
import plotly.express as px
from huggingface_hub import snapshot_download
import os
import logging

from config import (
    SETUPS,
    LOCAL_RESULTS_DIR,
    CITATION_BUTTON_TEXT,
    CITATION_BUTTON_LABEL,
)
from parsing import read_all_configs, get_common_langs

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        # logging.FileHandler("app.log"),
        logging.StreamHandler()
    ],
)

logger = logging.getLogger(__name__)


try:
    print("Saving results locally at:", LOCAL_RESULTS_DIR)
    snapshot_download(
        repo_id="g8a9/fair-asr-results",
        local_dir=LOCAL_RESULTS_DIR,
        repo_type="dataset",
        tqdm_class=None,
        etag_timeout=30,
        ignore_patterns=["*samples*", "*transcripts*"],
        token=os.environ.get("TOKEN"),
    )
except Exception as e:
    raise e


def format_dataframe(df, times_100=False):
    if times_100:
        df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x))
    else:
        df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x))
    return df


def _build_models_with_nan_md(models_with_nan):
    model_markups = [f"*{m}*" for m in models_with_nan]
    return f"""
We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages.
"""


def build_components(show_common_langs):
    aggregated_df, lang_df, barplot_fig, models_with_nan = _populate_components(
        show_common_langs
    )
    models_with_nan_md = _build_models_with_nan_md(models_with_nan)

    return (
        gr.DataFrame(format_dataframe(aggregated_df)),
        gr.DataFrame(format_dataframe(lang_df, times_100=True)),
        gr.Plot(barplot_fig),
        gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
    )


def _populate_components(show_common_langs):
    fm = SETUPS[0]
    setup = fm["majority_group"] + "_" + fm["minority_group"]
    results = read_all_configs(setup)

    if show_common_langs:
        common_langs = get_common_langs()
        logger.info(f"Common langs: {common_langs}")
        results = results[results["Language"].isin(common_langs)]

    missing_langs = (
        results[results.isna().any(axis=1)]
        .groupby("Model")["Language"]
        .apply(list)
        .to_dict()
    )
    for model, langs in missing_langs.items():
        logger.info(
            f"Model {model} is missing results for languages: {', '.join(langs)}"
        )

    models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist()
    logger.info(f"Models with NaN values: {models_with_nan}")
    results = results[~results["Model"].isin(models_with_nan)]

    aggregated_df = (
        results.pivot_table(
            index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum()
        )
        .reset_index()
        .sort_values("Gap")
    )
    best_model = aggregated_df.iloc[0]["Model"]
    top_3_models = aggregated_df["Model"].head(3).tolist()
    # main_df = gr.DataFrame(format_dataframe(model_results))

    lang_df = results.pivot_table(
        index="Model",
        values="Gap",
        columns="Language",
    ).reset_index()
    # lang_df = gr.DataFrame(format_dataframe(lang_results, times_100=True))

    # gr.Plot(fig1)
    results["Gap"] = results["Gap"] * 100
    barplot_fig = px.bar(
        results.loc[results["Model"].isin(top_3_models)],
        x="Language",
        y="Gap",
        color="Model",
        title="Gaps by Language and Model (top 3, sorted by the best model)",
        labels={
            "Gap": "Sum of Absolute Gaps (%)",
            "Language": "Language",
            "Model": "Model",
        },
        barmode="group",
    )
    lang_order = (
        lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
    )
    logger.info(f"Lang order: {lang_order}")

    barplot_fig.update_layout(
        xaxis={"categoryorder": "array", "categoryarray": lang_order}
    )

    return aggregated_df, lang_df, barplot_fig, models_with_nan


with gr.Blocks() as fm_interface:
    aggregated_df, lang_df, barplot_fig, model_with_nan = _populate_components(
        show_common_langs=False
    )
    model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))

    gr.Markdown("### Sum of Absolute Gaps ⬇️")
    aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))

    gr.Markdown("#### F-M gaps by language")
    lang_df_comp = gr.DataFrame(format_dataframe(lang_df, times_100=True))

    barplot_fig_comp = gr.Plot(barplot_fig)

###################
# LIST MAIN TABS
###################
tabs = [fm_interface]
titles = ["F-M Setup"]

banner = """
<style>
    .full-width-image {
        width: 100%;
        height: auto;
        margin: 0;
        padding: 0;
    }
</style>
<div>
    <img src="https://huggingface.co/spaces/g8a9/fair-asr-leaderboard/raw/main/twists_banner.png" alt="Twists Banner" class="full-width-image">
</div>
"""

###################
# MAIN INTERFACE
###################
with gr.Blocks() as demo:
    gr.HTML(banner)

    with gr.Row() as config_row:
        show_common_langs = gr.CheckboxGroup(
            choices=["Show only common languages"],
            label="Main configuration",
        )
        include_datasets = gr.CheckboxGroup(
            choices=["Mozilla CV 17"],
            label="Include datasets",
            value=["Mozilla CV 17"],
            interactive=False,
        )

        show_common_langs.input(
            build_components,
            inputs=[show_common_langs],
            outputs=[
                aggregated_df_comp,
                lang_df_comp,
                barplot_fig_comp,
                model_with_nans_md,
            ],
        )

    gr.TabbedInterface(tabs, titles)

    gr.Textbox(
        value=CITATION_BUTTON_TEXT,
        label=CITATION_BUTTON_LABEL,
        max_lines=6,
        show_copy_button=True,
    )

if __name__ == "__main__":
    demo.launch()