import os
import io

import torch
import json
import base64
import gradio as gr
import numpy as np
from pathlib import Path
from PIL import Image

from plots import get_pre_define_colors
from utils.load_model import load_xclip
from utils.predict import xclip_pred


#! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Not at Huggingface demo, load model to main process.")
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)

print(f"Device: {DEVICE}")

XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
IMAGES_FOLDER = "data/images"
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}

IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r"))
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST]

ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10'])
SACHIT_COLOR = "#ADD8E6"
# CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r"))
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r'))
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12))

# --- Image related functions ---
def img_to_base64(img):
    img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img
    buffered = io.BytesIO()
    img_pil.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue())
    return img_str.decode()

def create_blank_image(width=500, height=500, color=(255, 255, 255)):
    """Create a blank image of the given size and color."""
    return np.array(Image.new("RGB", (width, height), color))

# Convert RGB colors to hex
def rgb_to_hex(rgb):
    return f"#{''.join(f'{x:02x}' for x in rgb)}"

def load_part_images(file_name: str) -> dict:
    part_images = {}
    # start_time = time.time()
    for part_name in ORDERED_PARTS:
        base_name = Path(file_name).stem
        part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg")
        if not Path(part_image_path).exists():
            continue
        image = np.array(Image.open(part_image_path))
        part_images[part_name] = img_to_base64(image)
    # print(f"Time cost to load 12 images: {time.time() - start_time}")
    # This takes less than 0.01 seconds. So the loading time is not the bottleneck.
    return part_images

def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))):
    """
    The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name'
    descriptions: {part_name1: desc_1, part_name2: desc_2, ...}
    pred_scores: {part_name1: score_1, part_name2: score_2, ...}
    file_name: str
    """
    
    descriptions = result_dict['descriptions']
    image_name = result_dict['file_name']
    part_images = PART_IMAGES_DICT[image_name]
    MAX_LENGTH = 50
    exp_length = 400
    fontsize = 15

    # Start the SVG inside a div
    svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
                 "<svg width=\"100%\" height=\"100%\">"]

    # Add a row for each visible bird part
    y_offset = 0
    for part in ORDERED_PARTS:
        if visibility[part] and part_mask[part]:
            # Calculate the length of the bar (scaled to fit within the SVG)
            part_score = max(result_dict['pred_scores'][part], 0)
            bar_length = part_score * exp_length

            # Modify the overlay image's opacity on mouseover and mouseout
            mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;"
            mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;"

            combined_mouseover = f"javascript: {mouseover_action1};"
            combined_mouseout = f"javascript: {mouseout_action1};"

            # Add the description
            num_lines = len(descriptions[part]) // MAX_LENGTH + 1
            for line in range(num_lines):
                desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH]
                y_offset += fontsize
                svg_parts.append(f"""
                <text x="0" y="{y_offset}" font-size="{fontsize}" 
                    onmouseover="{combined_mouseover}"
                    onmouseout="{combined_mouseout}">
                    {desc_line}
                </text>
                """)

            # Add the bars
            svg_parts.append(f"""
            <rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}"
                onmouseover="{combined_mouseover}"
                onmouseout="{combined_mouseout}">
            </rect>
            """)
            # Add the scores
            svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>')

            y_offset += fontsize + 3
    svg_parts.extend(("</svg>", "</div>"))
    # Join everything into a single string
    html = "".join(svg_parts)


    return html



def generate_sachit_explanations(result_dict:dict):
    descriptions = result_dict['descriptions']
    scores = result_dict['scores']
    MAX_LENGTH = 50
    exp_length = 400
    fontsize = 15

    descriptions = zip(scores, descriptions)
    descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True)

    # Start the SVG inside a div
    svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
                 "<svg width=\"100%\" height=\"100%\">"]

    # Add a row for each visible bird part
    y_offset = 0
    for score, desc in descriptions:

        # Calculate the length of the bar (scaled to fit within the SVG)
        part_score = max(score, 0)
        bar_length = part_score * exp_length

        # Split the description into two lines if it's too long
        num_lines = len(desc) // MAX_LENGTH + 1
        for line in range(num_lines):
            desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH]
            y_offset += fontsize
            svg_parts.append(f"""
            <text x="0" y="{y_offset}" font-size="{fontsize}" fill="black">
                {desc_line}
            </text>
            """)

        # Add the bar
        svg_parts.append(f"""
        <rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}">
        </rect>
        """)

        # Add the score
        svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') # Added fill color

        y_offset += fontsize + 3


    svg_parts.extend(("</svg>", "</div>"))
    # Join everything into a single string
    html = "".join(svg_parts)


    return html

