um235 commited on
Commit
458c731
·
verified ·
1 Parent(s): 5b5bb22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -2,9 +2,10 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
  from peft import PeftModel, LoraConfig
5
- from diffusers import DiffusionPipeline
6
  from diffusers import ControlNetModel
7
  import torch
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  if torch.cuda.is_available():
@@ -44,45 +45,54 @@ def infer(
44
  height,
45
  guidance_scale,
46
  num_inference_steps,
47
- lscale,
48
- controlnet_enabled,
49
- control_strength,
50
- control_mode,
51
- control_image,
52
- ip_adapter_enabled,
53
- ip_adapter_scale,
54
- ip_adapter_image,
55
  progress=gr.Progress(track_tqdm=True),
56
  ):
 
57
  if randomize_seed:
58
  seed = random.randint(0, MAX_SEED)
59
 
60
  generator = torch.Generator().manual_seed(seed)
61
-
 
 
 
62
  pipe = None
63
- if model_id == "SD1.5 + lora Unet TextEncoder":
64
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
65
- pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/VanillaCat", subfolder="unet")
66
- pipe.safety_checker = None
67
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/VanillaCat", subfolder="text_encoder")
68
- elif model_id == "SD1.5 + lora Unet":
69
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
70
- pipe.safety_checker = None
71
- pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
72
  else:
73
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
74
- pipe.safety_checker = None
75
-
76
- if controlnet_enabled:
77
- controlnet_model = CONTROLNET_MODES.get(control_mode)
78
- if controlnet_model:
79
- controlnet_model = ControlNetModel.from_pretrained(controlnet_model)
80
- pipe.controlnet = controlnet_model
 
 
 
 
 
81
 
 
 
82
  pipe = pipe.to(device)
83
 
84
  image = pipe(
85
  prompt=prompt,
 
86
  negative_prompt=negative_prompt,
87
  guidance_scale=guidance_scale,
88
  num_inference_steps=num_inference_steps,
@@ -90,8 +100,8 @@ def infer(
90
  height=height,
91
  generator=generator,
92
  cross_attention_kwargs={"scale": lscale},
93
- control_image=control_image,
94
- controlnet_conditioning_scale=control_strength
95
  ).images[0]
96
 
97
  return image, seed
@@ -187,6 +197,7 @@ with gr.Blocks(css=css) as demo:
187
  )
188
 
189
  ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
 
190
 
191
  with gr.Row():
192
  run_button = gr.Button("Run", scale=0, variant="primary")
@@ -235,7 +246,7 @@ with gr.Blocks(css=css) as demo:
235
  minimum=0.0,
236
  maximum=10.0,
237
  step=0.1,
238
- value=9.0,
239
  )
240
 
241
  num_inference_steps = gr.Slider(
 
2
  import numpy as np
3
  import random
4
  from peft import PeftModel, LoraConfig
5
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline
6
  from diffusers import ControlNetModel
7
  import torch
8
+ from PIL import Image
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  if torch.cuda.is_available():
 
45
  height,
46
  guidance_scale,
47
  num_inference_steps,
48
+ lscale=0.0,
49
+ controlnet_enabled=False,
50
+ control_strength=0.0,
51
+ control_mode=None,
52
+ control_image=None,
53
+ ip_adapter_enabled=False,
54
+ ip_adapter_scale=0.0,
55
+ ip_adapter_image=None,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
+ control_strength=float(control_strength)
59
  if randomize_seed:
60
  seed = random.randint(0, MAX_SEED)
61
 
62
  generator = torch.Generator().manual_seed(seed)
63
+ if ip_adapter_enabled:
64
+ print("ip_adapter_image")
65
+ ip_adapter_image = ip_adapter_image.convert('RGB').resize((510, 510))
66
+ print("ip_adapter_image",ip_adapter_image.size)
67
  pipe = None
68
+ if controlnet_enabled and control_image:
69
+ controlnet_model = ControlNetModel.from_pretrained(CONTROLNET_MODES.get(control_mode))
70
+ if model_id == "SD1.5 + lora Unet TextEncoder" or model_id == "SD1.5 + lora Unet":
71
+ pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
72
+ else:
73
+ pipe=StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet_model)
 
 
 
74
  else:
75
+ if model_id == "SD1.5 + lora Unet TextEncoder":
76
+ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
77
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/VanillaCat", subfolder="unet")
78
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/VanillaCat", subfolder="text_encoder")
79
+ elif model_id == "SD1.5 + lora Unet":
80
+ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
81
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
82
+ else:
83
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
84
+ if ip_adapter_enabled:
85
+ print("ip_adapter_enabled",ip_adapter_enabled)
86
+ pipe.load_ip_adapter("h94/IP-Adapter",subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
87
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
88
 
89
+ pipe.safety_checker = None
90
+
91
  pipe = pipe.to(device)
92
 
93
  image = pipe(
94
  prompt=prompt,
95
+ image=control_image,
96
  negative_prompt=negative_prompt,
97
  guidance_scale=guidance_scale,
98
  num_inference_steps=num_inference_steps,
 
100
  height=height,
101
  generator=generator,
102
  cross_attention_kwargs={"scale": lscale},
103
+ controlnet_conditioning_scale=control_strength,
104
+ ip_adapter_image=ip_adapter_image,
105
  ).images[0]
106
 
107
  return image, seed
 
197
  )
198
 
199
  ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
200
+
201
 
202
  with gr.Row():
203
  run_button = gr.Button("Run", scale=0, variant="primary")
 
246
  minimum=0.0,
247
  maximum=10.0,
248
  step=0.1,
249
+ value=7.0,
250
  )
251
 
252
  num_inference_steps = gr.Slider(