File size: 5,480 Bytes
4ed7a63
08fc0b2
62bd330
dd2b2f9
fec64c5
ce02dfa
43e72c6
 
 
 
08fc0b2
6f6d7fe
 
 
0467837
cfe97ad
bc11c81
 
6f6d7fe
1e716b6
a1deaea
6f6d7fe
 
 
 
 
a1deaea
4ed7a63
1e716b6
d16752a
 
6f6d7fe
 
d16752a
 
6f6d7fe
 
d16752a
6f6d7fe
d16752a
 
1e716b6
7139441
91e5c41
fd69a13
7139441
 
c0c66f4
b2a4599
c0c66f4
dde83e1
a92c6b1
ce02dfa
9539987
6f6d7fe
 
d16752a
 
965c284
6f6d7fe
 
d16752a
6f6d7fe
 
3e711ca
6f6d7fe
 
 
d16752a
9539987
6f6d7fe
 
4ed7a63
6f6d7fe
 
a92c6b1
 
6f6d7fe
 
 
 
 
93b65f4
3bec293
6f6d7fe
 
 
a11c80c
6f6d7fe
 
eec3e26
6f6d7fe
 
 
 
 
10e906d
 
6f6d7fe
ce02dfa
 
6f6d7fe
e6e1231
08fc0b2
 
 
0143c38
d040427
0143c38
08fc0b2
 
 
f0cce29
429369b
e0d2fcf
a92c6b1
e0fc88a
940b814
a92c6b1
6b87482
ba2bf14
3d2adfb
44e13e8
2395f61
de7dbc9
1e716b6
 
 
 
 
 
 
bc11c81
 
a92c6b1
7139441
c0c66f4
08fc0b2
6f6d7fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import spaces
import gradio as gr
from diffusers import StableDiffusion3InpaintPipeline, AutoencoderKL
import torch
from PIL import Image, ImageOps
import time
from huggingface_hub import login
import os

login(token=os.getenv("HF_TOKEN"))

# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
# pipeline = StableDiffusion3InpaintPipeline(vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")

pipeline = StableDiffusion3InpaintPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)

def get_select_index(evt: gr.SelectData):
    return evt.index
    
# @spaces.GPU()
def squarify_image(img):
    if(img.height > img.width): bg_size = img.height
    else:  bg_size = img.width
    bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
    bg.paste(img, ( int((bg.width - bg.width)/2), 0) )

    return bg

# @spaces.GPU()
def divisible_by_8(image):
    width, height = image.size
    
    # Calculate the new width and height that are divisible by 8
    new_width = (width // 8) * 8
    new_height = (height // 8) * 8
    
    # Resize the image
    resized_image = image.resize((new_width, new_height))
    
    return resized_image

# @spaces.GPU()
def restore_version(index, versions):
    print('restore version:', index)
    final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
    return final_dict

def clear_all():
    return gr.update(value=None), gr.update(value=None), gr.update(value=[], visible=False), gr.update(visible=False), gr.update(visible=False)

@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions, num_inference_steps):
    start = time.time()
    image = image_editor['background'].convert('RGB')

    # Resize image
    image.thumbnail((1024, 1024))
    image = divisible_by_8(image)
    original_image_size = image.size

    # Mask layer
    layer = image_editor["layers"][0].resize(image.size)

    # Make image a square
    image = squarify_image(image)

    # Make sure mask is white with a black background
    mask = Image.new("RGBA", image.size, "WHITE") 
    mask.paste(layer, (0, 0), layer)
    mask = ImageOps.invert(mask.convert('L'))

    # Inpaint
    pipeline.to("cuda")
    final_image = pipeline(prompt=prompt, 
                           image=image, 
                           mask_image=mask,
                           num_inference_steps=num_inference_steps).images[0]


    # Make sure the longest side of image is 1024
    if (original_image_size[0] > original_image_size[1]):
        original_image_size = ( original_image_size[0] * (1024/original_image_size[0]) , original_image_size[1] * (1024/original_image_size[0]))
    else:
        original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))

    
    # Crop image to original aspect ratio
    final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))

    # gradio.ImageEditor requires a diction
    final_dict = {'background': final_image, 'layers': None, 'composite': final_image}

    # Add generated image to version gallery
    if(versions==None): 
        final_gallery = [image_editor['background'] ,final_image]
    else: 
        final_gallery = versions
        final_gallery.append(final_image)

    end = time.time()
    print('time:', end - start)
    
    return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True), gr.update(visible=True)

with gr.Blocks() as demo:
    gr.Markdown("""
    # Inpainting SD3 Sketch Pad

    Please ❤️ this Space
    """)
    
    with gr.Row():
        with gr.Column():
            sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
            prompt = gr.Textbox(label="Prompt")
            generate_button = gr.Button(value="Inpaint", variant="primary")
            with gr.Accordion("Advanced Settings", open=False):
                neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
                num_inference_steps = gr.Slider(minimum = 10, maximum = 100, value = 30, step = 1, label = "Number of inference steps", info = "lower=faster, higher=image quality")
        with gr.Column():
            version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
            restore_button = gr.Button("Restore Version", visible=False)
            clear_button = gr.Button('Clear', visible=False)
            selected = gr.Number(show_label=False, visible=False)
            
    # gr.Examples(
    #     [[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'black and white tuxedo, bowtie', 'ugly',  None]],
    #     [sketch_pad, prompt, neg_prompt, version_gallery],
    #     [sketch_pad, version_gallery, restore_button, clear_button],
    #     generate,
    #     cache_examples=True,
    # )

    version_gallery.select(get_select_index, None, selected)
    generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery, num_inference_steps], outputs=[sketch_pad, version_gallery, restore_button, clear_button])
    restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)
    clear_button.click(clear_all, inputs=None, outputs=[sketch_pad, prompt, version_gallery, restore_button, clear_button])

demo.launch()