import json
import logging
import os
import subprocess
import time

import datasets
import gradio as gr
import huggingface_hub
from transformers.pipelines import TextClassificationPipeline

from io_utils import (
    convert_column_mapping_to_json,
    read_inference_type,
    read_scanners,
    write_inference_type,
    write_scanners,
)
from text_classification import (
    check_column_mapping_keys_validity,
    text_classification_fix_column_mapping,
)
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_MD, CONFIRM_MAPPING_DETAILS_MD

HF_REPO_ID = "HF_REPO_ID"
HF_SPACE_ID = "SPACE_ID"
HF_WRITE_TOKEN = "HF_WRITE_TOKEN"


def check_model(model_id):
    try:
        task = huggingface_hub.model_info(model_id).pipeline_tag
    except Exception:
        return None, None

    try:
        from transformers import pipeline

        ppl = pipeline(task=task, model=model_id)

        return model_id, ppl
    except Exception as e:
        return model_id, e


def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
    try:
        configs = datasets.get_dataset_config_names(dataset_id)
    except Exception:
        # Dataset may not exist
        return None, dataset_config, dataset_split

    if dataset_config not in configs:
        # Need to choose dataset subset (config)
        return dataset_id, configs, dataset_split

    ds = datasets.load_dataset(dataset_id, dataset_config)

    if isinstance(ds, datasets.DatasetDict):
        # Need to choose dataset split
        if dataset_split not in ds.keys():
            return dataset_id, None, list(ds.keys())
    elif not isinstance(ds, datasets.Dataset):
        # Unknown type
        return dataset_id, None, None
    return dataset_id, dataset_config, dataset_split


def try_validate(
    m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping="{}"
):
    # Validate model
    if m_id is None:
        gr.Warning(
            "Model is not accessible. Please set your HF_TOKEN if it is a private model."
        )
        return (
            gr.update(interactive=False),  # Submit button
            gr.update(visible=True),  # Loading row
            gr.update(visible=False),  # Preview row
            gr.update(visible=False),  # Model prediction input
            gr.update(visible=False),  # Model prediction preview
            gr.update(visible=False),  # Label mapping preview
            gr.update(visible=False),  # feature mapping preview
        )
    if isinstance(ppl, Exception):
        gr.Warning(f'Failed to load model": {ppl}')
        return (
            gr.update(interactive=False),  # Submit button
            gr.update(visible=True),  # Loading row
            gr.update(visible=False),  # Preview row
            gr.update(visible=False),  # Model prediction input
            gr.update(visible=False),  # Model prediction preview
            gr.update(visible=False),  # Label mapping preview
            gr.update(visible=False),  # feature mapping preview
        )

    # Validate dataset
    d_id, config, split = check_dataset(
        dataset_id=dataset_id,
        dataset_config=dataset_config,
        dataset_split=dataset_split,
    )

    dataset_ok = False
    if d_id is None:
        gr.Warning(
            f'Dataset "{dataset_id}" is not accessible. Please set your HF_TOKEN if it is a private dataset.'
        )
    elif isinstance(config, list):
        gr.Warning(
            f'Dataset "{dataset_id}" does not have "{dataset_config}" config. Please choose a valid config.'
        )
        config = gr.update(choices=config, value=config[0])
    elif isinstance(split, list):
        gr.Warning(
            f'Dataset "{dataset_id}" does not have "{dataset_split}" split. Please choose a valid split.'
        )
        split = gr.update(choices=split, value=split[0])
    else:
        dataset_ok = True

    if not dataset_ok:
        return (
            gr.update(interactive=False),  # Submit button
            gr.update(visible=True),  # Loading row
            gr.update(visible=False),  # Preview row
            gr.update(visible=False),  # Model prediction input
            gr.update(visible=False),  # Model prediction preview
            gr.update(visible=False),  # Label mapping preview
            gr.update(visible=False),  # feature mapping preview
        )

    # TODO: Validate column mapping by running once
    prediction_result = None
    id2label_df = None
    if isinstance(ppl, TextClassificationPipeline):
        try:
            column_mapping = json.loads(column_mapping)
        except Exception:
            column_mapping = {}

        (
            column_mapping,
            prediction_input,
            prediction_result,
            id2label_df,
            feature_df,
        ) = text_classification_fix_column_mapping(
            column_mapping, ppl, d_id, config, split
        )

        column_mapping = json.dumps(column_mapping, indent=2)

    if prediction_result is None and id2label_df is not None:
        gr.Warning(
            'The model failed to predict with the first row in the dataset. Please provide feature mappings in "Advance" settings.'
        )
        return (
            gr.update(interactive=False),  # Submit button
            gr.update(visible=False),  # Loading row
            gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True),  # Preview row
            gr.update(
                value=f"**Sample Input**: {prediction_input}", visible=True
            ),  # Model prediction input
            gr.update(visible=False),  # Model prediction preview
            gr.update(
                value=id2label_df, visible=True, interactive=True
            ),  # Label mapping preview
            gr.update(
                value=feature_df, visible=True, interactive=True
            ),  # feature mapping preview
        )
    elif id2label_df is None:
        gr.Warning(
            'The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.'
        )
        return (
            gr.update(interactive=False),  # Submit button
            gr.update(visible=False),  # Loading row
            gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True),  # Preview row
            gr.update(
                value=f"**Sample Input**: {prediction_input}", visible=True
            ),  # Model prediction input
            gr.update(
                value=prediction_result, visible=True
            ),  # Model prediction preview
            gr.update(visible=True, interactive=True),  # Label mapping preview
            gr.update(visible=True, interactive=True),  # feature mapping preview
        )

    gr.Info(
        "Model and dataset validations passed. Your can submit the evaluation task."
    )

    return (
        gr.update(interactive=True),  # Submit button
        gr.update(visible=False),  # Loading row
        gr.update(CONFIRM_MAPPING_DETAILS_MD, visible=True),  # Preview row
        gr.update(
            value=f"**Sample Input**: {prediction_input}", visible=True
        ),  # Model prediction input
        gr.update(value=prediction_result, visible=True),  # Model prediction preview
        gr.update(
            value=id2label_df, visible=True, interactive=True
        ),  # Label mapping preview
        gr.update(
            value=feature_df, visible=True, interactive=True
        ),  # feature mapping preview
    )


