PeterL1n commited on
Commit
f1e3c7d
1 Parent(s): b0f3145

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -16,20 +16,24 @@ opts = {
16
  }
17
 
18
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
 
19
 
20
  # Function
21
  @spaces.GPU(enable_queue=True)
22
  def generate_image(prompt, option):
23
  ckpt, step = opts[option]
24
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
25
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
 
 
26
  image = pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
27
  return image
28
 
29
 
30
  with gr.Blocks() as demo:
31
- gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
32
- gr.Markdown("Lightning-fast text-to-image generation! https://huggingface.co/ByteDance/SDXL-Lightning")
 
33
 
34
  with gr.Group():
35
  with gr.Row():
@@ -40,7 +44,7 @@ with gr.Blocks() as demo:
40
  option = gr.Dropdown(
41
  label="Inference steps",
42
  choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
43
- value="4-Step",
44
  interactive=True
45
  )
46
  submit = gr.Button(
 
16
  }
17
 
18
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
19
+ last_step = None
20
 
21
  # Function
22
  @spaces.GPU(enable_queue=True)
23
  def generate_image(prompt, option):
24
  ckpt, step = opts[option]
25
+ if last_step != step:
26
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
27
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
28
+ last_step = step
29
  image = pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
30
  return image
31
 
32
 
33
  with gr.Blocks() as demo:
34
+ gr.HTML("<h1><center>SDXL-Lightning</center></h1>")
35
+ gr.HTML("<p><center>Lightning-fast text-to-image generation.</center></p>")
36
+ gr.HTML("<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>")
37
 
38
  with gr.Group():
39
  with gr.Row():
 
44
  option = gr.Dropdown(
45
  label="Inference steps",
46
  choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
47
+ value="4 Steps",
48
  interactive=True
49
  )
50
  submit = gr.Button(