import json
from typing import List, Optional, Union

import argilla as rg
import gradio as gr
import numpy as np
import pandas as pd
from gradio.oauth import (
    OAuthToken,
    get_space,
)
from huggingface_hub import whoami
from jinja2 import Environment, meta

from distilabel_dataset_generator.constants import argilla_client


def get_duplicate_button():
    if get_space() is not None:
        return gr.DuplicateButton(size="lg")


def list_orgs(oauth_token: Union[OAuthToken, None] = None):
    try:
        if oauth_token is None:
            return []
        data = whoami(oauth_token.token)
        if data["auth"]["type"] == "oauth":
            organizations = [data["name"]] + [org["name"] for org in data["orgs"]]
        elif data["auth"]["type"] == "access_token":
            organizations = [org["name"] for org in data["orgs"]]
        else:
            organizations = [
                entry["entity"]["name"]
                for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
                if "repo.write" in entry["permissions"]
            ]
            organizations = [org for org in organizations if org != data["name"]]
            organizations = [data["name"]] + organizations
    except Exception as e:
        raise gr.Error(
            f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
        )
    return organizations


def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
    if oauth_token is not None:
        orgs = list_orgs(oauth_token)
    else:
        orgs = []
    return gr.Dropdown(
        label="Organization",
        choices=orgs,
        value=orgs[0] if orgs else None,
        allow_custom_value=True,
        interactive=True,
    )


def swap_visibility(oauth_token: Union[OAuthToken, None]):
    if oauth_token:
        return gr.update(elem_classes=["main_ui_logged_in"])
    else:
        return gr.update(elem_classes=["main_ui_logged_out"])


def get_argilla_client() -> Union[rg.Argilla, None]:
    return argilla_client


def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
    return list(set([label.lower().strip() for label in labels])) if labels else []


def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]:
    if column_name in dataframe.columns:
        return dataframe[column_name].tolist()
    else:
        raise ValueError(f"Column '{column_name}' does not exist.")


def process_columns(
    dataframe,
    instruction_column: str,
    response_columns: Union[str, List[str]],
) -> List[dict]:
    instruction_column = [instruction_column]
    if isinstance(response_columns, str):
        response_columns = [response_columns]

    data = []
    for _, row in dataframe.iterrows():
        instruction = ""
        for col in instruction_column:
            value = row[col]
            if isinstance(value, (list, np.ndarray)):
                user_contents = [d["content"] for d in value if d.get("role") == "user"]
                if user_contents:
                    instruction = user_contents[-1]
            elif isinstance(value, str):
                try:
                    parsed_message = json.loads(value)
                    user_contents = [
                        d["content"] for d in parsed_message if d.get("role") == "user"
                    ]
                    if user_contents:
                        instruction = user_contents[-1]
                except json.JSONDecodeError:
                    instruction = value
            else:
                instruction = ""

        generations = []
        for col in response_columns:
            value = row[col]
            if isinstance(value, (list, np.ndarray)):
                if all(isinstance(item, dict) and "role" in item for item in value):
                    assistant_contents = [
                        d["content"] for d in value if d.get("role") == "assistant"
                    ]
                    if assistant_contents:
                        generations.append(assistant_contents[-1])
                else:
                    generations.extend(value)
            elif isinstance(value, str):
                try:
                    parsed_message = json.loads(value)
                    assistant_contents = [
                        d["content"]
                        for d in parsed_message
                        if d.get("role") == "assistant"
                    ]
                    if assistant_contents:
                        generations.append(assistant_contents[-1])
                except json.JSONDecodeError:
                    generations.append(value)
            else:
                pass

        data.append({"instruction": instruction, "generations": generations})

    return data


def extract_column_names(prompt_template: str) -> List[str]:
    env = Environment()
    parsed_content = env.parse(prompt_template)
    variables = meta.find_undeclared_variables(parsed_content)
    return list(variables)


def pad_or_truncate_list(lst, target_length):
    lst = lst or []
    lst_length = len(lst)
    if lst_length >= target_length:
        return lst[-target_length:]
    else:
        return lst + [None] * (target_length - lst_length)