def try_submit(
    m_id,
    d_id,
    config,
    split,
    id2label_mapping_dataframe,
    feature_mapping_dataframe,
    local,
):
    label_mapping = {}
    for i, label in id2label_mapping_dataframe["Model Prediction Labels"].items():
        label_mapping.update({str(i): label})

    feature_mapping = {}
    for i, feature in feature_mapping_dataframe["Dataset Features"].items():
        feature_mapping.update(
            {feature_mapping_dataframe["Model Input Features"][i]: feature}
        )

    # TODO: Set column mapping for some dataset such as `amazon_polarity`

    if local:
        command = [
            "giskard_scanner",
            "--loader",
            "huggingface",
            "--model",
            m_id,
            "--dataset",
            d_id,
            "--dataset_config",
            config,
            "--dataset_split",
            split,
            "--hf_token",
            os.environ.get(HF_WRITE_TOKEN),
            "--discussion_repo",
            os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
            "--output_format",
            "markdown",
            "--output_portal",
            "huggingface",
            "--feature_mapping",
            json.dumps(feature_mapping),
            "--label_mapping",
            json.dumps(label_mapping),
            "--scan_config",
            "../config.yaml",
        ]

        eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
        start = time.time()
        logging.info(f"Start local evaluation on {eval_str}")

        evaluator = subprocess.Popen(
            command,
            stderr=subprocess.STDOUT,
        )
        result = evaluator.wait()

        logging.info(
            f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s"
        )

        gr.Info(
            f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s"
        )
    else:
        gr.Info("TODO: Submit task to an endpoint")

    return gr.update(interactive=True)  # Submit button


