radames HF staff commited on
Commit
ad6ae44
1 Parent(s): 549f820

prompt PromptWeighting

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ⚡️⚡️⚡️⚡️
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  disable_embedding: true
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
  disable_embedding: true
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from diffusers import (
2
  StableDiffusionXLPipeline,
3
  EulerDiscreteScheduler,
@@ -7,6 +8,8 @@ from diffusers import (
7
  import torch
8
  import os
9
  from huggingface_hub import hf_hub_download
 
 
10
 
11
 
12
  from PIL import Image
@@ -24,13 +27,6 @@ REPO = "ByteDance/SDXL-Lightning"
24
  CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
25
  taesd_model = "madebyollin/taesdxl"
26
 
27
- # {
28
- # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
29
- # "2-Step": ["sdxl_lightning_2step_unet.safetensors", 2],
30
- # "4-Step": ["sdxl_lightning_4step_unet.safetensors", 4],
31
- # "8-Step": ["sdxl_lightning_8step_unet.safetensors", 8],
32
- # }
33
-
34
 
35
  SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
36
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
@@ -55,6 +51,15 @@ unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda")
55
  pipe = StableDiffusionXLPipeline.from_pretrained(
56
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
57
  ).to("cuda")
 
 
 
 
 
 
 
 
 
58
 
59
  if USE_TAESD:
60
  pipe.vae = AutoencoderTiny.from_pretrained(
@@ -115,14 +120,24 @@ if SFAST_COMPILE:
115
  pipe = compile(pipe, config)
116
 
117
 
118
- def predict(prompt, seed=1231231):
 
119
  generator = torch.manual_seed(seed)
120
  last_time = time.time()
 
 
 
 
 
 
121
  results = pipe(
122
- prompt=prompt,
 
 
 
123
  generator=generator,
124
  num_inference_steps=2,
125
- guidance_scale=0.0,
126
  # width=768,
127
  # height=768,
128
  output_type="pil",
@@ -142,12 +157,15 @@ def predict(prompt, seed=1231231):
142
  css = """
143
  #container{
144
  margin: 0 auto;
145
- max-width: 40rem;
146
  }
147
  #intro{
148
  max-width: 100%;
149
  margin: 0 auto;
150
  }
 
 
 
151
  """
152
  with gr.Blocks(css=css) as demo:
153
  with gr.Column(elem_id="container"):
@@ -159,51 +177,61 @@ with gr.Blocks(css=css) as demo:
159
  elem_id="intro",
160
  )
161
  with gr.Row():
162
- with gr.Row():
163
- prompt = gr.Textbox(
164
- placeholder="Insert your prompt here:", scale=5, container=False
165
- )
166
- generate_bt = gr.Button("Generate", scale=1)
167
-
168
- image = gr.Image(type="filepath")
169
- with gr.Accordion("Advanced options", open=False):
170
- seed = gr.Slider(
171
- randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
172
- )
173
- with gr.Accordion("Run with diffusers"):
174
- gr.Markdown(
175
- """## Running SDXL-Lightning with `diffusers`
176
- ```py
177
- import torch
178
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
179
- from huggingface_hub import hf_hub_download
180
- from safetensors.torch import load_file
181
-
182
- base = "stabilityai/stable-diffusion-xl-base-1.0"
183
- repo = "ByteDance/SDXL-Lightning"
184
- ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
185
-
186
- # Load model.
187
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
188
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
189
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
190
-
191
- # Ensure sampler uses "trailing" timesteps.
192
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
193
-
194
- # Ensure using the same inference steps as the loaded model and CFG set to 0.
195
- pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("output.png")
196
- ```
197
- """
198
- )
199
-
200
- inputs = [prompt, seed]
201
  outputs = [image]
202
- generate_bt.click(
203
- fn=predict, inputs=inputs, outputs=outputs, show_progress=False
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
- prompt.input(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
206
- seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
207
 
208
- demo.queue()
209
  demo.launch()
 
1
+ import spaces
2
  from diffusers import (
3
  StableDiffusionXLPipeline,
4
  EulerDiscreteScheduler,
 
8
  import torch
9
  import os
10
  from huggingface_hub import hf_hub_download
11
+ from compel import Compel, ReturnedEmbeddingsType
12
+ from gradio_promptweighting import PromptWeighting
13
 
14
 
15
  from PIL import Image
 
27
  CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
28
  taesd_model = "madebyollin/taesdxl"
29
 
 
 
 
 
 
 
 
30
 
31
  SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
32
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
51
  pipe = StableDiffusionXLPipeline.from_pretrained(
52
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
53
  ).to("cuda")
54
+ unet = unet.to(dtype=torch.float16)
55
+
56
+ compel = Compel(
57
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
58
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
59
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
60
+ requires_pooled=[False, True],
61
+ )
62
+
63
 
64
  if USE_TAESD:
65
  pipe.vae = AutoencoderTiny.from_pretrained(
 
120
  pipe = compile(pipe, config)
121
 
122
 
123
+ @spaces.GPU
124
+ def predict(prompt, prompt_w, guidance_scale, seed=1231231):
125
  generator = torch.manual_seed(seed)
126
  last_time = time.time()
127
+ prompt_w = " ".join(
128
+ [f"({p['prompt']}){p['scale']}" for p in prompt_w if p["prompt"]]
129
+ )
130
+
131
+ conditioning, pooled = compel([prompt + " " + prompt_w, ""])
132
+
133
  results = pipe(
134
+ prompt_embeds=conditioning[0:1],
135
+ pooled_prompt_embeds=pooled[0:1],
136
+ negative_prompt_embeds=conditioning[1:2],
137
+ negative_pooled_prompt_embeds=pooled[1:2],
138
  generator=generator,
139
  num_inference_steps=2,
140
+ guidance_scale=guidance_scale,
141
  # width=768,
142
  # height=768,
143
  output_type="pil",
 
157
  css = """
158
  #container{
159
  margin: 0 auto;
160
+ max-width: 80rem;
161
  }
162
  #intro{
163
  max-width: 100%;
164
  margin: 0 auto;
165
  }
166
+ .generating {
167
+ display: none
168
+ }
169
  """
170
  with gr.Blocks(css=css) as demo:
171
  with gr.Column(elem_id="container"):
 
177
  elem_id="intro",
178
  )
179
  with gr.Row():
180
+ with gr.Column():
181
+ with gr.Group():
182
+ prompt = gr.Textbox(
183
+ placeholder="Insert your prompt here:",
184
+ max_lines=1,
185
+ label="Prompt",
186
+ )
187
+ prompt_w = PromptWeighting(
188
+ min=0,
189
+ max=3,
190
+ step=0.005,
191
+ show_label=False,
192
+ )
193
+
194
+ with gr.Accordion("Advanced options", open=True):
195
+ seed = gr.Slider(
196
+ minimum=0,
197
+ maximum=12013012031030,
198
+ label="Seed",
199
+ step=1,
200
+ )
201
+ guidance_scale = gr.Slider(
202
+ minimum=0.0,
203
+ maximum=20.0,
204
+ label="Guidance scale",
205
+ value=0.0,
206
+ step=0.1,
207
+ )
208
+ generate_bt = gr.Button("Generate")
209
+ with gr.Column():
210
+ image = gr.Image(type="filepath")
211
+
212
+ inputs = [
213
+ prompt,
214
+ prompt_w,
215
+ guidance_scale,
216
+ seed,
217
+ ]
 
218
  outputs = [image]
219
+
220
+ gr.on(
221
+ triggers=[
222
+ prompt.input,
223
+ prompt_w.input,
224
+ generate_bt.click,
225
+ guidance_scale.input,
226
+ seed.input,
227
+ ],
228
+ fn=predict,
229
+ inputs=inputs,
230
+ outputs=outputs,
231
+ show_progress="hidden",
232
+ show_api=False,
233
+ trigger_mode="always_last",
234
  )
 
 
235
 
236
+ demo.queue(api_open=False)
237
  demo.launch()
gradio_promptweighting-0.0.1-py3-none-any.whl ADDED
Binary file (38.8 kB). View file
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- diffusers==0.26.3
2
  transformers
3
- gradio==4.19.2
4
  torch==2.1.0
5
  fastapi==0.104.0
6
  uvicorn==0.23.2
@@ -13,4 +13,7 @@ xformers
13
  hf_transfer
14
  huggingface_hub
15
  safetensors
16
- stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.2/stable_fast-1.0.2+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl; sys_platform != 'darwin' or platform_machine != 'arm64'
 
 
 
 
1
+ diffusers==0.27.2
2
  transformers
3
+ gradio==4.25.0
4
  torch==2.1.0
5
  fastapi==0.104.0
6
  uvicorn==0.23.2
 
13
  hf_transfer
14
  huggingface_hub
15
  safetensors
16
+ compel
17
+ stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.2/stable_fast-1.0.2+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
+ spaces
19
+ gradio_promptweighting @ ./gradio_promptweighting-0.0.1-py3-none-any.whl