File size: 3,598 Bytes
87d91a7
2419492
 
 
 
 
 
3651eaa
 
 
 
 
 
2419492
62b2a7f
0af1d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
c1071da
510810d
089e6cb
3651eaa
7f39ca4
 
 
0699667
7f39ca4
 
0699667
83f75b0
672cfcb
c058625
7f39ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c058625
 
7f39ca4
83f75b0
7f39ca4
c058625
 
 
7f39ca4
7d3dd84
c058625
1658ae9
7f39ca4
2c60f43
 
0699667
 
c058625
 
3651eaa
c058625
 
7f39ca4
c058625
 
3651eaa
c058625
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from huggingface_hub import login
import os

hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2

#vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16
)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    #vae=vae,
    torch_dtype=torch.float16, 
    variant="fp16",
    use_safetensors=True
)
pipe.to("cuda")
generator = torch.Generator(device="cuda")

#pipe.enable_model_cpu_offload()

def infer(use_custom_model, model_name, image_in, prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, seed):
    if use_custom_model:
        custom_model = model_name

        # This is where you load your trained weights
        pipe.load_lora_weights(custom_model, weight_name="pytorch_lora_weights.safetensors", use_auth_token=True)
    
    prompt = prompt
    negative_prompt = "extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"

    if preprocessor == "canny":

        image = load_image(image_in)

        image = np.array(image)
        image = cv2.Canny(image, 100, 200)
        image = image[:, :, None]
        image = np.concatenate([image, image, image], axis=2)
        image = Image.fromarray(image)

    if use_custom_model:
        lora_scale= 0.9

        images = pipe(
            prompt, 
            negative_prompt=negative_prompt, 
            image=image, 
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            guidance_scale = guidance_scale,
            num_inference_steps=50,
            generator=generator.manual_seed(seed),
            cross_attention_kwargs={"scale": lora_scale}
        ).images
    else:
        images = pipe(
            prompt, 
            negative_prompt=negative_prompt, 
            image=image, 
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            guidance_scale = guidance_scale,
            num_inference_steps=50,
            generator=generator.manual_seed(seed),
        ).images

    images[0].save(f"result.png")

    return f"result.png"

with gr.Blocks() as demo:
    with gr.Column():
        use_custom_model = gr.Checkbox(label="Use a custom model ?", value=False)
        model_name = gr.Textbox(label="Model to use", placeholder="username/my_model")
        image_in = gr.Image(source="upload", type="filepath")
        prompt = gr.Textbox(label="Prompt")
        preprocessor = gr.Dropdown(label="Preprocessor", choices=["canny"], value="canny")
        guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=7.5, type="float")
        controlnet_conditioning_scale = gr.Slider(label="Controlnet conditioning Scale", minimum=0.1, maximum=0.9, step=0.01, value=0.5, type="float")
        seed = gr.Slider(label="seed", minimum=0, maximum=500000, step=1, value=42)

        submit_btn = gr.Button("Submit")
        result = gr.Image(label="Result")

    submit_btn.click(
        fn = infer,
        inputs = [use_custom_model, model_name, image_in, prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, seed],
        outputs = [result]
    )

demo.queue().launch()