Hei-Ha commited on
Commit
a1740b4
·
1 Parent(s): 2a3af43
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +167 -70
  3. requirements.txt +15 -4
README.md CHANGED
@@ -4,10 +4,10 @@ emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ disable_embedding: true
12
+ header: mini
13
  ---
 
 
app.py CHANGED
@@ -1,40 +1,79 @@
1
- import gradio as gr
 
 
 
 
 
2
  import torch
3
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
4
- from huggingface_hub import hf_hub_download
5
- from safetensors.torch import load_file
6
- import spaces
7
  import os
8
- from PIL import Image
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
11
 
12
- base = "stabilityai/stable-diffusion-xl-base-1.0"
13
- repo = "ByteDance/SDXL-Lightning"
14
- checkpoints = {
15
- "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
16
- "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
17
- "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
18
- "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
19
- }
20
 
21
- print('--------------------------')
22
- print(torch.cuda.is_available())
23
- print('--------------------------')
24
- # Ensure model and scheduler are initialized in GPU-enabled function
25
 
26
- if torch.cuda.is_available():
27
- pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
28
-
29
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  if SAFETY_CHECKER:
32
  from safety_checker import StableDiffusionSafetyChecker
33
  from transformers import CLIPFeatureExtractor
34
 
35
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
36
  "CompVis/stable-diffusion-safety-checker"
37
- ).to("cuda")
38
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
39
  "openai/clip-vit-base-patch32"
40
  )
@@ -42,67 +81,125 @@ if SAFETY_CHECKER:
42
  def check_nsfw_images(
43
  images: list[Image.Image],
44
  ) -> tuple[list[Image.Image], list[bool]]:
45
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
46
  has_nsfw_concepts = safety_checker(
47
  images=[images],
48
- clip_input=safety_checker_input.pixel_values.to("cuda")
49
  )
50
 
51
  return images, has_nsfw_concepts
52
 
53
- # Function
54
- @spaces.GPU(enable_queue=True)
55
- def generate_image(prompt, ckpt):
56
-
57
- checkpoint = checkpoints[ckpt][0]
58
- num_inference_steps = checkpoints[ckpt][1]
59
-
60
- if num_inference_steps==1:
61
- # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
62
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
63
- else:
64
- # Ensure sampler uses "trailing" timesteps.
65
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
66
-
67
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
68
- results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if SAFETY_CHECKER:
71
  images, has_nsfw_concepts = check_nsfw_images(results.images)
72
  if any(has_nsfw_concepts):
73
  gr.Warning("NSFW content detected.")
74
  return Image.new("RGB", (512, 512))
75
- return images[0]
76
- return results.images[0]
77
-
78
-
79
-
80
 
81
 
82
-
83
- # Gradio Interface
84
- description = """
85
- This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
86
- As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
 
 
 
 
87
  """
88
-
89
- with gr.Blocks(css="style.css") as demo:
90
- gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
91
- gr.Markdown(description)
92
- with gr.Group():
 
 
 
 
93
  with gr.Row():
94
- prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
95
- ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
96
- submit = gr.Button(scale=1, variant='primary')
97
- img = gr.Image(label='SDXL-Lightning Generated Image')
98
-
99
- prompt.submit(fn=generate_image,
100
- inputs=[prompt, ckpt],
101
- outputs=img,
102
- )
103
- submit.click(fn=generate_image,
104
- inputs=[prompt, ckpt],
105
- outputs=img,
106
- )
107
-
108
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionXLPipeline,
3
+ EulerDiscreteScheduler,
4
+ UNet2DConditionModel,
5
+ AutoencoderTiny,
6
+ )
7
  import torch
 
 
 
 
8
  import os
9
+ from huggingface_hub import hf_hub_download
10
+
11
 
12
+ from PIL import Image
13
+ import gradio as gr
14
+ import time
15
+ from safetensors.torch import load_file
16
+ import time
17
+ import tempfile
18
+ from pathlib import Path
19
+
20
+ # Constants
21
+ BASE = "stabilityai/stable-diffusion-xl-base-1.0"
22
+ REPO = "ByteDance/SDXL-Lightning"
23
+ # 1-step
24
+ CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
25
+ taesd_model = "madebyollin/taesdxl"
26
+
27
+ # {
28
+ # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
29
+ # "2-Step": ["sdxl_lightning_2step_unet.safetensors", 2],
30
+ # "4-Step": ["sdxl_lightning_4step_unet.safetensors", 4],
31
+ # "8-Step": ["sdxl_lightning_8step_unet.safetensors", 8],
32
+ # }
33
+
34
+
35
+ SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
36
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
37
+ USE_TAESD = os.environ.get("USE_TAESD", "0") == "1"
38
 
39
+ # check if MPS is available OSX only M1/M2/M3 chips
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ torch_device = device
43
+ torch_dtype = torch.float16
 
 
 
44
 
45
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
46
+ print(f"SFAST_COMPILE: {SFAST_COMPILE}")
47
+ print(f"USE_TAESD: {USE_TAESD}")
48
+ print(f"device: {device}")
49
 
 
 
 
50
 
51
+ unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to(
52
+ "cuda", torch.float16
53
+ )
54
+ unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda"))
55
+ pipe = StableDiffusionXLPipeline.from_pretrained(
56
+ BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
57
+ ).to("cuda")
58
 
