import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import torch

torch.jit.script = lambda f: f
from diffusers import (
    ControlNetModel,
    StableDiffusionXLControlNetImg2ImgPipeline,
    DDIMScheduler,
)
from controlnet_aux import AnylineDetector
from compel import Compel, ReturnedEmbeddingsType
from PIL import Image
import os
import time
import numpy as np

from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
from utils.prompt_analysis import PromptAnalysis

path = os.getcwd()
cn_dir = f"{path}/controlnet"
os.makedirs(cn_dir)
tagger_dir = f"{path}/tagger"
os.mkdir(tagger_dir)
lora_dir = f"{path}/lora"
os.mkdir(lora_dir)

load_cn_model(cn_dir)
load_cn_config(cn_dir)
load_tagger_model(tagger_dir)
load_lora_model(lora_dir)

IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"

print(f"device: {device}")
print(f"dtype: {dtype}")
print(f"low memory: {LOW_MEMORY}")


model = "cagliostrolab/animagine-xl-3.1"
scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
    model,
    controlnet=controlnet,
    torch_dtype=dtype,
    use_safetensors=True,
    scheduler=scheduler,
)

pipe.load_lora_weights(
    lora_dir, 
    weight_name="sdxl_BWLine.safetensors"
)


compel = Compel(
    tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
    text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True],
)
pipe = pipe.to(device)



@spaces.GPU
def predict(
    input_image_path,
    prompt,
    negative_prompt,
    controlnet_conditioning_scale,
):
    input_image_pil = Image.open(input_image_path)
    base_size =input_image_pil.size
    resize_image= resize_image_aspect_ratio(input_image_pil)
    resize_image_size = resize_image.size
    width = resize_image_size[0]
    height = resize_image_size[1]
    white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
    conditioning, pooled = compel([prompt, negative_prompt])
    generator = torch.manual_seed(0)
    last_time = time.time()

    output_image = pipe(
        image=white_base_pil,
        control_image=resize_image,
        strength=1.0,
        prompt_embeds=conditioning[0:1],
        pooled_prompt_embeds=pooled[0:1],
        negative_prompt_embeds=conditioning[1:2],
        negative_pooled_prompt_embeds=pooled[1:2],
        width=width,
        height=height,
        controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        controlnet_start=0.0,
        controlnet_end=1.0,
        generator=generator,
        num_inference_steps=30,
        guidance_scale=8.5,
        eta=1.0,
    )
    print(f"Time taken: {time.time() - last_time}")
    output_image = output_image.resize(base_size, Image.LANCZOS)
    return output_image


css = """
#intro{
    # max-width: 32rem;
    # text-align: center;
    # margin: 0 auto;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row() as block:
        with gr.Column():
            # 画像アップロード用の行
            with gr.Row():
                with gr.Column():
                    input_image_path = gr.Image(label="入力画像",  type='filepath')
            
            # プロンプト入力用の行
            with gr.Row():
                prompt_analysis = PromptAnalysis(tagger_dir)
                [prompt, nega] = prompt_analysis.layout(input_image_path)           
            # 画像の詳細設定用のスライダー行
            with gr.Row():
                controlnet_conditioning_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, interactive=True, label="線画忠実度")
           
            # 画像生成ボタンの行
            with gr.Row():
                generate_button = gr.Button("生成", interactive=False)

        with gr.Column():
            output_image = gr.Image(type="pil", label="Output Image")

        # インプットとアウトプットの設定
        inputs = [
            input_image_path,
            prompt,
            nega,
            controlnet_conditioning_scale,
        ]
        outputs = [output_image]
        
        # ボタンのクリックイベントを設定
        generate_button.click(
            fn=predict,
            inputs=[input_image_path, prompt, nega, controlnet_conditioning_scale],
            outputs=[output_image]
        )

# デモの設定と起動
demo.queue(api_open=True)
demo.launch(show_api=True)