import spaces
import gradio as gr
from diffusers import AutoPipelineForInpainting, AutoencoderKL
import torch
from PIL import Image, ImageOps

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")

def get_select_index(evt: gr.SelectData):
    return evt.index
    
@spaces.GPU()
def squarify_image(img):
    if(img.height > img.width): bg_size = img.height
    else:  bg_size = img.width
    bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
    bg.paste(img, ( int((bg.width - bg.width)/2), 0) )

    return bg

@spaces.GPU()
def divisible_by_8(image):
    width, height = image.size
    
    # Calculate the new width and height that are divisible by 8
    new_width = (width // 8) * 8
    new_height = (height // 8) * 8
    
    # Resize the image
    resized_image = image.resize((new_width, new_height))
    
    return resized_image

@spaces.GPU()
def restore_version(index, versions):
    print('restore version:', index)
    final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
    return final_dict

@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions):
    image = image_editor['background'].convert('RGB')

    # Resize image
    image.thumbnail((1024, 1024))
    image = divisible_by_8(image)
    original_image_size = image.size

    # Mask layer
    layer = image_editor["layers"][0].resize(image.size)

    # Make image a square
    image = squarify_image(image)

    # Make sure mask is white with a black background
    mask = Image.new("RGBA", image.size, "WHITE") 
    mask.paste(layer, (0, 0), layer)
    mask = ImageOps.invert(mask.convert('L'))

    # Inpaint
    pipeline.to("cuda")
    final_image = pipeline(prompt=prompt, 
                           image=image, 
                           mask_image=mask).images[0]


    # Make sure the longest side of image is 1024
    if (original_image_size[0] > original_image_size[1]):
        original_image_size = ( original_image_size[0] * (1024/original_image_size[0]) , original_image_size[1] * (1024/original_image_size[0]))
    else:
        original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))

    
    # Crop image to original aspect ratio
    final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))

    # gradio.ImageEditor requires a diction
    final_dict = {'background': final_image, 'layers': None, 'composite': final_image}

    # Add generated image to version gallery
    if(versions==None): 
        final_gallery = [image_editor['background'] ,final_image]
    else: 
        final_gallery = versions
        final_gallery.append(final_image)
    
    return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True)

with gr.Blocks() as demo:
    gr.Markdown("""
    # Inpainting SDXL Sketch Pad
    by [Tony Assi](https://www.tonyassi.com/)

    Please ❤️ this Space. I build custom AI apps for companies. <a href="mailto: tony.assi.media@gmail.com">Email me</a> for business inquiries.
    """)
    
    with gr.Row():
        with gr.Column():
            sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
            prompt = gr.Textbox(label="Prompt")
            generate_button = gr.Button("Generate")
            with gr.Accordion("Advanced Settings", open=False):
                neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
        with gr.Column():
            version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
            restore_button = gr.Button("Restore Version", visible=False)
            selected = gr.Number(show_label=False, visible=False)
            
    gr.Examples(
        [[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'tuexedo', 'ugly',  []]],
        [sketch_pad, prompt, neg_prompt, version_gallery],
        [sketch_pad, version_gallery, restore_button],
        generate,
        cache_examples=True,
    )

    version_gallery.select(get_select_index, None, selected)
    generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery], outputs=[sketch_pad, version_gallery, restore_button])
    restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)

demo.launch()