mbrack commited on
Commit
c4e0b79
1 Parent(s): 0699a43

add style guidance

Browse files
Files changed (1) hide show
  1. app.py +217 -53
app.py CHANGED
@@ -19,9 +19,16 @@ if disable_safety:
19
  return images, False
20
  pipe.safety_checker = null_safety
21
 
 
 
 
 
 
22
 
23
  def infer(prompt, steps, scale, seed, editing_prompt_1 = None, reverse_editing_direction_1 = False, edit_warmup_steps_1=10, edit_guidance_scale_1=5, edit_threshold_1=0.95,
24
  editing_prompt_2 = None, reverse_editing_direction_2 = False, edit_warmup_steps_2=10, edit_guidance_scale_2=5, edit_threshold_2=0.95,
 
 
25
  edit_momentum_scale=0.5, edit_mom_beta=0.6):
26
 
27
 
@@ -42,15 +49,52 @@ def infer(prompt, steps, scale, seed, editing_prompt_1 = None, reverse_editing_d
42
  del edit_warmup_steps[index]
43
  del edit_guidance_scale[index]
44
  del edit_threshold[index]
 
 
45
 
 
 
 
 
 
 
 
46
 
47
  gen.manual_seed(seed)
48
  images.extend(pipe(prompt, guidance_scale=scale, num_inference_steps=steps, generator=gen,
49
- editing_prompt=editing_prompt, reverse_editing_direction=reverse_editing_direction, edit_warmup_steps=edit_warmup_steps, edit_guidance_scale=edit_guidance_scale,
 
50
  edit_momentum_scale=edit_momentum_scale, edit_mom_beta=edit_mom_beta
51
  ).images)
52
 
53
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  css = """
56
  a {
@@ -144,13 +188,18 @@ examples = [
144
  'sunglasses',
145
  False,
146
  10,
147
- 6,
148
  0.95,
149
  '',
150
  False,
151
  10,
152
  5,
153
- 0.95
 
 
 
 
 
154
  ],
155
  [
156
  'an image of a crowded boulevard, realistic, 4k',
@@ -166,7 +215,12 @@ examples = [
166
  False,
167
  10,
168
  5,
169
- 0.95
 
 
 
 
 
170
  ],
171
  [
172
  'a castle next to a river',
@@ -182,6 +236,11 @@ examples = [
182
  False,
183
  18,
184
  6,
 
 
 
 
 
185
  0.8
186
  ],
187
  [
@@ -198,7 +257,12 @@ examples = [
198
  False,
199
  5,
200
  5,
201
- 0.9
 
 
 
 
 
202
  ],
203
  [
204
  'a photo of a flowerpot',
@@ -214,7 +278,12 @@ examples = [
214
  False,
215
  10,
216
  5,
217
- 0.95
 
 
 
 
 
218
  ],
219
  [
220
  'a photo of the face of a woman',
@@ -230,10 +299,79 @@ examples = [
230
  False,
231
  13,
232
  3,
233
- 0.925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  ],
235
  ]
236
 
 
237
  with block:
238
  gr.HTML(
239
  """
@@ -268,48 +406,58 @@ with block:
268
  margin=False,
269
  rounded=(False, True, True, False),
270
  )