59
+ if USE_TAESD:
60
+ pipe.vae = AutoencoderTiny.from_pretrained(
61
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
62
+ ).to(device)
63
+
64
+
65
+ # Ensure sampler uses "trailing" timesteps.
66
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
67
+ pipe.scheduler.config, timestep_spacing="trailing"
68
+ )
69
+ pipe.set_progress_bar_config(disable=True)
70
  if SAFETY_CHECKER:
71
  from safety_checker import StableDiffusionSafetyChecker
72
  from transformers import CLIPFeatureExtractor
73
 
74
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
75
  "CompVis/stable-diffusion-safety-checker"
76
+ ).to(device)
77
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
78
  "openai/clip-vit-base-patch32"
79
  )
 
81
  def check_nsfw_images(
82
  images: list[Image.Image],
83
  ) -> tuple[list[Image.Image], list[bool]]:
84
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
85
  has_nsfw_concepts = safety_checker(
86
  images=[images],
87
+ clip_input=safety_checker_input.pixel_values.to(torch_device),
88
  )
89
 
90
  return images, has_nsfw_concepts
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ if SFAST_COMPILE:
94
+ from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
95
+
96
+ # sfast compilation
97
+ config = CompilationConfig.Default()
98
+ try:
99
+ import xformers
100
+
101
+ config.enable_xformers = True
102
+ except ImportError:
103
+ print("xformers not installed, skip")
104
+ try:
105
+ import triton
106
+
107
+ config.enable_triton = True
108
+ except ImportError:
109
+ print("Triton not installed, skip")
110
+ # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
111
+ # But it can increase the amount of GPU memory used.
112
+ # For StableVideoDiffusionPipeline it is not needed.
113
+ config.enable_cuda_graph = True
114
+
115
+ pipe = compile(pipe, config)
116
+
117
+
118
+ def predict(prompt, seed=1231231):
119
+ generator = torch.manual_seed(seed)
120
+ last_time = time.time()
121
+ results = pipe(
122
+ prompt=prompt,
123
+ generator=generator,
124
+ num_inference_steps=2,
125
+ guidance_scale=0.0,
126
+ # width=768,
127
+ # height=768,
128
+ output_type="pil",
129
+ )
130
+ print(f"Pipe took {time.time() - last_time} seconds")
131
  if SAFETY_CHECKER:
132
  images, has_nsfw_concepts = check_nsfw_images(results.images)
133
  if any(has_nsfw_concepts):
134
  gr.Warning("NSFW content detected.")
135
  return Image.new("RGB", (512, 512))
136
+ image = results.images[0]
137
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
138
+ image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
139
+ return Path(tmpfile.name)
 
140
 
141
 
142
+ css = """
143
+ #container{
144
+ margin: 0 auto;
145
+ max-width: 40rem;
146
+ }
147
+ #intro{
148
+ max-width: 100%;
149
+ margin: 0 auto;
150
+ }
151
  """
152
+ with gr.Blocks(css=css) as demo:
153
+ with gr.Column(elem_id="container"):
154
+ gr.Markdown(
155
+ """
156
+ # SDXL-Lightning- Text To Image 2-Steps
157
+ **Model**: https://huggingface.co/ByteDance/SDXL-Lightning
158
+ """,
159
+ elem_id="intro",
160
+ )
161
  with gr.Row():
162
+ with gr.Row():
163
+ prompt = gr.Textbox(
164
+ placeholder="Insert your prompt here:", scale=5, container=False
165
+ )
166
+ generate_bt = gr.Button("Generate", scale=1)
167
+
168
+ image = gr.Image(type="filepath")
169
+ with gr.Accordion("Advanced options", open=False):
170
+ seed = gr.Slider(
171
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
172
+ )
173
+ with gr.Accordion("Run with diffusers"):
174
+ gr.Markdown(
175
+ """## Running SDXL-Lightning with `diffusers`
176
+ ```py
177
+ import torch
178
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
179
+ from huggingface_hub import hf_hub_download
180
+ from safetensors.torch import load_file
181
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
182
+ repo = "ByteDance/SDXL-Lightning"
183
+ ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
184
+ # Load model.
185
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
186
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
187
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
188
+ # Ensure sampler uses "trailing" timesteps.
189
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
190
+ # Ensure using the same inference steps as the loaded model and CFG set to 0.
191
+ pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("output.png")
192
+ ```
193
+ """
194
+ )
195
+
196
+ inputs = [prompt, seed]
197
+ outputs = [image]
198
+ generate_bt.click(
199
+ fn=predict, inputs=inputs, outputs=outputs, show_progress=False
200
+ )
201
+ prompt.input(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
202
+ seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
203
+
204
+ demo.queue()
205
+ demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,16 @@
 
1
  transformers
2
- diffusers
3
- torch
4
- accelerate
5
- gradio
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.26.3
2
  transformers
3
+ gradio==4.19.2
4
+ torch==2.1.0
5
+ fastapi==0.104.0
6
+ uvicorn==0.23.2
7
+ Pillow==10.1.0
8
+ accelerate==0.24.0
9
+ compel==2.0.2
10
+ controlnet-aux==0.0.7
11
+ peft==0.6.0
12
+ xformers
13
+ hf_transfer
14
+ huggingface_hub
15
+ safetensors
16
+ stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.2/stable_fast-1.0.2+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl; sys_platform != 'darwin' or platform_machine != 'arm64'