multimodalart HF staff commited on
Commit
f33c43f
1 Parent(s): 178e606

Performance PR

Browse files

- Swap `LCM LoRA` to `SDXL Lightening 2 steps` (faster, more quality)
- Switch regular VAE to tiny TAESD VAE
- Add header and mention to blog
- Add keyboard navigation (`A` to Dislike, `Space` for Neither and `L` to like)
- Disable Safety Filter (redundant in SDXL for this use-case and lots of false positives)

Performance result on A10G:
- < 1s per image

Files changed (1) hide show
  1. app.py +46 -22
app.py CHANGED
@@ -6,7 +6,7 @@ from sklearn.svm import LinearSVC
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
- from diffusers import LCMScheduler
10
  from diffusers.models import ImageProjection
11
  from patch_sdxl import SDEmb
12
  import torch
@@ -22,6 +22,9 @@ from PIL import Image
22
  import requests
23
  from io import BytesIO, StringIO
24
 
 
 
 
25
  prompt_list = [p for p in list(set(
26
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
27
 
@@ -29,12 +32,16 @@ start_time = time.time()
29
 
30
  ####################### Setup Model
31
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
32
- lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
33
- pipe = SDEmb.from_pretrained(model_id, variant="fp16", low_cpu_mem_usage=True, device_map="auto")
34
- pipe.load_lora_weights(lcm_lora_id)
35
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
36
- pipe.to(device='cuda', dtype=torch.float16)
 
 
 
37
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
 
38
  output_hidden_state = False
39
  #######################
40
 
@@ -53,7 +60,7 @@ def predict(
53
  ip_adapter_emb=im_emb.to('cuda'),
54
  height=1024,
55
  width=1024,
56
- num_inference_steps=8,
57
  guidance_scale=0,
58
  ).images[0]
59
  im_emb, _ = pipe.encode_image(
@@ -61,12 +68,6 @@ def predict(
61
  )
62
  return image, im_emb.to(DEVICE)
63
 
64
-
65
-
66
-
67
-
68
-
69
-
70
  # TODO add to state instead of shared across all
71
  glob_idx = 0
72
 
@@ -133,9 +134,9 @@ def next_image(embs, ys, calibrate_prompts):
133
  def start(_, embs, ys, calibrate_prompts):
134
  image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
135
  return [
136
- gr.Button(value='Like', interactive=True),
137
- gr.Button(value='Neither', interactive=True),
138
- gr.Button(value='Dislike', interactive=True),
139
  gr.Button(value='Start', interactive=False),
140
  image,
141
  embs,
@@ -157,9 +158,32 @@ def choose(choice, embs, ys, calibrate_prompts):
157
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
158
  return img, embs, ys, calibrate_prompts
159
 
160
- css = ".gradio-container{max-width: 700px !important}"
161
- print(css)
162
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  embs = gr.State([])
164
  ys = gr.State([])
165
  calibrate_prompts = gr.State([
@@ -177,9 +201,9 @@ with gr.Blocks(css=css) as demo:
177
  with gr.Row(elem_id='output-image'):
178
  img = gr.Image(interactive=False, elem_id='output-image',width=700)
179
  with gr.Row(equal_height=True):
180
- b3 = gr.Button(value='Dislike', interactive=False,)
181
- b2 = gr.Button(value='Neither', interactive=False,)
182
- b1 = gr.Button(value='Like', interactive=False,)
183
  b1.click(
184
  choose,
185
  [b1, embs, ys, calibrate_prompts],
 
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
+ from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
10
  from diffusers.models import ImageProjection
11
  from patch_sdxl import SDEmb
12
  import torch
 
22
  import requests
23
  from io import BytesIO, StringIO
24
 
25
+ from huggingface_hub import hf_hub_download
26
+ from safetensors.torch import load_file
27
+
28
  prompt_list = [p for p in list(set(
29
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
30
 
 
32
 
33
  ####################### Setup Model
34
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
+ sdxl_lightening = "ByteDance/SDXL-Lightning"
36
+ ckpt = "sdxl_lightning_2step_unet.safetensors"
37
+ unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
38
+ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
39
+ pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
41
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
42
+ pipe.to(device='cuda')
43
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
44
+
45
  output_hidden_state = False
46
  #######################
47
 
 
60
  ip_adapter_emb=im_emb.to('cuda'),
61
  height=1024,
62
  width=1024,
63
+ num_inference_steps=2,
64
  guidance_scale=0,
65
  ).images[0]
66
  im_emb, _ = pipe.encode_image(
 
68
  )
69
  return image, im_emb.to(DEVICE)
70
 
 
 
 
 
 
 
71
  # TODO add to state instead of shared across all
72
  glob_idx = 0
73
 
 
134
  def start(_, embs, ys, calibrate_prompts):
135
  image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
136
  return [
137
+ gr.Button(value='Like (L)', interactive=True),
138
+ gr.Button(value='Neither (Space)', interactive=True),
139
+ gr.Button(value='Dislike (A)', interactive=True),
140
  gr.Button(value='Start', interactive=False),
141
  image,
142
  embs,
 
158
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
159
  return img, embs, ys, calibrate_prompts
160
 
161
+ css = '''.gradio-container{max-width: 700px !important}
162
+ #description{text-align: center}
163
+ #description h1{display: block}
164
+ #description p{margin-top: 0}
165
+ '''
166
+ js = '''
167
+ <script>
168
+ document.addEventListener('keydown', function(event) {
169
+ if (event.key === 'a' || event.key === 'A') {
170
+ // Trigger click on 'dislike' if 'A' is pressed
171
+ document.getElementById('dislike').click();
172
+ } else if (event.key === ' ' || event.keyCode === 32) {
173
+ // Trigger click on 'neither' if Spacebar is pressed
174
+ document.getElementById('neither').click();
175
+ } else if (event.key === 'l' || event.key === 'L') {
176
+ // Trigger click on 'like' if 'L' is pressed
177
+ document.getElementById('like').click();
178
+ }
179
+ });
180
+ </script>
181
+ '''
182
+
183
+ with gr.Blocks(css=css, head=js) as demo:
184
+ gr.Markdown('''# Generative Recommenders
185
+ Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/)
186
+ ''', elem_id="description")
187
  embs = gr.State([])
188
  ys = gr.State([])
189
  calibrate_prompts = gr.State([
 
201
  with gr.Row(elem_id='output-image'):
202
  img = gr.Image(interactive=False, elem_id='output-image',width=700)
203
  with gr.Row(equal_height=True):
204
+ b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
205
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
206
+ b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
207
  b1.click(
208
  choose,
209
  [b1, embs, ys, calibrate_prompts],