hideosnes commited on
Commit
d51b2d2
1 Parent(s): 73b7db1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -99
app.py CHANGED
@@ -21,12 +21,12 @@ snapshot_download(
21
  repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
22
  )
23
 
24
- # CPU fallback & pipeline-definition
25
  MAX_SEED = np.iinfo(np.int32).max
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
 
29
- # load models & scheduler (==>EULER) & CN (==>canny > test what's better!!!)
30
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
31
  image_encoder_path = "sdxl_models/image_encoder"
32
  ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
@@ -36,14 +36,14 @@ controlnet = ControlNetModel.from_pretrained(
36
  controlnet_path, use_safetensors=False, torch_dtype=torch.float16
37
  ).to(device)
38
 
39
- # load SDXL lightning >> put Turbo here if fallback to Comfy @Litto
40
 
41
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
42
  base_model_path,
43
- controlnet = controlnet,
44
  torch_dtype=torch.float16,
45
  variant="fp16",
46
- add_watermark=False,
47
  ).to(device)
48
  pipe.set_progress_bar_config(disable=True)
49
  pipe.scheduler = EulerDiscreteScheduler.from_config(
@@ -51,14 +51,14 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(
51
  )
52
  pipe.unet.load_state_dict(
53
  load_file(
54
- hf_hub_download(
55
- "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
56
- ),
57
- device="cuda",
58
- )
59
  )
60
 
61
- # load ip-adapter with specific target blocks for style transfer and layout preservation. Should be better than Comfy! Test this!
62
  # target_blocks=["block"] for original IP-Adapter
63
  # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
64
  # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
@@ -67,12 +67,9 @@ ip_model = IPAdapterXL(
67
  image_encoder_path,
68
  ip_ckpt,
69
  device,
70
- target_blocks=["up_blocks.0.attentions.1"]
71
  )
72
 
73
- # Resizing the input image
74
- # OpenCV goes here!!!
75
- # Test this with smaller side-no for faster infr
76
 
77
  def resize_img(
78
  input_image,
@@ -91,9 +88,8 @@ def resize_img(
91
  w, h = round(ratio * w), round(ratio * h)
92
  ratio = max_side / max(h, w)
93
  input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
94
- w = (round(ratio * w) // base_pixel_number) * base_pixel_number
95
- w = (round(ratio * h) // base_pixel_number) * base_pixel_number
96
- nput_image.resize([w_resize_new, h_resize_new], mode)
97
  input_image = input_image.resize([w_resize_new, h_resize_new], mode)
98
 
99
  if pad_to_max_side:
@@ -106,31 +102,52 @@ def resize_img(
106
  input_image = Image.fromarray(res)
107
  return input_image
108
 
109
- # expand example images for endpoints --> info an Johannes/Jascha what to expect
110
 
111
  examples = [
112
  [
113
- "./asset/0.jpg",
 
 
 
 
 
 
 
114
  None,
115
- "3D model, cute monster, test prompt",
116
  1.0,
117
  0.0,
118
  ],
119
  [
120
- "./asset/2.jpg",
121
- "./asset/house.jpg",
122
- "3D model, cute, kawai, house, another test prompt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  1.0,
124
  0.6,
125
  ],
126
  ]
127
 
 
128
  def run_for_examples(style_image, source_image, prompt, scale, control_scale):
129
  return create_image(
130
  image_pil=style_image,
131
  input_image=source_image,
132
  prompt=prompt,
133
- n_prompt="text, watermark, low res, low quality, worst quality, deformed, blurry",
134
  scale=scale,
135
  control_scale=control_scale,
136
  guidance_scale=0.0,
@@ -141,7 +158,6 @@ def run_for_examples(style_image, source_image, prompt, scale, control_scale):
141
  neg_content_scale=0,
142
  )
143
 
144
- # Main function for image synthesis (input -> run_for_examples)
145
 
146
  @spaces.GPU(enable_queue=True)
147
  def create_image(
@@ -167,12 +183,20 @@ def create_image(
167
  elif target == "Load only style blocks":
168
  # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
169
  ip_model = IPAdapterXL(
170
- pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1"],
 
 
 
 
171
  )
172
  elif target == "Load style+layout block":
173
  # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
174
  ip_model = IPAdapterXL(
175
- pipe, image_encoder_path, ip_ckpt, device, target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
 
 
 
 
176
  )
177
 
178
  if input_image is not None:
@@ -181,7 +205,7 @@ def create_image(
181
  detected_map = cv2.Canny(cv_input_image, 50, 200)
182
  canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
183
  else:
184
- canny_map = Image.new("RGB", (1024, 1024), color=(255,255,255))
185
  control_scale = 0
186
 
187
  if float(control_scale) == 0:
@@ -189,7 +213,22 @@ def create_image(
189
 
190
  if len(neg_content_prompt) > 0 and neg_content_scale != 0:
191
  images = ip_model.generate(
192
- pil_image_image_pil,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  prompt=prompt,
194
  negative_prompt=n_prompt,
195
  scale=scale,
@@ -202,31 +241,47 @@ def create_image(
202
  )
203
  image = images[0]
204
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
205
- image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True) # check what happens to imgs when this changes!!!
206
  return Path(tmpfile.name)
207
-
 
208
  def pil_to_cv2(image_pil):
209
  image_np = np.array(image_pil)
210
  image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
211
  return image_cv2
212
 
213
- # Gradio Description & Frontend Stuff for Space (remove this for Endpoint)
 
214
  title = r"""
215
- <h1 align="center">MewMewMew: Simsalabim!</h1>
216
  """
217
 
218
  description = r"""
219
- <b>Let's test this! ARM <3 GoldExtra</b><br>
220
- <b>SDXL-Lightning && IP-Adapter</b>
221
  """
222
 
223
  article = r"""
224
- Ask Hidéo if something breaks: <a href="mailto:hideo@artificialmuseum.com">Hidéo's Mail</a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  """
226
 
227
  block = gr.Blocks()
228
  with block:
229
- #description
230
  gr.Markdown(title)
231
  gr.Markdown(description)
232
 
@@ -239,71 +294,77 @@ with block:
239
  with gr.Column():
240
  prompt = gr.Textbox(
241
  label="Prompt",
242
- value="mewmewmew, kitty cats, unicorns, uWu",
243
  )
244
-
245
  scale = gr.Slider(
246
- minimum=0, maximum=2.0, step=0.01, value=1.0, label="Maßstab // scale"
247
- )
248
- with gr.Accordion(open=False, label="Für Details erweitern!"):
249
- target = gr.Radio(
250
- [
251
- "Load only style blocks",
252
- "Load style+layout block",
253
- "Load original IP-Adapter",
254
- ],
255
- value="Load only style blocks",
256
- label="Modus für IP-Adapter auswählen"
257
  )
258
-
259
- with gr.Column():
260
- src_image_pil = gr.Image(
261
- label="Guidance Image (optional)", type="pil"
262
- )
263
- control_scale = gr.Slider(
264
- minimum=0, maximum=1.0, step=0.1, value=0.5,
265
- label="ControlNet-Stärke // control_scale",
266
- )
267
- n_prompt = gr.Textbox(
268
- label="Negative Prompts",
269
- value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
270
- )
271
- neg_content_prompt = gr.Textbox(
272
- label="Negative Content Prompt (optional)", value=""
273
- )
274
- neg_content_scale = gr.Slider(
275
- minimum=0,
276
- maximum=1.0,
277
- step=0.1,
278
- value=0.5,
279
- label="Negative Content Stärke // neg_content_scale"
280
- )
281
- guidance_scale = gr.Slider(
282
- minimum=0,
283
- maximum=10.0,
284
- step=0.01,
285
- value=0.0,
286
- label="guidance-scale"
287
- )
288
- num_inference_steps = gr.Slider(
289
- minimum=2,
290
- maximum=50.0,
291
- step=1.0,
292
- value=2,
293
- label="Anzahl der Inference Steps (optional) // num_inference_steps"
294
- )
295
- seed = gr.Slider(
296
- minimum=-1,
297
- maximum=MAX_SEED,
298
- value=-1,
299
- step=1,
300
- label="Seed Value // -1 = random // Seed-Proof=True"
301
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- generate_button = gr.Button("Simsalabim")
304
 
305
  with gr.Column():
306
- generated_image = gr.Image(label="MewMewMagix uWu")
307
 
308
  inputs = [
309
  image_pil,
@@ -343,10 +404,10 @@ with block:
343
  inputs=[image_pil, src_image_pil, prompt, scale, control_scale],
344
  fn=run_for_examples,
345
  outputs=[generated_image],
346
- cache_examples=False,
347
  )
348
 
349
  gr.Markdown(article)
350
 
351
- block.queue(api_open=False)
352
- block.launch(show_api=False)
 
21
  repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
22
  )
23
 
24
+ # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
28
 
29
+ # initialization
30
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
31
  image_encoder_path = "sdxl_models/image_encoder"
32
  ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
 
36
  controlnet_path, use_safetensors=False, torch_dtype=torch.float16
37
  ).to(device)
38
 
39
+ # load SDXL lightnining
40
 
41
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
42
  base_model_path,
43
+ controlnet=controlnet,
44
  torch_dtype=torch.float16,
45
  variant="fp16",
46
+ add_watermarker=False,
47
  ).to(device)
48
  pipe.set_progress_bar_config(disable=True)
49
  pipe.scheduler = EulerDiscreteScheduler.from_config(
 
51
  )
52
  pipe.unet.load_state_dict(
53
  load_file(
54
+ hf_hub_download(
55
+ "ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
56
+ ),
57
+ device="cuda",
58
+ )
59
  )
60
 
61
+ # load ip-adapter
62
  # target_blocks=["block"] for original IP-Adapter
63
  # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
64
  # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
 
67
  image_encoder_path,
68
  ip_ckpt,
69
  device,
70
+ target_blocks=["up_blocks.0.attentions.1"],
71
  )
72
 
 
 
 
73
 
74
  def resize_img(
75
  input_image,
 
88
  w, h = round(ratio * w), round(ratio * h)
89
  ratio = max_side / max(h, w)
90
  input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
91
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
92
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
 
93
  input_image = input_image.resize([w_resize_new, h_resize_new], mode)
94
 
95
  if pad_to_max_side:
 
102
  input_image = Image.fromarray(res)
103
  return input_image
104
 
 
105
 
106
  examples = [
107
  [
108
+ "./assets/0.jpg",
109
+ None,
110
+ "a cat, masterpiece, best quality, high quality",
111
+ 1.0,
112
+ 0.0,
113
+ ],
114
+ [
115
+ "./assets/1.jpg",
116
  None,
117
+ "a cat, masterpiece, best quality, high quality",
118
  1.0,
119
  0.0,
120
  ],
121
  [
122
+ "./assets/2.jpg",
123
+ None,
124
+ "a cat, masterpiece, best quality, high quality",
125
+ 1.0,
126
+ 0.0,
127
+ ],
128
+ [
129
+ "./assets/3.jpg",
130
+ None,
131
+ "a cat, masterpiece, best quality, high quality",
132
+ 1.0,
133
+ 0.0,
134
+ ],
135
+ [
136
+ "./assets/2.jpg",
137
+ "./assets/yann-lecun.jpg",
138
+ "a man, masterpiece, best quality, high quality",
139
  1.0,
140
  0.6,
141
  ],
142
  ]
143
 
144
+
145
  def run_for_examples(style_image, source_image, prompt, scale, control_scale):
146
  return create_image(
147
  image_pil=style_image,
148
  input_image=source_image,
149
  prompt=prompt,
150
+ n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
151
  scale=scale,
152
  control_scale=control_scale,
153
  guidance_scale=0.0,
 
158
  neg_content_scale=0,
159
  )
160
 
 
161
 
162
  @spaces.GPU(enable_queue=True)
163
  def create_image(
 
183
  elif target == "Load only style blocks":
184
  # target_blocks=["up_blocks.0.attentions.1"] for style blocks only
185
  ip_model = IPAdapterXL(
186
+ pipe,
187
+ image_encoder_path,
188
+ ip_ckpt,
189
+ device,
190
+ target_blocks=["up_blocks.0.attentions.1"],
191
  )
192
  elif target == "Load style+layout block":
193
  # target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
194
  ip_model = IPAdapterXL(
195
+ pipe,
196
+ image_encoder_path,
197
+ ip_ckpt,
198
+ device,
199
+ target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
200
  )
201
 
202
  if input_image is not None:
 
205
  detected_map = cv2.Canny(cv_input_image, 50, 200)
206
  canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
207
  else:
208
+ canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
209
  control_scale = 0
210
 
211
  if float(control_scale) == 0:
 
213
 
214
  if len(neg_content_prompt) > 0 and neg_content_scale != 0:
215
  images = ip_model.generate(
216
+ pil_image=image_pil,
217
+ prompt=prompt,
218
+ negative_prompt=n_prompt,
219
+ scale=scale,
220
+ guidance_scale=guidance_scale,
221
+ num_samples=1,
222
+ num_inference_steps=num_inference_steps,
223
+ seed=seed,
224
+ image=canny_map,
225
+ controlnet_conditioning_scale=float(control_scale),
226
+ neg_content_prompt=neg_content_prompt,
227
+ neg_content_scale=neg_content_scale,
228
+ )
229
+ else:
230
+ images = ip_model.generate(
231
+ pil_image=image_pil,
232
  prompt=prompt,
233
  negative_prompt=n_prompt,
234
  scale=scale,
 
241
  )
242
  image = images[0]
243
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
244
+ image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
245
  return Path(tmpfile.name)
246
+
247
+
248
  def pil_to_cv2(image_pil):
249
  image_np = np.array(image_pil)
250
  image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
251
  return image_cv2
252
 
253
+
254
+ # Description
255
  title = r"""
256
+ <h1 align="center">InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
257
  """
258
 
259
  description = r"""
260
+ <b>Forked from <a href='https://github.com/InstantStyle/InstantStyle' target='_blank'>InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</a>.<br>
261
+ <b>Model by <a href='https://huggingface.co/ByteDance/SDXL-Lightning' target='_blank'>SDXL Lightning</a> and <a href='https://huggingface.co/h94/IP-Adapter' target='_blank'>IP-Adapter</a>.</b><br>
262
  """
263
 
264
  article = r"""
265
+ ---
266
+ 📝 **Citation**
267
+ <br>
268
+ If our work is helpful for your research or applications, please cite us via:
269
+ ```bibtex
270
+ @article{wang2024instantstyle,
271
+ title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
272
+ author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
273
+ journal={arXiv preprint arXiv:2404.02733},
274
+ year={2024}
275
+ }
276
+ ```
277
+ 📧 **Contact**
278
+ <br>
279
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
280
  """
281
 
282
  block = gr.Blocks()
283
  with block:
284
+ # description
285
  gr.Markdown(title)
286
  gr.Markdown(description)
287
 
 
294
  with gr.Column():
295
  prompt = gr.Textbox(
296
  label="Prompt",
297
+ value="a cat, masterpiece, best quality, high quality",
298
  )
299
+
300
  scale = gr.Slider(
301
+ minimum=0, maximum=2.0, step=0.01, value=1.0, label="Scale"
 
 
 
 
 
 
 
 
 
 
302
  )
303
+
304
+ with gr.Accordion(open=False, label="Advanced Options"):
305
+ target = gr.Radio(
306
+ [
307
+ "Load only style blocks",
308
+ "Load style+layout block",
309
+ "Load original IP-Adapter",
310
+ ],
311
+ value="Load only style blocks",
312
+ label="Style mode",
313
+ )
314
+ with gr.Column():
315
+ src_image_pil = gr.Image(
316
+ label="Source Image (optional)", type="pil"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  )
318
+ control_scale = gr.Slider(
319
+ minimum=0,
320
+ maximum=1.0,
321
+ step=0.01,
322
+ value=0.5,
323
+ label="Controlnet conditioning scale",
324
+ )
325
+
326
+ n_prompt = gr.Textbox(
327
+ label="Neg Prompt",
328
+ value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
329
+ )
330
+
331
+ neg_content_prompt = gr.Textbox(
332
+ label="Neg Content Prompt", value=""
333
+ )
334
+ neg_content_scale = gr.Slider(
335
+ minimum=0,
336
+ maximum=1.0,
337
+ step=0.01,
338
+ value=0.5,
339
+ label="Neg Content Scale",
340
+ )
341
+
342
+ guidance_scale = gr.Slider(
343
+ minimum=0,
344
+ maximum=10.0,
345
+ step=0.01,
346
+ value=0.0,
347
+ label="guidance scale",
348
+ )
349
+ num_inference_steps = gr.Slider(
350
+ minimum=2,
351
+ maximum=50.0,
352
+ step=1.0,
353
+ value=2,
354
+ label="num inference steps",
355
+ )
356
+ seed = gr.Slider(
357
+ minimum=-1,
358
+ maximum=MAX_SEED,
359
+ value=-1,
360
+ step=1,
361
+ label="Seed Value",
362
+ )
363
 
364
+ generate_button = gr.Button("Generate Image")
365
 
366
  with gr.Column():
367
+ generated_image = gr.Image(label="Generated Image")
368
 
369
  inputs = [
370
  image_pil,
 
404
  inputs=[image_pil, src_image_pil, prompt, scale, control_scale],
405
  fn=run_for_examples,
406
  outputs=[generated_image],
407
+ cache_examples=True,
408
  )
409
 
410
  gr.Markdown(article)
411
 
412
+ block.queue(api_open=False)
413
+ block.launch(show_api=False)