import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from share_btn import community_icon_html, loading_icon_html, share_js
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler

import lora
import copy
import json
import gc
import random
from urllib.parse import quote
import gdown
import os

import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler
import cv2
import torch
import numpy as np
from PIL import Image

from insightface.app import FaceAnalysis
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
from controlnet_aux import ZoeDetector

from compel import Compel, ReturnedEmbeddingsType

with open("sdxl_loras.json", "r") as file:
    data = json.load(file)
    sdxl_loras_raw = [
        {
            "image": item["image"],
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "likes": item.get("likes", 0),
            "downloads": item.get("downloads", 0),
            "is_nc": item.get("is_nc", False),
            "new": item.get("new", False),
        }
        for item in data
    ]

with open("defaults_data.json", "r") as file:
    lora_defaults = json.load(file)
    

device = "cuda" 

state_dicts = {}

for item in sdxl_loras_raw:
    saved_name = hf_hub_download(item["repo"], item["weights"])
    
    if not saved_name.endswith('.safetensors'):
        state_dict = torch.load(saved_name)
    else:
        state_dict = load_file(saved_name)
    
    state_dicts[item["repo"]] = {
        "saved_name": saved_name,
        "state_dict": state_dict
    }

sdxl_loras_raw_new = [item for item in sdxl_loras_raw if item.get("new") == True]

sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
    
# download models
hf_hub_download(
    repo_id="InstantX/InstantID",
    filename="ControlNetModel/config.json",
    local_dir="/data/checkpoints",
)
hf_hub_download(
    repo_id="InstantX/InstantID",
    filename="ControlNetModel/diffusion_pytorch_model.safetensors",
    local_dir="/data/checkpoints",
)
hf_hub_download(
    repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
)
hf_hub_download(
    repo_id="latent-consistency/lcm-lora-sdxl",
    filename="pytorch_lora_weights.safetensors",
    local_dir="/data/checkpoints",
)
# download antelopev2
if not os.path.exists("/data/antelopev2.zip"):
    gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
    os.system("unzip /data/antelopev2.zip -d /data/models/")

app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

# prepare models under ./checkpoints
face_adapter = f'/data/checkpoints/ip-adapter.bin'
controlnet_path = f'/data/checkpoints/ControlNetModel'

# load IdentityNet
identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("rubbrband/albedobaseXL_v21",
                                                                 vae=vae,
                                                                 controlnet=[identitynet, zoedepthnet],
                                                                 torch_dtype=torch.float16)

compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.set_ip_adapter_scale(0.8)
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe.to("cuda")

original_pipe = copy.deepcopy(pipe)
pipe.to(device)

last_lora = ""
last_merged = False
last_fused = False
js = '''
var button = document.getElementById('button');
// Add a click event listener to the button
button.addEventListener('click', function() {
    element.classList.add('selected');
});
'''
def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
    lora_repo = sdxl_loras[selected_state.index]["repo"]
    new_placeholder = "Type a prompt to use your selected LoRA"
    weight_name = sdxl_loras[selected_state.index]["weights"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"

    for lora_list in lora_defaults:
        if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
            face_strength = lora_list.get("face_strength", face_strength)
            image_strength = lora_list.get("image_strength", image_strength)
            weight = lora_list.get("weight", weight)
            depth_control_scale = lora_list.get("depth_control_scale", depth_control_scale)
            negative = lora_list.get("negative", negative)
    
    if(is_new):
        if(selected_state.index == 0):
            selected_state.index = -9999
        else:
            selected_state.index *= -1
    
    return (
        updated_text,
        gr.update(placeholder=new_placeholder),
        face_strength,
        image_strength,
        weight,
        depth_control_scale,
        negative,
        selected_state
    )

def center_crop_image_as_square(img):
    square_size = min(img.size)
    
    left = (img.width - square_size) / 2
    top = (img.height - square_size) / 2
    right = (img.width + square_size) / 2
    bottom = (img.height + square_size) / 2
    
    img_cropped = img.crop((left, top, right, bottom))
    return img_cropped
    
def check_selected(selected_state):
    if not selected_state:
        raise gr.Error("You must select a LoRA")

def merge_incompatible_lora(full_path_lora, lora_scale):
    for weights_file in [full_path_lora]:
                if ";" in weights_file:
                    weights_file, multiplier = weights_file.split(";")
                    multiplier = float(multiplier)
                else:
                    multiplier = lora_scale

                lora_model, weights_sd = lora.create_network_from_weights(
                    multiplier,
                    full_path_lora,
                    pipe.vae,
                    pipe.text_encoder,
                    pipe.unet,
                    for_inference=True,
                )
                lora_model.merge_to(
                    pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
                )
                del weights_sd
                del lora_model
                gc.collect()

def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
    global last_lora, last_merged, last_fused, pipe
    face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
    face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
    face_emb = face_info['embedding']
    face_kps = draw_kps(face_image, face_info['kps'])
    
    for lora_list in lora_defaults:
        if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
            prompt_full = lora_list["model"].get("prompt", None)
            if(prompt_full):
                prompt = prompt_full.replace("<subject>", prompt)

    
    print("Prompt:", prompt)    
    #prepare face zoe
    with torch.no_grad():
        image_zoe = zoe(face_image)
    
    width, height = face_kps.size
    images = [face_kps, image_zoe.resize((height, width))]
    
    
    #if(selected_state.index < 0):
    #    if(selected_state.index == -9999):
    #        selected_state.index = 0
    #    else:
    #         selected_state.index *= -1
        #sdxl_loras = sdxl_loras_new
    print("Selected State: ", selected_state.index)
    print(sdxl_loras[selected_state.index]["repo"])
    if negative == "":
        negative = None

    if not selected_state:
        raise gr.Error("You must select a LoRA")
    repo_name = sdxl_loras[selected_state.index]["repo"]
    weight_name = sdxl_loras[selected_state.index]["weights"]
    
    full_path_lora = state_dicts[repo_name]["saved_name"]
    loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
    cross_attention_kwargs = None
    print("Last LoRA: ", last_lora)
    print("Current LoRA: ", repo_name)
    print("Last fused: ", last_fused)
    if last_lora != repo_name:
        if(last_fused):
            pipe.unfuse_lora()
        pipe.unload_lora_weights()
        pipe.load_lora_weights(loaded_state_dict)
        pipe.fuse_lora()
        last_fused = True
        is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
        if(is_pivotal):
            #Add the textual inversion embeddings from pivotal tuning models
            text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
            embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
            state_dict_embedding = load_file(embedding_path)
            print(state_dict_embedding)
            try:
                pipe.unload_textual_inversion()
                pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
                pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
            except:
                pipe.unload_textual_inversion()
                pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
                pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)

    conditioning, pooled = compel(prompt)
    if(negative):
        negative_conditioning, negative_pooled = compel(negative)
    else:
        negative_conditioning, negative_pooled = None, None
        
    image = pipe(
        prompt_embeds=conditioning,
        pooled_prompt_embeds=pooled,
        negative_prompt_embeds=negative_conditioning,
        negative_pooled_prompt_embeds=negative_pooled,
        width=1024,
        height=1024,
        image_embeds=face_emb,
        image=face_image,
        strength=1-image_strength,
        control_image=images,
        num_inference_steps=20,
        guidance_scale = guidance_scale,
        controlnet_conditioning_scale=[face_strength, depth_control_scale],
    ).images[0]
    last_lora = repo_name
    gc.collect()
    return image, gr.update(visible=True)

