Files changed (1) hide show
  1. demo_gradio.py +6 -5
demo_gradio.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  import huggingface_hub
3
 
@@ -26,7 +27,7 @@ import glob
26
  import torch
27
  import cv2
28
  import argparse
29
-
30
  import DPT.util.io
31
 
32
  from torchvision.transforms import Compose
@@ -38,7 +39,7 @@ from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet
38
  """
39
  Get ZeST Ready
40
  """
41
- base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
42
  image_encoder_path = "models/image_encoder"
43
  ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
44
  controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
@@ -55,7 +56,7 @@ pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
55
  add_watermarker=False,
56
  ).to(device)
57
  pipe.unet = register_cross_attention_hook(pipe.unet)
58
-
59
  ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
60
 
61
 
@@ -161,7 +162,7 @@ def greet(input_image, material_exemplar):
161
 
162
 
163
  num_samples = 1
164
- images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)
165
 
166
  return images[0]
167
 
@@ -193,4 +194,4 @@ with gr.Blocks(css=css) as demo:
193
  output_image = gr.Image(label="transfer result")
194
  submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image])
195
 
196
- demo.queue().launch()
 
1
+
2
  import spaces
3
  import huggingface_hub
4
 
 
27
  import torch
28
  import cv2
29
  import argparse
30
+ from diffusers.models.attention_processor import AttnProcessor2_0
31
  import DPT.util.io
32
 
33
  from torchvision.transforms import Compose
 
39
  """
40
  Get ZeST Ready
41
  """
42
+ base_model_path = "Lykon/dreamshaper-xl-lightning"
43
  image_encoder_path = "models/image_encoder"
44
  ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
45
  controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
 
56
  add_watermarker=False,
57
  ).to(device)
58
  pipe.unet = register_cross_attention_hook(pipe.unet)
59
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
60
  ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
61
 
62
 
 
162
 
163
 
164
  num_samples = 1
165
+ images = ip_model.generate(guidance_scale=2, pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=4, seed=42)
166
 
167
  return images[0]
168
 
 
194
  output_image = gr.Image(label="transfer result")
195
  submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image])
196
 
197
+ demo.queue().launch()