ysharma HF staff commited on
Commit
f679e15
1 Parent(s): d017ab0

feature for selecting inference steps

Browse files

- Added a dropdown to select inference steps.
- Moved model loading out of inference function
- Changed the Interface to Blocks API

Files changed (1) hide show
  1. app.py +47 -16
app.py CHANGED
@@ -4,34 +4,65 @@ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
5
  import spaces
6
 
 
7
  # Constants
8
  base = "stabilityai/stable-diffusion-xl-base-1.0"
9
  repo = "ByteDance/SDXL-Lightning"
10
- ckpt = "sdxl_lightning_4step_unet.pth"
 
 
 
 
 
11
 
12
- # Function
13
- @spaces.GPU
14
- def generate_image(prompt):
15
- # Ensure model and scheduler are initialized in GPU-enabled function
16
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
17
- pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, ckpt), map_location="cuda"))
18
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
19
 
20
- image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  return image
22
 
 
23
  # Gradio Interface
24
  description = """
25
  This demo utilizes the SDXL-Lightning model by ByteDance, which is a fast text-to-image generative model capable of producing high-quality images in 4 steps.
26
  As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
27
  """
28
 
29
- demo = gr.Interface(
30
- fn=generate_image,
31
- inputs="text",
32
- outputs="image",
33
- title="Text-to-Image with SDXL Lightning ⚡",
34
- description=description
35
- )
 
 
36
 
37
- demo.launch()
 
 
 
 
 
 
 
 
 
 
4
  from huggingface_hub import hf_hub_download
5
  import spaces
6
 
7
+
8
  # Constants
9
  base = "stabilityai/stable-diffusion-xl-base-1.0"
10
  repo = "ByteDance/SDXL-Lightning"
11
+ checkpoints = {
12
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.pth", 1],
13
+ "2-Step" : ["sdxl_lightning_2step_unet.pth", 2],
14
+ "4-Step" : ["sdxl_lightning_4step_unet.pth", 4],
15
+ "8-Step" : ["sdxl_lightning_8step_unet.pth", 8],
16
+ }
17
 
18
+
19
+ # Ensure model and scheduler are initialized in GPU-enabled function
20
+ if torch.cuda.is_available():
 
21
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
22
 
23
+
24
+ # Function
25
+ @spaces.GPU(enable_queue=True)
26
+ def generate_image(prompt, ckpt):
27
+
28
+ checkpoint = checkpoints[ckpt][0]
29
+ num_inference_steps = checkpoints[ckpt][1]
30
+
31
+ if num_inference_steps==1:
32
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
33
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
34
+ else:
35
+ # Ensure sampler uses "trailing" timesteps.
36
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
37
+
38
+ pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
39
+ image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0).images[0]
40
  return image
41
 
42
+
43
  # Gradio Interface
44
  description = """
45
  This demo utilizes the SDXL-Lightning model by ByteDance, which is a fast text-to-image generative model capable of producing high-quality images in 4 steps.
46
  As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
47
  """
48
 
49
+ with gr.Blocks(css="style.css") as demo:
50
+ gr.HTML("<h1><center>Text-to-Image with SDXL Lightning ⚡</center></h1>")
51
+ gr.Markdown(description)
52
+ with gr.Group():
53
+ with gr.Row():
54
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
55
+ ckpt = gr.Dropdown(label='Select Inference Steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
56
+ submit = gr.Button(scale=1, variant='primary')
57
+ img = gr.Image(label='SDXL-Lightening Generate Image')
58
 
59
+ prompt.submit(fn=generate_image,
60
+ inputs=[prompt, ckpt],
61
+ outputs=img,
62
+ )
63
+ submit.click(fn=generate_image,
64
+ inputs=[prompt, ckpt],
65
+ outputs=img,
66
+ )
67
+
68
+ demo.queue().launch()