def shuffle_gallery(sdxl_loras):
    random.shuffle(sdxl_loras)
    return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras

def swap_gallery(order, sdxl_loras):
    if(order == "random"):
        return shuffle_gallery(sdxl_loras)
    else:
        sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get(order, 0), reverse=True)
        return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery

def deselect():
  return gr.Gallery(selected_index=None)

with gr.Blocks(css="custom.css") as demo:
    gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
    gr_sdxl_loras_new = gr.State(value=sdxl_loras_raw_new)
    title = gr.HTML(
        """<h1>Face to All</h1>""",
        elem_id="title",
    )
    selected_state = gr.State()
    with gr.Row(elem_id="main_app"):
        with gr.Group(elem_id="gallery_box"):
            photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil")
            selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
            order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio")
            #new_gallery = gr.Gallery(
            #    label="New LoRAs",
            #    elem_id="gallery_new",
            #    columns=3,
            #    value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
            gallery = gr.Gallery(
                #value=[(item["image"], item["title"]) for item in sdxl_loras],
                label="SDXL LoRA Gallery",
                allow_preview=False,
                columns=3,
                elem_id="gallery",
                show_share_button=False,
                height=784
            )
        with gr.Column():
            prompt_title = gr.Markdown(
                value="### Click on a LoRA in the gallery to select it",
                visible=True,
                elem_id="selected_lora",
            )
            with gr.Row():
                prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, placeholder="A person", elem_id="prompt")
                button = gr.Button("Run", elem_id="run_button")
            with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
                community_icon = gr.HTML(community_icon_html)
                loading_icon = gr.HTML(loading_icon_html)
                share_button = gr.Button("Share to community", elem_id="share-btn")
            result = gr.Image(
                interactive=False, label="Generated Image", elem_id="result-image"
            )
            face_strength = gr.Slider(0, 1, value=0.85, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
            image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
            with gr.Accordion("Advanced options", open=False):
                negative = gr.Textbox(label="Negative Prompt")
                weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
                guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
                depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
                
    order_gallery.change(
        fn=swap_gallery,
        inputs=[order_gallery, gr_sdxl_loras],
        outputs=[gallery, gr_sdxl_loras],
        queue=False
    )
    gallery.select(
        fn=update_selection,
        inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
        outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
        queue=False,
        show_progress=False
    )
    #new_gallery.select(
    #    fn=update_selection,
    #    inputs=[gr_sdxl_loras_new, gr.State(True)],
    #    outputs=[prompt_title, prompt, prompt, selected_state, gallery],
    #    queue=False,
    #    show_progress=False
    #)
    prompt.submit(
        fn=check_selected,
        inputs=[selected_state],
        queue=False,
        show_progress=False
    ).success(
        fn=center_crop_image_as_square,
        inputs=[photo],
        outputs=[photo],
        queue=False,
        show_progress=False,
    ).success(
        fn=run_lora,
        inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
        outputs=[result, share_group],
    )
    button.click(
        fn=check_selected,
        inputs=[selected_state],
        queue=False,
        show_progress=False
    ).success(
        fn=center_crop_image_as_square,
        inputs=[photo],
        outputs=[photo],
        queue=False,
        show_progress=False,
    ).success(
        fn=run_lora,
        inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
        outputs=[result, share_group],
    )
    share_button.click(None, [], [], js=share_js)
    demo.load(fn=shuffle_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], queue=False, js=js)
demo.queue(max_size=20)
demo.launch(share=True)