271
- with gr.Box():
272
- with gr.Row().style(mobile_collapse=False, equal_height=True):
273
- edit_1 = gr.Textbox(
274
- label="Edit Prompt 1",
275
- show_label=False,
276
- max_lines=1,
277
- placeholder="Enter your 1st edit prompt",
278
- ).style(
279
- border=(True, False, True, True),
280
- rounded=(True, False, False, True),
281
- container=False,
282
- )
283
- with gr.Group():
284
- with gr.Row().style(mobile_collapse=False, equal_height=True):
285
- rev_1 = gr.Checkbox(
286
- label='Reverse')
287
- warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
288
- scale_1 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
289
- threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
290
- with gr.Row().style(mobile_collapse=False, equal_height=True):
291
- edit_2 = gr.Textbox(
292
- label="Edit Prompt 2",
293
- show_label=False,
294
- max_lines=1,
295
- placeholder="Enter your 2nd edit prompt",
296
- ).style(
297
- border=(True, False, True, True),
298
- rounded=(True, False, False, True),
299
- container=False,
300
- )
301
- with gr.Group():
302
- with gr.Row().style(mobile_collapse=False, equal_height=True):
303
- rev_2 = gr.Checkbox(
304
- label='Reverse')
305
- warmup_2 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
306
- scale_2 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
307
- threshold_2 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
308
-
309
-
310
-
 
 
 
 
 
 
 
 
 
 
311
  gallery = gr.Gallery(
312
- label="Generated images", show_label=False, elem_id="gallery"
313
  ).style(grid=[2], height="auto")
314
 
315
 
@@ -324,13 +472,29 @@ with block:
324
  #randomize=True,
325
  )
326
 
 
 
 
327
 
