PeterL1n commited on
Commit
b0f3145
1 Parent(s): 3aee425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
5
+ from huggingface_hub import hf_hub_download
6
+ from safetensors.torch import load_file
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
10
+ repo = "ByteDance/SDXL-Lightning"
11
+ opts = {
12
+ "1 Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
13
+ "2 Steps" : ["sdxl_lightning_2step_unet.safetensors", 2],
14
+ "4 Steps" : ["sdxl_lightning_4step_unet.safetensors", 4],
15
+ "8 Steps" : ["sdxl_lightning_8step_unet.safetensors", 8],
16
+ }
17
+
18
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
19
+
20
+ # Function
21
+ @spaces.GPU(enable_queue=True)
22
+ def generate_image(prompt, option):
23
+ ckpt, step = opts[option]
24
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
25
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
26
+ image = pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
27
+ return image
28
+
29
+
30
+ with gr.Blocks() as demo:
31
+ gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
32
+ gr.Markdown("Lightning-fast text-to-image generation! https://huggingface.co/ByteDance/SDXL-Lightning")
33
+
34
+ with gr.Group():
35
+ with gr.Row():
36
+ prompt = gr.Textbox(
37
+ label="Text prompt",
38
+ scale=8
39
+ )
40
+ option = gr.Dropdown(
41
+ label="Inference steps",
42
+ choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
43
+ value="4-Step",
44
+ interactive=True
45
+ )
46
+ submit = gr.Button(
47
+ scale=1,
48
+ variant="primary"
49
+ )
50
+
51
+ img = gr.Image(label="SDXL-Lightening Generated Image")
52
+
53
+ prompt.submit(
54
+ fn=generate_image,
55
+ inputs=[prompt, option],
56
+ outputs=img,
57
+ )
58
+ submit.click(
59
+ fn=generate_image,
60
+ inputs=[prompt, option],
61
+ outputs=img,
62
+ )
63
+
64
+ demo.queue().launch()