File size: 11,909 Bytes
a037dba
 
 
9e02880
458c731
5b5bb22
a037dba
458c731
8c66442
0526f37
a037dba
 
 
 
 
 
 
 
 
 
08ee6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a037dba
 
 
2717395
a037dba
 
 
 
 
 
 
 
458c731
 
 
 
 
 
 
 
0526f37
 
 
a037dba
0526f37
a037dba
458c731
a037dba
 
 
 
458c731
 
 
 
29ed86a
458c731
 
c5cc5b8
0526f37
2c3d3ff
 
c5cc5b8
 
 
458c731
 
29ed86a
458c731
 
2c3d3ff
 
458c731
 
 
 
 
 
 
 
 
08ee6b2
458c731
0526f37
2717395
a037dba
 
 
458c731
a037dba
 
 
 
 
 
08ee6b2
458c731
 
0526f37
a037dba
89a2b44
0526f37
 
8c66442
0526f37
a037dba
 
 
 
9111970
452e572
47501a2
a037dba
 
 
 
 
 
 
 
 
d25c5f5
452e572
d25c5f5
 
452e572
 
d25c5f5
a037dba
 
f57a31d
2717395
51c6d38
08ee6b2
 
 
 
 
 
 
 
 
 
f57a31d
0af4841
08ee6b2
 
 
 
 
89a2b44
08ee6b2
ca7d365
 
 
89a2b44
ca7d365
 
a037dba
 
 
 
 
 
 
 
51c6d38
 
0af4841
 
 
08ee6b2
 
 
 
 
 
 
 
0af4841
08ee6b2
 
 
 
 
0af4841
a037dba
cd22dd4
0af4841
 
 
 
 
 
 
 
 
 
0526f37
0af4841
 
 
 
458c731
0af4841
 
a037dba
 
 
 
 
 
 
 
 
0af4841
89a2b44
a037dba
 
 
 
 
 
 
9111970
a037dba
 
89a2b44
a037dba
 
 
 
 
 
 
08ee6b2
a037dba
 
 
 
 
 
 
08ee6b2
a037dba
 
 
 
 
 
 
 
9111970
a037dba
 
 
 
 
 
 
08ee6b2
a037dba
 
 
08ee6b2
a037dba
 
 
 
2717395
a037dba
 
 
 
 
 
 
 
0af4841
 
 
 
 
 
 
 
0526f37
 
 
a037dba
 
 
08ee6b2
2c22c94
d25c5f5
 
 
 
 
2c22c94
d25c5f5
 
 
 
a037dba
08ee6b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import gradio as gr
import numpy as np
import random
from peft import PeftModel, LoraConfig
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline
from diffusers import ControlNetModel
import torch
from PIL import Image
from rembg import remove
from diffusers import DiffusionPipeline, DDIMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# ControlNet modes list with aliases
CONTROLNET_MODES = {
    "Canny Edge Detection": "lllyasviel/control_v11p_sd15_canny",
    "Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p",
    "Inpainting": "lllyasviel/control_v11p_sd15_inpaint",
    "Multi-Level Line Segments": "lllyasviel/control_v11p_sd15_mlsd",
    "Depth Estimation": "lllyasviel/control_v11f1p_sd15_depth",
    "Surface Normal Estimation": "lllyasviel/control_v11p_sd15_normalbae",
    "Image Segmentation": "lllyasviel/control_v11p_sd15_seg",
    "Line Art Generation": "lllyasviel/control_v11p_sd15_lineart",
    "Anime Line Art": "lllyasviel/control_v11p_sd15_lineart_anime",
    "Human Pose Estimation": "lllyasviel/control_v11p_sd15_openpose",
    "Scribble-Based Generation": "lllyasviel/control_v11p_sd15_scribble",
    "Soft Edge Generation": "lllyasviel/control_v11p_sd15_softedge",
    "Image Shuffling": "lllyasviel/control_v11e_sd15_shuffle",
    "Image Tiling": "lllyasviel/control_v11f1e_sd15_tile",
}

# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    model_id,
    prompt,
    negative_prompt,
    seed,
    randomize_seed,
    width,
    height,
    guidance_scale,
    num_inference_steps,
    lscale=0.0,
    controlnet_enabled=False,
    control_strength=0.0,
    control_mode=None,
    control_image=None,
    ip_adapter_enabled=False,
    ip_adapter_scale=0.0,
    ip_adapter_image=None,
    d_bckg=False,
    ddim_use=False,
    distill_vae=False,
    progress=gr.Progress(track_tqdm=True),
    
):
    control_strength=float(control_strength)
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)
    if ip_adapter_enabled:
        print("ip_adapter_image")
        ip_adapter_image = ip_adapter_image.convert('RGB').resize((510, 510))
        print("ip_adapter_image",ip_adapter_image.size)
    pipe = None
    if controlnet_enabled and control_image:
        controlnet_model = ControlNetModel.from_pretrained(CONTROLNET_MODES.get(control_mode))
        if model_id == "SD1.5 + lora Unet TextEncoder": 
            pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
            pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
        elif model_id == "SD1.5 + lora Unet TextEncoder" or model_id == "SD1.5 + lora Unet":
            pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
            pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
        else:
            pipe=StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet_model)
    else:
        if model_id == "SD1.5 + lora Unet TextEncoder": 
            pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
            pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
        elif model_id == "SD1.5 + lora Unet": 
            pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
            pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
        else:
            pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
    if ip_adapter_enabled:
            print("ip_adapter_enabled",ip_adapter_enabled)
            pipe.load_ip_adapter("h94/IP-Adapter",subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
            pipe.set_ip_adapter_scale(ip_adapter_scale)
    
    pipe.safety_checker = None
    if ddim_use: pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)       
    pipe = pipe.to(device)

    image = pipe(
        prompt=prompt,
        image=control_image,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator,
        cross_attention_kwargs={"scale": lscale},
        controlnet_conditioning_scale=control_strength,
        ip_adapter_image=ip_adapter_image,
        
    ).images[0]
    
    if d_bckg:
        image=remove(image)
        

    return image, seed


examples = [
    "Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a one white flower on its head, sitting in lotus pose on a yoga mat, with its paws pressed together in front of its chest in a prayer position, eyes closed, looking calm and peaceful.",
    "Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, standing with a mischievous grin, one paw raised playfully, bright eyes full of energy, cheeky and fun, white background",
    "Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, jumping mid-air with a surprised expression, wide eyes, and mouth open in excitement, paws stretched out, energetic and playful, forest background.",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

def update_controlnet_visibility(controlnet_enabled):
    return gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled)

def update_ip_adapter_visibility(ip_adapter_enabled):
    return gr.update(visible=ip_adapter_enabled), gr.update(visible=ip_adapter_enabled)


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # UM235 DIFFUSION Space")
        
        model_id_input = gr.Dropdown(
            label="Choose Model",
            choices=[
                "stable-diffusion-v1-5/stable-diffusion-v1-5",  
                "CompVis/stable-diffusion-v1-4", 
                "SD1.5 + lora Unet TextEncoder",  
                "SD1.5 + lora Unet"
            ],
            value="SD1.5 + lora Unet TextEncoder",  
            show_label=True,
            type="value",  
        )
        with gr.Row():
            lscale = gr.Slider(
                label="Lora scale",
                minimum=0,
                maximum=2,
                step=0.05,
                value=0.85,  
            )
        with gr.Row():
                d_bckg=gr.Checkbox(label="Delete Background", value=False)
                ddim_use=gr.Checkbox(label="Enable DDIMScheduler", value=False)
                distill_vae=gr.Checkbox(label="Use tiny VAE with distill model", value=False)

           # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
        with gr.Accordion("ControlNet Settings", open=False):
            controlnet_enabled = gr.Checkbox(label="Enable ControlNet", value=False)

        with gr.Row():
            control_strength = gr.Slider(
                label="ControlNet scale",
                minimum=0.0,
                maximum=1.0,
                step=0.05,
                value=0.75,
                visible=False,
            )

            control_mode = gr.Dropdown(
                label="ControlNet Mode",
                choices=list(CONTROLNET_MODES.keys()),
                value="Canny Edge Detection",
                visible=False,
            )

        control_image = gr.Image(label="ControlNet Image", type="pil", visible=False)

        with gr.Accordion("IP-Adapter Settings", open=False):
            ip_adapter_enabled = gr.Checkbox(label="Enable IP-Adapter", value=False)

            with gr.Row():
                ip_adapter_scale = gr.Slider(
                    label="IP-Adapter Scale",
                    minimum=0.0,
                    maximum=2.0,
                    step=0.05,
                    value=0.55,
                    visible=False,
                )

            ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
            

        with gr.Row():
            run_button = gr.Button("Run", scale=0, variant="primary")

        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=True,
                value="worst quality,low quality, low res, blurry, distortion, jpeg artifacts, backround"
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=1274800826,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=False)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512, 
                )

                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,  
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=7.3,  
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=36,  
                )

        gr.Examples(examples=examples, inputs=[prompt])
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            model_id_input,
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
            lscale,
            controlnet_enabled,
            control_strength,
            control_mode,
            control_image,
            ip_adapter_enabled,
            ip_adapter_scale,
            ip_adapter_image,
            d_bckg,
            ddim_use,
            distill_vae
        ],
        outputs=[result, seed],
    )
    
    controlnet_enabled.change(
        fn=update_controlnet_visibility,
        inputs=[controlnet_enabled],
        outputs=[control_strength, control_mode, control_image],
    )

    ip_adapter_enabled.change(
        fn=update_ip_adapter_visibility,
        inputs=[ip_adapter_enabled],
        outputs=[ip_adapter_scale, ip_adapter_image],
    )
if __name__ == "__main__":
    demo.launch()