# --- Constants created by the functions above ---
BLANK_OVERLAY = img_to_base64(create_blank_image())
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)}
blank_image = np.array(Image.open('data/images/final.png').convert('RGB'))
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST}

# --- Gradio Functions ---
def update_selected_image(event: gr.SelectData):
    image_height = 400
    index = event.index

    image_name = IMAGE_FILE_LIST[index]
    current_image.state = image_name
    org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB')
    img_base64 = f"""
    <div style="position: relative; height: {image_height}px; display: inline-block;">
        <img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;">
        <img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;">
    </div>
    """
    gt_label = IMAGE2GT[image_name]
    gt_class.state = gt_label

    # --- for initial value only ---
    out_dict = xclip_pred(new_desc=None, 
                          new_part_mask=None, 
                          new_class=None, 
                          org_desc=XCLIP_DESC_PATH, 
                          image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), 
                          model=XCLIP, 
                          owlvit_processor=OWLVIT_PRECESSOR, 
                          device=DEVICE, 
                          image_name=current_image.state,
                          cub_embeds=CUB_DESC_EMBEDS,
                          cub_idx2name=CUB_IDX2NAME,
                          descriptors=XCLIP_DESC)
    xclip_label = out_dict['pred_class']
    clip_pred_scores = out_dict['pred_score']
    xclip_part_scores = out_dict['pred_desc_scores']
    result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
    xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12)))
    # --- end of intial value ---
    
    xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red"
    xclip_pred_markdown = f"""
        ### <span style='color:{xclip_color}'>{xclip_label} &nbsp;&nbsp;&nbsp; {clip_pred_scores:.4f}</span>
    """

    gt_label = f"""
        ## {gt_label}
    """
    current_predicted_class.state = xclip_label
    
    # Populate the textbox with current descriptions
    custom_class_name = "class name: custom"
    descs = XCLIP_DESC[xclip_label]
    descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
    descs = {k: descs[k] for k in ORDERED_PARTS}
    custom_text = [custom_class_name] + list(descs.values())
    descriptions = ";\n".join(custom_text)
    # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
    textbox = gr.Textbox(value=descriptions, 
                     lines=12, 
                     visible=True, 
                     label="XCLIP descriptions", 
                     interactive=True, 
                     info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', 
                     show_label=False)
    # modified_exp = gr.HTML().update(value="", visible=True)
    return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox

def on_edit_button_click_xclip():
    # empty_exp = gr.HTML.update(visible=False)
    empty_exp = gr.HTML(visible=False)

    # Populate the textbox with current descriptions
    descs = XCLIP_DESC[current_predicted_class.state]
    descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
    descs = {k: descs[k] for k in ORDERED_PARTS}
    custom_text = ["class name: custom"] + list(descs.values())
    descriptions = ";\n".join(custom_text)
    # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
    textbox = gr.Textbox(value=descriptions,
                         lines=12,
                            visible=True,
                            label="XCLIP descriptions",
                            interactive=True,
                            info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
                            show_label=False)
    
    return textbox, empty_exp

def convert_input_text_to_xclip_format(textbox_input: str):

    # Split the descriptions by newline to get individual descriptions for each part
    descriptions_list = textbox_input.split(";\n")
    # the first line should be "class name: xxx"
    class_name_line = descriptions_list[0]
    new_class_name = class_name_line.split(":")[1].strip()
    
    descriptions_list = descriptions_list[1:]
    
    # construct descripion dict with part name as key
    descriptions_dict = {}
    for desc in descriptions_list:
        if desc.strip() == "":
            continue
        part_name, _ = desc.split(":")
        descriptions_dict[part_name.strip()] = desc
    # fill with empty string if the part is not in the descriptions
    part_mask = {}
    for part in ORDERED_PARTS:
        if part not in descriptions_dict:
            descriptions_dict[part] = ""
            part_mask[part] = 0
        else:
            part_mask[part] = 1
    return descriptions_dict, part_mask, new_class_name