def get_demo():
    # gr.themes.Soft(
    #     primary_hue="green",
    # )

    def check_dataset_and_get_config(dataset_id):
        try:
            configs = datasets.get_dataset_config_names(dataset_id)
            return gr.Dropdown(configs, value=configs[0], visible=True)
        except Exception:
            # Dataset may not exist
            pass

    def check_dataset_and_get_split(dataset_config, dataset_id):
        try:
            splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
            return gr.Dropdown(splits, value=splits[0], visible=True)
        except Exception as e:
            # Dataset may not exist
            gr.Warning(
                f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
            )

    def clear_column_mapping_tables():
        return [
            gr.update(CONFIRM_MAPPING_DETAILS_FAIL_MD, visible=True),
            gr.update(value=[], visible=False, interactive=True),
            gr.update(value=[], visible=False, interactive=True),
        ]

    def gate_validate_btn(
        model_id,
        dataset_id,
        dataset_config,
        dataset_split,
        id2label_mapping_dataframe=None,
        feature_mapping_dataframe=None,
    ):
        column_mapping = "{}"
        _, ppl = check_model(model_id=model_id)

        if id2label_mapping_dataframe is not None:
            labels = convert_column_mapping_to_json(
                id2label_mapping_dataframe.value, label="data"
            )
            features = convert_column_mapping_to_json(
                feature_mapping_dataframe.value, label="text"
            )
            column_mapping = json.dumps({**labels, **features}, indent=2)

        if check_column_mapping_keys_validity(column_mapping, ppl) is False:
            gr.Warning("Label mapping table has invalid contents. Please check again.")
            return (
                gr.update(interactive=False),
                gr.update(CONFIRM_MAPPING_DETAILS_FAIL_MD, visible=True),
                gr.update(),
                gr.update(),
                gr.update(),
                gr.update(),
                gr.update(),
            )
        else:
            if model_id and dataset_id and dataset_config and dataset_split:
                return try_validate(
                    model_id,
                    ppl,
                    dataset_id,
                    dataset_config,
                    dataset_split,
                    column_mapping,
                )
            else:
                return (
                    gr.update(interactive=False),
                    gr.update(visible=True),
                    gr.update(visible=False),
                    gr.update(visible=False),
                    gr.update(visible=False),
                    gr.update(visible=False),
                    gr.update(visible=False),
                )

    with gr.Row():
        gr.Markdown(CONFIRM_MAPPING_DETAILS_MD)
    with gr.Row():
        run_local = gr.Checkbox(value=True, label="Run in this Space")
        use_inference = read_inference_type("./config.yaml") == "hf_inference_api"
        run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")

    with gr.Row():
        selected = read_scanners("./config.yaml")
        scan_config = selected + ["data_leakage"]
        scanners = gr.CheckboxGroup(
            choices=scan_config, value=selected, label="Scan Settings", visible=True
        )

    with gr.Row():
        model_id_input = gr.Textbox(
            label="Hugging Face model id",
            placeholder="cardiffnlp/twitter-roberta-base-sentiment-latest",
        )

        dataset_id_input = gr.Textbox(
            label="Hugging Face Dataset id",
            placeholder="tweet_eval",
        )
    with gr.Row():
        dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
        dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)

    with gr.Row(visible=True) as loading_row:
        gr.Markdown(
            """
                    <p style="text-align: center;">
                    🚀🐢Please validate your model and dataset first...
                    </p>
                    """
        )

    with gr.Row(visible=False) as preview_row:
        gr.Markdown(
            """
            <h1 style="text-align: center;">
            Confirm Pre-processing Details
            </h1>
            Base on your model and dataset, we inferred this label mapping and feature mapping. <b>If the mapping is incorrect, please modify it in the table below.</b>
            """
        )

    with gr.Row():
        id2label_mapping_dataframe = gr.DataFrame(
            label="Preview of label mapping", interactive=True, visible=False
        )
        feature_mapping_dataframe = gr.DataFrame(
            label="Preview of feature mapping", interactive=True, visible=False
        )
    with gr.Row():
        example_input = gr.Markdown("Sample Input: ", visible=False)

    with gr.Row():
        example_labels = gr.Label(label="Model Prediction Sample", visible=False)

    run_btn = gr.Button(
        "Get Evaluation Result",
        variant="primary",
        interactive=False,
        size="lg",
    )

    model_id_input.blur(
        clear_column_mapping_tables,
        outputs=[id2label_mapping_dataframe, feature_mapping_dataframe],
    )

    dataset_id_input.blur(
        check_dataset_and_get_config, dataset_id_input, dataset_config_input
    )
    dataset_id_input.submit(
        check_dataset_and_get_config, dataset_id_input, dataset_config_input
    )

    dataset_config_input.change(
        check_dataset_and_get_split,
        inputs=[dataset_config_input, dataset_id_input],
        outputs=[dataset_split_input],
    )

    dataset_id_input.blur(
        clear_column_mapping_tables,
        outputs=[id2label_mapping_dataframe, feature_mapping_dataframe],
    )
    # model_id_input.blur(gate_validate_btn,
    #                         inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
    #                         outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
    # dataset_id_input.blur(gate_validate_btn,
    #                         inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
    # outputs=[run_btn, loading_row, preview_row, example_input,  example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
    dataset_config_input.change(
        gate_validate_btn,
        inputs=[
            model_id_input,
            dataset_id_input,
            dataset_config_input,
            dataset_split_input,
        ],
        outputs=[
            run_btn,
            loading_row,
            preview_row,
            example_input,
            example_labels,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
    )
    dataset_split_input.change(
        gate_validate_btn,
        inputs=[
            model_id_input,
            dataset_id_input,
            dataset_config_input,
            dataset_split_input,
        ],
        outputs=[
            run_btn,
            loading_row,
            preview_row,
            example_input,
            example_labels,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
    )
    id2label_mapping_dataframe.input(
        gate_validate_btn,
        inputs=[
            model_id_input,
            dataset_id_input,
            dataset_config_input,
            dataset_split_input,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
        outputs=[
            run_btn,
            loading_row,
            preview_row,
            example_input,
            example_labels,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
    )
    feature_mapping_dataframe.input(
        gate_validate_btn,
        inputs=[
            model_id_input,
            dataset_id_input,
            dataset_config_input,
            dataset_split_input,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
        outputs=[
            run_btn,
            loading_row,
            preview_row,
            example_input,
            example_labels,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
        ],
    )
    scanners.change(write_scanners, inputs=scanners)
    run_inference.change(write_inference_type, inputs=[run_inference])

    run_btn.click(
        try_submit,
        inputs=[
            model_id_input,
            dataset_id_input,
            dataset_config_input,
            dataset_split_input,
            id2label_mapping_dataframe,
            feature_mapping_dataframe,
            run_local,
        ],
        outputs=[
            run_btn,
        ],
    )