import os
import sys
import pdb
import random
import numpy as np
from PIL import Image, ImageOps, ImageChops
import base64
from io import BytesIO

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
import gradio as gr

from src.model import make_1step_sched
from src.pix2pix_turbo import Pix2Pix_Turbo

model = Pix2Pix_Turbo("sketch_to_image_stochastic")
 
ITEMS_NAMES = [ "💡 Lamp","👜 Bag","🛋️ Sofa","🪑 Chair","🏎️ Car","🏍️ Motorbike"]
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_ITEM_NAME = "💡 Lamp"
def empty_input_image(): 
    return { 'background': Image.new("L", (512, 512), 255),
             'layers': [Image.new("L", (512, 512), 255),Image.new("L", (512, 512), 255)],
             'composite': Image.new("L", (512, 512), 255)}

def pil_image_to_data_uri(img, format='PNG'):
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"


def run(image, item_name):
    print("sketch updated")
    print(image)
    empty_image = Image.new("L", (512, 512), 255)
    diff = ImageChops.difference(image["composite"], empty_image)
    # if image["composite"] is None:
    if not diff.getbbox():
        ones = empty_image
        return ones
    print(item_name.split()[1])
    prompt = item_name.split()[1] + " professional 3d model. octane render, highly detailed, volumetric, dramatic lighting"
    inverted_image = ImageOps.invert(image["composite"])
    converted_image = inverted_image.convert("RGB")
    image_t = TF.to_tensor(converted_image) > 0.5
    with torch.no_grad():
        c_t = image_t.unsqueeze(0).cuda().float()
        torch.manual_seed(42)
        B,C,H,W = c_t.shape
        noise = torch.randn((1,4,H//8, W//8), device=c_t.device)
        output_image = model(c_t, prompt, deterministic=False, r=0.4, noise_map=noise)
    output_pil = TF.to_pil_image(output_image[0].cpu()*0.5+0.5)
    return output_pil


def update_canvas(use_line, use_eraser):
    if use_eraser:
        _color = "#ffffff"
        brush_size = 20
    if use_line:
        _color = "#000000"
        brush_size = 4
    return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)


def upload_sketch(file):
    _img = Image.open(file.name)
    _img = _img.convert("L")
    return gr.update(value=_img, source="upload", interactive=True)


scripts = """
async () => {
    


    globalThis.theSketchDownloadFunction = () => {
        console.log("test")
        var link = document.createElement("a");
        dataUri = document.getElementById('download_sketch').href
        link.setAttribute("href", dataUri)
        link.setAttribute("download", "sketch.png")
        document.body.appendChild(link); // Required for Firefox
        link.click();
        document.body.removeChild(link); // Clean up
      
        // also call the output download function
        theOutputDownloadFunction();
      return false
    }

    globalThis.theOutputDownloadFunction = () => {
        console.log("test output download function")
        var link = document.createElement("a");
        dataUri = document.getElementById('download_output').href
        link.setAttribute("href", dataUri);
        link.setAttribute("download", "output.png");
        document.body.appendChild(link); // Required for Firefox
        link.click();
        document.body.removeChild(link); // Clean up
      return false
    }



    globalThis.DELETE_SKETCH_FUNCTION = () => {
        console.log("delete sketch function")
        var button_del = document.querySelector('#input_image > div.image-container.svelte-1sbaaot > div.controls-wrap.svelte-4lttvb > div > button:nth-child(3)');
        // Create a new 'click' event
        var event = new MouseEvent('click', {
            'view': window,
            'bubbles': true,
            'cancelable': true
        });
        button_del.dispatchEvent(event);
    }

    globalThis.togglePencil = () => {
        el_pencil = document.getElementById('my-toggle-pencil');
        el_pencil.classList.toggle('clicked');
        // simulate a click on the gradio button
        btn_gradio = document.querySelector("#cb-line > label > input");
        var event = new MouseEvent('click', {
            'view': window,
            'bubbles': true,
            'cancelable': true
        });
        btn_gradio.dispatchEvent(event);
        if (el_pencil.classList.contains('clicked')) {
            document.getElementById('my-toggle-eraser').classList.remove('clicked');
            document.getElementById('my-div-pencil').style.backgroundColor = "gray";
            document.getElementById('my-div-eraser').style.backgroundColor = "white";
        }
        else {
            document.getElementById('my-toggle-eraser').classList.add('clicked');
            document.getElementById('my-div-pencil').style.backgroundColor = "white";
            document.getElementById('my-div-eraser').style.backgroundColor = "gray";
        }
        
    }

    globalThis.toggleEraser = () => {
        element = document.getElementById('my-toggle-eraser');
        element.classList.toggle('clicked');
        // simulate a click on the gradio button
        btn_gradio = document.querySelector("#cb-eraser > label > input");
        var event = new MouseEvent('click', {
            'view': window,
            'bubbles': true,
            'cancelable': true
        });
        btn_gradio.dispatchEvent(event);
        if (element.classList.contains('clicked')) {
            document.getElementById('my-toggle-pencil').classList.remove('clicked');
            document.getElementById('my-div-pencil').style.backgroundColor = "white";
            document.getElementById('my-div-eraser').style.backgroundColor = "gray";
        }
        else {
            document.getElementById('my-toggle-pencil').classList.add('clicked');
            document.getElementById('my-div-pencil').style.backgroundColor = "gray";
            document.getElementById('my-div-eraser').style.backgroundColor = "white";
        }
    }
}
"""
head="""<meta name="theme-color" content="#000"><link href="https://fonts.cdnfonts.com/css/pp-neue-montreal" rel="stylesheet">"""

with gr.Blocks(css="style.css", head = head) as demo:
    gr.HTML("""<div id="header_block">
      <h1>Dai forma al nuovo<br />design Made in Italy</h1>
      <div id="logos_block">
          <img id="logos_row" src="file=assets/logos.png" alt="logo" />
        <div id="text_row">
          <span>krnl.ai</span><span>//</span
          ><span>eccellenza-italiana.com</span>
        </div>
      </div>
    </div>""")
    with gr.Column(elem_id="main_block"):
        with gr.Row(elem_id="board_row"):
            with gr.Group(elem_id="input_image_container", elem_classes="image_container" ):
                image = gr.Sketchpad(type="pil", image_mode="L",container=False, height="100%", width="100%", value = empty_input_image,
                    brush = gr.Brush(default_size="3", colors=["#000000"], color_mode="fixed"), layers = False,
                    # invert_colors=True, shape=(512, 512), brush_radius=4, 
                    interactive=True, show_download_button=True, elem_id="input_image", show_label=False)
                gr.HTML("""<img src="file=assets/drawCta.png" id="draw_cta" alt="draw here image" />""",elem_id="draw_cta_container")
                gr.HTML("""<button id="eraser" onclick="return DELETE_SKETCH_FUNCTION(this)">
                    <span id="eraser_icon"></span>
                    </button>""",elem_id="eraser_container")
            with gr.Group(elem_id="output_image_container", elem_classes="image_container"):
                result = gr.Image(label="Result",  height="100%", width="100%", elem_id="output_image", show_label=False, show_download_button=True,container=False,)
    with gr.Row(elem_id="radio_row"):
        item = gr.Radio(choices=ITEMS_NAMES, value=DEFAULT_ITEM_NAME, show_label=False, container=False)

    demo.load(None,None,None,js=scripts)
    inputs = [image, item]
    outputs = [result]
    item.change(fn=run, inputs=inputs, outputs=outputs)
    image.change(fn=run, inputs=inputs, outputs=outputs, trigger_mode="always_last")
    image.clear(fn=empty_input_image, outputs=image)
if __name__ == "__main__":
    demo.queue().launch(debug=True, allowed_paths=["."])