Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,10 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
from peft import PeftModel, LoraConfig
|
5 |
-
from diffusers import DiffusionPipeline
|
6 |
from diffusers import ControlNetModel
|
7 |
import torch
|
|
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
if torch.cuda.is_available():
|
@@ -44,45 +45,54 @@ def infer(
|
|
44 |
height,
|
45 |
guidance_scale,
|
46 |
num_inference_steps,
|
47 |
-
lscale,
|
48 |
-
controlnet_enabled,
|
49 |
-
control_strength,
|
50 |
-
control_mode,
|
51 |
-
control_image,
|
52 |
-
ip_adapter_enabled,
|
53 |
-
ip_adapter_scale,
|
54 |
-
ip_adapter_image,
|
55 |
progress=gr.Progress(track_tqdm=True),
|
56 |
):
|
|
|
57 |
if randomize_seed:
|
58 |
seed = random.randint(0, MAX_SEED)
|
59 |
|
60 |
generator = torch.Generator().manual_seed(seed)
|
61 |
-
|
|
|
|
|
|
|
62 |
pipe = None
|
63 |
-
if
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
|
70 |
-
pipe.safety_checker = None
|
71 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
|
72 |
else:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
82 |
pipe = pipe.to(device)
|
83 |
|
84 |
image = pipe(
|
85 |
prompt=prompt,
|
|
|
86 |
negative_prompt=negative_prompt,
|
87 |
guidance_scale=guidance_scale,
|
88 |
num_inference_steps=num_inference_steps,
|
@@ -90,8 +100,8 @@ def infer(
|
|
90 |
height=height,
|
91 |
generator=generator,
|
92 |
cross_attention_kwargs={"scale": lscale},
|
93 |
-
|
94 |
-
|
95 |
).images[0]
|
96 |
|
97 |
return image, seed
|
@@ -187,6 +197,7 @@ with gr.Blocks(css=css) as demo:
|
|
187 |
)
|
188 |
|
189 |
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
|
|
|
190 |
|
191 |
with gr.Row():
|
192 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
@@ -235,7 +246,7 @@ with gr.Blocks(css=css) as demo:
|
|
235 |
minimum=0.0,
|
236 |
maximum=10.0,
|
237 |
step=0.1,
|
238 |
-
value=
|
239 |
)
|
240 |
|
241 |
num_inference_steps = gr.Slider(
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
from peft import PeftModel, LoraConfig
|
5 |
+
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline
|
6 |
from diffusers import ControlNetModel
|
7 |
import torch
|
8 |
+
from PIL import Image
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
if torch.cuda.is_available():
|
|
|
45 |
height,
|
46 |
guidance_scale,
|
47 |
num_inference_steps,
|
48 |
+
lscale=0.0,
|
49 |
+
controlnet_enabled=False,
|
50 |
+
control_strength=0.0,
|
51 |
+
control_mode=None,
|
52 |
+
control_image=None,
|
53 |
+
ip_adapter_enabled=False,
|
54 |
+
ip_adapter_scale=0.0,
|
55 |
+
ip_adapter_image=None,
|
56 |
progress=gr.Progress(track_tqdm=True),
|
57 |
):
|
58 |
+
control_strength=float(control_strength)
|
59 |
if randomize_seed:
|
60 |
seed = random.randint(0, MAX_SEED)
|
61 |
|
62 |
generator = torch.Generator().manual_seed(seed)
|
63 |
+
if ip_adapter_enabled:
|
64 |
+
print("ip_adapter_image")
|
65 |
+
ip_adapter_image = ip_adapter_image.convert('RGB').resize((510, 510))
|
66 |
+
print("ip_adapter_image",ip_adapter_image.size)
|
67 |
pipe = None
|
68 |
+
if controlnet_enabled and control_image:
|
69 |
+
controlnet_model = ControlNetModel.from_pretrained(CONTROLNET_MODES.get(control_mode))
|
70 |
+
if model_id == "SD1.5 + lora Unet TextEncoder" or model_id == "SD1.5 + lora Unet":
|
71 |
+
pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
|
72 |
+
else:
|
73 |
+
pipe=StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet_model)
|
|
|
|
|
|
|
74 |
else:
|
75 |
+
if model_id == "SD1.5 + lora Unet TextEncoder":
|
76 |
+
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
|
77 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/VanillaCat", subfolder="unet")
|
78 |
+
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/VanillaCat", subfolder="text_encoder")
|
79 |
+
elif model_id == "SD1.5 + lora Unet":
|
80 |
+
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
|
81 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
|
82 |
+
else:
|
83 |
+
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
|
84 |
+
if ip_adapter_enabled:
|
85 |
+
print("ip_adapter_enabled",ip_adapter_enabled)
|
86 |
+
pipe.load_ip_adapter("h94/IP-Adapter",subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
87 |
+
pipe.set_ip_adapter_scale(ip_adapter_scale)
|
88 |
|
89 |
+
pipe.safety_checker = None
|
90 |
+
|
91 |
pipe = pipe.to(device)
|
92 |
|
93 |
image = pipe(
|
94 |
prompt=prompt,
|
95 |
+
image=control_image,
|
96 |
negative_prompt=negative_prompt,
|
97 |
guidance_scale=guidance_scale,
|
98 |
num_inference_steps=num_inference_steps,
|
|
|
100 |
height=height,
|
101 |
generator=generator,
|
102 |
cross_attention_kwargs={"scale": lscale},
|
103 |
+
controlnet_conditioning_scale=control_strength,
|
104 |
+
ip_adapter_image=ip_adapter_image,
|
105 |
).images[0]
|
106 |
|
107 |
return image, seed
|
|
|
197 |
)
|
198 |
|
199 |
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
|
200 |
+
|
201 |
|
202 |
with gr.Row():
|
203 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
|
|
246 |
minimum=0.0,
|
247 |
maximum=10.0,
|
248 |
step=0.1,
|
249 |
+
value=7.0,
|
250 |
)
|
251 |
|
252 |
num_inference_steps = gr.Slider(
|