multimodalart HF staff commited on
Commit
8733b6e
1 Parent(s): 3b35dc4

Make it compatible with ZeroGPU and add `trigger_mode`

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from diffusers import (
2
  StableDiffusionXLPipeline,
3
  EulerDiscreteScheduler,
@@ -55,6 +56,7 @@ unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda")
55
  pipe = StableDiffusionXLPipeline.from_pretrained(
56
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
57
  ).to("cuda")
 
58
 
59
  if USE_TAESD:
60
  pipe.vae = AutoencoderTiny.from_pretrained(
@@ -115,6 +117,7 @@ if SFAST_COMPILE:
115
  pipe = compile(pipe, config)
116
 
117
 
 
118
  def predict(prompt, seed=1231231):
119
  generator = torch.manual_seed(seed)
120
  last_time = time.time()
@@ -202,7 +205,7 @@ pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("
202
  generate_bt.click(
203
  fn=predict, inputs=inputs, outputs=outputs, show_progress=False
204
  )
205
- prompt.input(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
206
  seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
207
 
208
  demo.queue()
 
1
+ import spaces
2
  from diffusers import (
3
  StableDiffusionXLPipeline,
4
  EulerDiscreteScheduler,
 
56
  pipe = StableDiffusionXLPipeline.from_pretrained(
57
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
58
  ).to("cuda")
59
+ unet = unet.to(dtype=torch.float16)
60
 
61
  if USE_TAESD:
62
  pipe.vae = AutoencoderTiny.from_pretrained(
 
117
  pipe = compile(pipe, config)
118
 
119
 
120
+ @spaces.GPU
121
  def predict(prompt, seed=1231231):
122
  generator = torch.manual_seed(seed)
123
  last_time = time.time()
 
205
  generate_bt.click(
206
  fn=predict, inputs=inputs, outputs=outputs, show_progress=False
207
  )
208
+ prompt.input(fn=predict, inputs=inputs, outputs=outputs, trigger_mode="always_last", show_progress=False)
209
  seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
210
 
211
  demo.queue()