328
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2], outputs=gallery, cache_examples=False)
329
- ex.dataset.headers = ['Prompt', 'Steps', 'Scale', 'Seed', 'Edit Prompt 1', 'Reverse 1', 'Warmup 1', 'Scale 1', 'Threshold 1', 'Edit Prompt 2', 'Reverse 2', 'Warmup 2', 'Scale 2', 'Threshold 2']
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- text.submit(infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2], outputs=gallery)
333
- btn.click(infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2], outputs=gallery)
334
  gr.HTML(
335
  """
336
  <div class="footer">
 
19
  return images, False
20
  pipe.safety_checker = null_safety
21
 
22
+
23
+ style_embeddings = {
24
+ 'Concept Art': torch.load('embeddings/concept_art.pt'), 'Animation': torch.load('embeddings/animation.pt'), 'Character Design': torch.load('embeddings/character_design.pt')
25
+ , 'Portrait Photo': torch.load('embeddings/portrait_photo.pt'), 'Architecture': torch.load('embeddings/architecture.pt')
26
+ }
27
 
28
  def infer(prompt, steps, scale, seed, editing_prompt_1 = None, reverse_editing_direction_1 = False, edit_warmup_steps_1=10, edit_guidance_scale_1=5, edit_threshold_1=0.95,
29
  editing_prompt_2 = None, reverse_editing_direction_2 = False, edit_warmup_steps_2=10, edit_guidance_scale_2=5, edit_threshold_2=0.95,
30
+ edit_style=None,
31
+ reverse_editing_direction_style = False, edit_warmup_steps_style=5, edit_guidance_scale_style=7, edit_threshold_style=0.8,
32
  edit_momentum_scale=0.5, edit_mom_beta=0.6):
33
 
34
 
 
49
  del edit_warmup_steps[index]
50
  del edit_guidance_scale[index]
51
  del edit_threshold[index]
52
+ editing_prompt_embeddings = None
53
+
54
 
55
+ if edit_style is not None:
56
+ editing_prompt = None
57
+ reverse_editing_direction = reverse_editing_direction_style
58
+ edit_warmup_steps = edit_warmup_steps_style
59
+ edit_guidance_scale = edit_guidance_scale_style
60
+ edit_threshold = edit_threshold_style
61
+ editing_prompt_embeddings = style_embeddings[edit_style]
62
 
63
  gen.manual_seed(seed)
64
  images.extend(pipe(prompt, guidance_scale=scale, num_inference_steps=steps, generator=gen,
65
+ editing_prompt=editing_prompt, editing_prompt_embeddings=editing_prompt_embeddings,
66
+ reverse_editing_direction=reverse_editing_direction, edit_warmup_steps=edit_warmup_steps, edit_guidance_scale=edit_guidance_scale,
67
  edit_momentum_scale=edit_momentum_scale, edit_mom_beta=edit_mom_beta
68
  ).images)
69
 
70
+ return zip(images, ['Original', edit_style if edit_style is not None else 'SEGA'])
71
+
72
+ def reset_style():
73
+ radio = gr.Radio(label='Style', choices=['Concept Art', 'Animation', 'Character Design', 'Portrait Photo', 'Architecture'])
74
+ return radio
75
+
76
+ def reset_text():
77
+ text_1 = gr.Textbox(
78
+ label="Edit Prompt 1",
79
+ show_label=False,
80
+ max_lines=1,
81
+ placeholder="Enter your 1st edit prompt",
82
+ ).style(
83
+ border=(True, False, True, True),
84
+ rounded=(True, False, False, True),
85
+ container=False,
86
+ )
87
+ text_2 = gr.Textbox(
88
+ label="Edit Prompt 2",
89
+ show_label=False,
90
+ max_lines=1,
91
+ placeholder="Enter your 2nd edit prompt",
92
+ ).style(
93
+ border=(True, False, True, True),
94
+ rounded=(True, False, False, True),
95
+ container=False,
96
+ )
97
+ return text_1, text_2
98
 
99
  css = """
100
  a {
 
188
  'sunglasses',
189
  False,
190
  10,
191
+ 5,
192
  0.95,
193
  '',
194
  False,
195
  10,
196
  5,
197
+ 0.95,
198
+ '',
199
+ False,
200
+ 5,
201
+ 7,
202
+ 0.8,
203
  ],
204
  [
205
  'an image of a crowded boulevard, realistic, 4k',
 
215
  False,
216
  10,
217
  5,
218
+ 0.95,
219
+ '',
220
+ False,
221
+ 5,
222
+ 7,
223
+ 0.8
224
  ],
225
  [
226
  'a castle next to a river',
 
236
  False,
237
  18,
238
  6,
239
+ 0.8,
240
+ '',
241
+ False,
242
+ 5,
243
+ 7,
244
  0.8
245
  ],
246
  [
 
257
  False,
258
  5,
259
  5,
260
+ 0.9,
261
+ '',
262
+ False,
263
+ 5,
264
+ 7,
265
+ 0.8
266
  ],
267
  [
268
  'a photo of a flowerpot',
 
278
  False,
279
  10,
280
  5,
281
+ 0.95,
282
+ '',
283
+ False,
284
+ 5,
285
+ 7,
286
+ 0.8
287
  ],
288
  [
289
  'a photo of the face of a woman',
 
299
  False,
300
  13,
301
  3,
302
+ 0.925,
303
+ '',
304
+ False,
305
+ 5,
306
+ 7,
307
+ 0.8
308
+ ],
309
+ [
310
+ 'temple in ruines, forest, stairs, columns',
311
+ 50,
312
+ 7,
313
+ 11,
314
+ '',
315
+ False,
316
+ 10,
317
+ 5,
318
+ 0.95,
319
+ '',
320
+ False,
321
+ 10,
322
+ 5,
323
+ 0.95,
324
+ 'Animation',
325
+ False,
326
+ 5,
327
+ 7,
328
+ 0.8
329
+ ],
330
+ [
331
+ 'city made out of glass',
332
+ 50,
333
+ 7,
334
+ 16,
335
+ '',
336
+ False,
337
+ 10,
338
+ 5,
339
+ 0.95,
340
+ '',
341
+ False,
342
+ 10,
343
+ 5,
344
+ 0.95,
345
+ 'Concept Art',
346
+ False,
347
+ 10,
348
+ 8,
349
+ 0.8
350
+ ],
351
+ [
352
+ 'a man riding a horse',
353
+ 50,
354
+ 7,
355
+ 11,
356
+ '',
357
+ False,
358
+ 10,
359
+ 5,
360
+ 0.95,
361
+ '',
362
+ False,
363
+ 10,
364
+ 5,
365
+ 0.95,
366
+ 'Character Design',
367
+ False,
368
+ 11,
369
+ 8,
370
+ 0.9
371
  ],
372
  ]
373
 
374
+
375
  with block:
376
  gr.HTML(
377
  """
 
406
  margin=False,
407
  rounded=(False, True, True, False),
408
  )
409
+ with gr.Tabs() as tabs:
410
+ with gr.TabItem('Text Guidance', id=0):
411
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
412
+ edit_1 = gr.Textbox(
413
+ label="Edit Prompt 1",
414
+ show_label=False,
415
+ max_lines=1,
416
+ placeholder="Enter your 1st edit prompt",
417
+ ).style(
418
+ border=(True, False, True, True),
419
+ rounded=(True, False, False, True),
420
+ container=False,
421
+ )
422
+ with gr.Group():
423
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
424
+ rev_1 = gr.Checkbox(
425
+ label='Negative Guidance')
426
+ warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
427
+ scale_1 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
428
+ threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
429
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
430
+ edit_2 = gr.Textbox(
431
+ label="Edit Prompt 2",
432
+ show_label=False,
433
+ max_lines=1,
434
+ placeholder="Enter your 2nd edit prompt",
435
+ ).style(
436
+ border=(True, False, True, True),
437
+ rounded=(True, False, False, True),
438
+ container=False,
439
+ )
440
+ with gr.Group():
441
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
442
+ rev_2 = gr.Checkbox(
443
+ label='Negative Guidance')
444
+ warmup_2 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
445
+ scale_2 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
446
+ threshold_2 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
447
+ with gr.TabItem("Style Guidance", id=1):
448
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
449
+ style = gr.Radio(label='Style', choices=['Concept Art', 'Animation', 'Character Design', 'Portrait Photo', 'Architecture'], interactive=True)
450
+ with gr.Group():
451
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
452
+ rev_style = gr.Checkbox(
453
+ label='Negative Guidance', interactive=False)
454
+ warmup_style = gr.Slider(label='Warmup', minimum=0, maximum=50, value=5, step=1, interactive=True)
455
+ scale_style = gr.Slider(label='Scale', minimum=1, maximum=10, value=7, step=0.25, interactive=True)
456
+ threshold_style = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.8, steps=0.01, interactive=True)
457
+
458
+
459
  gallery = gr.Gallery(
460
+ label=("Generated images"), show_label=False, elem_id="gallery"
461
  ).style(grid=[2], height="auto")
462
 
463
 
 
472
  #randomize=True,
473
  )
474
 
475
+
476
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2, style, rev_style, warmup_style, scale_style, threshold_style], outputs=gallery, cache_examples=False)
477
+ ex.dataset.headers = ['Prompt', 'Steps', 'Scale', 'Seed', 'Edit Prompt 1', 'Negation 1', 'Warmup 1', 'Scale 1', 'Threshold 1', 'Edit Prompt 2', 'Negation 2', 'Warmup 2', 'Scale 2', 'Threshold 2', 'Style', 'Style Negation', 'Style Warmup', 'Style Scale', 'Style Threshold']
478
 
 
 
479
 
480
+ text.submit(infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2, style, rev_style, warmup_style, scale_style, threshold_style], outputs=gallery)
481
+ btn.click(infer, inputs=[text, steps, scale, seed, edit_1, rev_1, warmup_1, scale_1, threshold_1, edit_2, rev_2, warmup_2, scale_2, threshold_2, style, rev_style, warmup_style, scale_style, threshold_style], outputs=gallery)
482
+ #btn.click(change_tab, None, tabs)
483
+
484
+ edit_1.change(reset_style, outputs=style)
485
+ edit_2.change(reset_style, outputs=style)
486
+
487
+ rev_1.change(reset_style, outputs=style)
488
+ rev_2.change(reset_style, outputs=style)
489
+
490
+ warmup_1.change(reset_style, outputs=style)
491
+ warmup_2.change(reset_style, outputs=style)
492
+
493
+ threshold_1.change(reset_style, outputs=style)
494
+ threshold_2.change(reset_style, outputs=style)
495
+ #style.change(reset_text, outputs=[edit_1, edit_2])
496
 
497
+
 
498
  gr.HTML(
499
  """
500
  <div class="footer">