PeterL1n commited on
Commit
23f3ac6
1 Parent(s): 7c672bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -19,11 +19,14 @@ opts = {
19
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
20
  }
21
 
 
22
  step_loaded = 4
23
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
24
- unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
25
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
 
26
 
 
27
  @spaces.GPU(enable_queue=True)
28
  def generate_image(prompt, option):
29
  global step_loaded
 
19
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
20
  }
21
 
22
+ # Default to load 4-step model.
23
  step_loaded = 4
24
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet")
25
+ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
26
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
27
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
28
 
29
+ # Inference function.
30
  @spaces.GPU(enable_queue=True)
31
  def generate_image(prompt, option):
32
  global step_loaded