def on_predict_button_click_xclip(textbox_input: str):
    descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
    
    # Get the new predictions and explanations
    out_dict = xclip_pred(new_desc=descriptions_dict, 
                          new_part_mask=part_mask, 
                          new_class=new_class_name, 
                          org_desc=XCLIP_DESC_PATH, 
                          image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), 
                          model=XCLIP, 
                          owlvit_processor=OWLVIT_PRECESSOR, 
                          device=DEVICE, 
                          image_name=current_image.state,
                          cub_embeds=CUB_DESC_EMBEDS,
                          cub_idx2name=CUB_IDX2NAME,
                          descriptors=XCLIP_DESC)
    xclip_label = out_dict['pred_class']
    xclip_pred_score = out_dict['pred_score']
    xclip_part_scores = out_dict['pred_desc_scores']
    custom_label = out_dict['modified_class']
    custom_pred_score = out_dict['modified_score']
    custom_part_scores = out_dict['modified_desc_scores']

    # construct a result dict to generate xclip explanations
    result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
    xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask)
    modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state}
    modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask)

    xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red"
    xclip_pred_markdown = f"""
        ### <span style='color:{xclip_color}'> {xclip_label} &nbsp;&nbsp;&nbsp; {xclip_pred_score:.4f}</span>
    """
    custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red"
    custom_pred_markdown = f"""
        ### <span style='color:{custom_color}'> {custom_label} &nbsp;&nbsp;&nbsp; {custom_pred_score:.4f}</span>
    """
    # textbox = gr.Textbox.update(visible=False)
    textbox = gr.Textbox(visible=False)
    # return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
    
    # modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
    modified_exp = gr.HTML(value=modified_explanation, visible=True)
    return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp


custom_css = """
        html, body {
            margin: 0;
            padding: 0;
        }

        #container {
            position: relative;
            width: 400px;
            height: 400px;
            border: 1px solid #000;
            margin: 0 auto; /* This will center the container horizontally */
        }

        #canvas {
            position: absolute;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            object-fit: cover;
        }

"""

# Define the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
    current_image = gr.State("")
    current_predicted_class = gr.State("")
    gt_class = gr.State("")
    
    with gr.Column():
        title_text = gr.Markdown("# Demo | A classifier with Part-based Explainable and Editable Bottleneck (PEEB)")
        gr.Markdown("PEEB is an image classifier, here for birds, pre-trained on Bird-11K and finetuned on CUB-200 (see our [NAACL 2024 paper](https://arxiv.org/abs/2403.05297) and [code](https://github.com/anguyen8/peeb/tree/inspect_ddp)).\n This **interactive** demo shows how to run PEEB on an existing image and how to **edit** a class' textual description to directly change the classifier to detect one new bird species (without any re-training).")
        gr.Markdown(
            """
            ### Steps:
            1. **Select an image**. Then, PEEB will show its grounded explanations and the top-1 predicted label with associated `softmax` confidence score.
            2. **Hover mouse over text descriptors** to see the corresponding region used to match to each text descriptor.
            3. **Edit the text under [Extra class]()** which correspond to one extra, new class (i.e. 200+1 = `201`). Further editing will overwrite this class' descriptors.
            4. **Click on [Predict]()** to see the grounded explanations and the top-1 label for the newly modified CUB-201 classifier.            
            """
        )

    # display the gallery of images
    with gr.Column():
        
        gr.Markdown("## Select an image to start!")
        image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250)
        gr.Markdown("### Extra-class descriptors: \n The first row should be `class name: {some name};`, the name of your 201th class. \n For the 12 part descriptors, please use `;` to separate the descriptions for each part, and use the format `{part name}: {descriptions}`.")
        gr.Markdown("**Note:** you can delete a row for any given part (e.g. `nape`) and that part will be removed from all 201 classes in the classifier. For example, you can edit PEEB into a classifier that only identifies birds using 5 parts by deleting all rows corresponding to the other 7 parts.")
        
        with gr.Row():
            with gr.Column():
                image_label = gr.Markdown("### Class Name")
                org_image = gr.HTML()
            
            with gr.Column():
                with gr.Row():
                    # xclip_predict_button = gr.Button(label="Predict", value="Predict")
                    xclip_predict_button = gr.Button(value="Predict")
                xclip_pred_label = gr.Markdown("### Top-1 class:")
                xclip_explanation = gr.HTML()

            with gr.Column():
                # xclip_edit_button = gr.Button(label="Edit", value="Reset Extra-class descriptors")
                xclip_edit_button = gr.Button(value="Reset Descriptions")
                custom_pred_label = gr.Markdown(
                    "### Extra class:"
                )
                xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False)
                # ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500)
                custom_explanation = gr.HTML()

    gr.HTML("<br>")

    image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox])
    xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
    xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])

demo.launch()