Pusheen commited on
Commit
bda45a0
·
verified ·
1 Parent(s): ab6a5ca

Upload 170 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +4 -0
  3. .gitignore +3 -0
  4. README.md +5 -5
  5. app.py +239 -217
  6. dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  7. dataset/__pycache__/__init__.cpython-38.pyc +0 -0
  8. dataset/__pycache__/catalog.cpython-310.pyc +0 -0
  9. dataset/__pycache__/catalog.cpython-38.pyc +0 -0
  10. dataset/__pycache__/concat_dataset.cpython-310.pyc +0 -0
  11. dataset/__pycache__/concat_dataset.cpython-38.pyc +0 -0
  12. environment.yaml +1 -1
  13. example_component.py +805 -0
  14. gligen/.DS_Store +0 -0
  15. gligen/SD_input_conv_weight_bias.pth +3 -0
  16. gligen/__pycache__/__init__.cpython-310.pyc +0 -0
  17. gligen/__pycache__/__init__.cpython-38.pyc +0 -0
  18. gligen/__pycache__/distributed.cpython-310.pyc +0 -0
  19. gligen/__pycache__/distributed.cpython-38.pyc +0 -0
  20. gligen/__pycache__/evaluator.cpython-310.pyc +0 -0
  21. gligen/__pycache__/evaluator.cpython-38.pyc +0 -0
  22. gligen/__pycache__/task_grounded_generation.cpython-310.pyc +0 -0
  23. gligen/__pycache__/task_grounded_generation.cpython-38.pyc +0 -0
  24. gligen/__pycache__/trainer.cpython-310.pyc +0 -0
  25. gligen/__pycache__/trainer.cpython-38.pyc +0 -0
  26. gligen/evaluator.py +1 -1
  27. gligen/ldm/.DS_Store +0 -0
  28. gligen/ldm/__pycache__/util.cpython-310.pyc +0 -0
  29. gligen/ldm/__pycache__/util.cpython-38.pyc +0 -0
  30. gligen/ldm/data/.DS_Store +0 -0
  31. gligen/ldm/data/imagenet_train_hr_indices.p +3 -0
  32. gligen/ldm/data/imagenet_val_hr_indices.p +3 -0
  33. gligen/ldm/models/.DS_Store +0 -0
  34. gligen/ldm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  35. gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  36. gligen/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  37. gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  38. gligen/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
  39. gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  40. gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc +0 -0
  41. gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  42. gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-310.pyc +0 -0
  43. gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc +0 -0
  44. gligen/ldm/models/diffusion/__pycache__/ldm.cpython-310.pyc +0 -0
  45. gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc +0 -0
  46. gligen/ldm/models/diffusion/__pycache__/loss.cpython-310.pyc +0 -0
  47. gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc +0 -0
  48. gligen/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc +0 -0
  49. gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
  50. gligen/ldm/models/diffusion/ddim.py +4 -4
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ gligen/ldm/data/imagenet_train_hr_indices.p filter=lfs diff=lfs merge=lfs -text
36
+ gligen/projection_matrix.pth filter=lfs diff=lfs merge=lfs -text
37
+ gligen/ldm/data/imagenet_val_hr_indices.p filter=lfs diff=lfs merge=lfs -text
38
+ gligen/SD_input_conv_weight_bias.pth filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -110,3 +110,6 @@ create_samples/
110
  create_samples/*
111
 
112
  ckpts/*
 
 
 
 
110
  create_samples/*
111
 
112
  ckpts/*
113
+
114
+ **/__pycache__/*
115
+ **/__pycache__
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: LoCo_Gligen Demo
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Attention Refocusing
3
+ emoji: 🌖
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import torch
3
  from omegaconf import OmegaConf
4
  from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
@@ -18,9 +19,11 @@ import warnings
18
 
19
  from datetime import datetime
20
 
 
 
21
  from huggingface_hub import hf_hub_download
22
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
23
-
24
  import sys
25
  sys.tracebacklimit = 0
26
 
@@ -39,8 +42,6 @@ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
39
  pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
40
  config = OmegaConf.create( config["_content"] ) # config used in training
41
  config.alpha_scale = 1.0
42
- config.model['params']['is_inpaint'] = is_inpaint
43
- config.model['params']['is_style'] = is_style
44
 
45
  if common_instances is None:
46
  common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
@@ -138,13 +139,25 @@ class ImageMask(gr.components.Image):
138
  if x is None:
139
  return x
140
  if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
 
141
  decode_image = processing_utils.decode_base64_to_image(x)
 
142
  width, height = decode_image.size
 
 
 
143
  mask = np.zeros((height, width, 4), dtype=np.uint8)
 
144
  mask[..., -1] = 255
145
  mask = self.postprocess(mask)
146
  x = {'image': x, 'mask': mask}
147
- return super().preprocess(x)
 
 
 
 
 
 
148
 
149
 
150
  class Blocks(gr.Blocks):
@@ -180,23 +193,25 @@ class Blocks(gr.Blocks):
180
  inference model
181
  '''
182
 
183
- @torch.no_grad()
184
- def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image,
185
  alpha_sample, guidance_scale, batch_size,
186
  fix_seed, rand_seed, actual_mask, style_image,
187
  *args, **kwargs):
188
- grounding_instruction = json.loads(grounding_instruction)
189
- phrase_list, location_list = [], []
190
- for k, v in grounding_instruction.items():
191
- phrase_list.append(k)
192
- location_list.append(v)
 
 
193
 
194
  placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
195
  image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
196
 
197
  batch_size = int(batch_size)
198
  if not 1 <= batch_size <= 4:
199
- batch_size = 2
200
 
201
  if style_image == None:
202
  has_text_mask = 1
@@ -212,9 +227,6 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
212
 
213
  location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location
214
 
215
- if task == 'Grounded Inpainting':
216
- alpha_sample = 1.0
217
-
218
  instruction = dict(
219
  prompt = language_instruction,
220
  phrases = phrase_list,
@@ -238,21 +250,19 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
238
  phrase_list=phrase_list)
239
 
240
  with torch.autocast(device_type='cuda', dtype=torch.float16):
241
- if task == 'Grounded Generation':
242
  if style_image == None:
243
- return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
 
 
244
  else:
245
  return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
246
- elif task == 'Grounded Inpainting':
247
- assert image is not None
248
- instruction['input_image'] = image.convert("RGB")
249
- return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
250
 
251
 
252
  def draw_box(boxes=[], texts=[], img=None):
253
  if len(boxes) == 0 and img is None:
254
  return None
255
-
256
  if img is None:
257
  img = Image.new('RGB', (512, 512), (255, 255, 255))
258
  colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
@@ -281,7 +291,7 @@ def get_concat(ims):
281
 
282
  def auto_append_grounding(language_instruction, grounding_texts):
283
  for grounding_text in grounding_texts:
284
- if grounding_text not in language_instruction and grounding_text != 'auto':
285
  language_instruction += "; " + grounding_text
286
  return language_instruction
287
 
@@ -292,6 +302,7 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
292
  alpha_sample, guidance_scale, batch_size,
293
  fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
294
  state):
 
295
  if 'boxes' not in state:
296
  state['boxes'] = []
297
 
@@ -307,44 +318,18 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun
307
 
308
  boxes = (np.asarray(boxes) / 512).tolist()
309
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
310
-
311
  image = None
312
  actual_mask = None
313
- if task == 'Grounded Inpainting':
314
- image = state.get('original_image', sketch_pad['image']).copy()
315
- image = center_crop(image)
316
- image = Image.fromarray(image)
317
-
318
- if use_actual_mask:
319
- actual_mask = sketch_pad['mask'].copy()
320
- if actual_mask.ndim == 3:
321
- actual_mask = actual_mask[..., 0]
322
- actual_mask = center_crop(actual_mask, tgt_size=(64, 64))
323
- actual_mask = torch.from_numpy(actual_mask == 0).float()
324
-
325
- if state.get('inpaint_hw', None):
326
- boxes = np.asarray(boxes) * 0.9 + 0.05
327
- boxes = boxes.tolist()
328
- grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes) if obj != 'auto'})
329
 
330
  if append_grounding:
331
  language_instruction = auto_append_grounding(language_instruction, grounding_texts)
332
 
333
  gen_images, gen_overlays = inference(
334
- task, language_instruction, grounding_instruction, boxes, image,
335
  alpha_sample, guidance_scale, batch_size,
336
  fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
337
  )
338
-
339
- for idx, gen_image in enumerate(gen_images):
340
-
341
- if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
342
- hw = min(*state['original_image'].shape[:2])
343
- gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
344
- gen_image = Image.fromarray(gen_image)
345
-
346
- gen_images[idx] = gen_image
347
-
348
  blank_samples = batch_size % 2 if batch_size > 1 else 0
349
  gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
350
  + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
@@ -355,6 +340,9 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun
355
 
356
  def binarize(x):
357
  return (x != 0).astype('uint8') * 255
 
 
 
358
 
359
  def sized_center_crop(img, cropx, cropy):
360
  y, x = img.shape[:2]
@@ -387,10 +375,20 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
387
  img = img.resize(tgt_size)
388
  return np.array(img)
389
 
390
- def draw(task, input, grounding_texts, new_image_trigger, state):
 
 
 
391
  if type(input) == dict:
392
  image = input['image']
393
  mask = input['mask']
 
 
 
 
 
 
 
394
  else:
395
  mask = input
396
 
@@ -398,24 +396,8 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
398
  mask = mask[..., 0]
399
 
400
  image_scale = 1.0
401
-
402
- # resize trigger
403
- if task == "Grounded Inpainting":
404
- mask_cond = mask.sum() == 0
405
- # size_cond = mask.shape != (512, 512)
406
- if mask_cond and 'original_image' not in state:
407
- image = Image.fromarray(image)
408
- width, height = image.size
409
- scale = 600 / min(width, height)
410
- image = image.resize((int(width * scale), int(height * scale)))
411
- state['original_image'] = np.array(image).copy()
412
- image_scale = float(height / width)
413
- return [None, new_image_trigger + 1, image_scale, state]
414
- else:
415
- original_image = state['original_image']
416
- H, W = original_image.shape[:2]
417
- image_scale = float(H / W)
418
-
419
  mask = binarize(mask)
420
  if mask.shape != (512, 512):
421
  # assert False, "should not receive any non- 512x512 masks."
@@ -424,16 +406,16 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
424
  image = center_crop(state['original_image'], state['inpaint_hw'])
425
  else:
426
  mask = np.zeros((512, 512), dtype=np.uint8)
427
- # mask = center_crop(mask)
428
  mask = binarize(mask)
429
 
430
  if type(mask) != np.ndarray:
431
  mask = np.array(mask)
432
-
433
- if mask.sum() == 0 and task != "Grounded Inpainting":
434
  state = {}
 
435
 
436
- if task != 'Grounded Inpainting':
437
  image = None
438
  else:
439
  image = Image.fromarray(image)
@@ -441,20 +423,20 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
441
  if 'boxes' not in state:
442
  state['boxes'] = []
443
 
444
- if 'masks' not in state or len(state['masks']) == 0:
445
  state['masks'] = []
446
  last_mask = np.zeros_like(mask)
447
  else:
448
  last_mask = state['masks'][-1]
449
-
450
- if type(mask) == np.ndarray and mask.size > 1:
451
  diff_mask = mask - last_mask
452
  else:
453
  diff_mask = np.zeros([])
454
 
455
  if diff_mask.sum() > 0:
456
- x1x2 = np.where(diff_mask.max(0) != 0)[0]
457
- y1y2 = np.where(diff_mask.max(1) != 0)[0]
458
  y1, y2 = y1y2.min(), y1y2.max()
459
  x1, x2 = x1x2.min(), x1x2.max()
460
 
@@ -466,26 +448,73 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
466
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
467
  if len(grounding_texts) < len(state['boxes']):
468
  grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
469
-
470
  box_image = draw_box(state['boxes'], grounding_texts, image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- if box_image is not None and state.get('inpaint_hw', None):
473
- inpaint_hw = state['inpaint_hw']
474
- box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
475
- original_image = state['original_image'].copy()
476
- box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
477
-
478
- return [box_image, new_image_trigger, image_scale, state]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
- def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
481
- if task != 'Grounded Inpainting':
482
- sketch_pad_trigger = sketch_pad_trigger + 1
 
483
  blank_samples = batch_size % 2 if batch_size > 1 else 0
484
  out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
485
  + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
486
  + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
487
  state = {}
488
- return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
489
 
490
  css = """
491
  #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
@@ -502,6 +531,10 @@ css = """
502
  cursor: pointer;
503
  text-decoration: none;
504
  }
 
 
 
 
505
  """
506
 
507
  rescale_js = """
@@ -516,42 +549,48 @@ function(x) {
516
  return x;
517
  }
518
  """
519
-
520
  with Blocks(
521
  css=css,
522
  analytics_enabled=False,
523
- title="GLIGen demo",
524
  ) as main:
525
  description = """<p style="text-align: center; font-weight: bold;">
526
- <span style="font-size: 28px">GLIGen: Open-Set Grounded Text-to-Image Generation</span>
527
  <br>
528
  <span style="font-size: 18px" id="paper-info">
529
- [<a href="https://gligen.github.io" target="_blank">Project Page</a>]
530
- [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
531
- [<a href="https://github.com/gligen/GLIGEN" target="_blank">GitHub</a>]
532
  </span>
533
  </p>
534
  <p>
535
- To ground concepts of interest with desired spatial specification, please (1) &#9000;&#65039; enter the concept names in <em> Grounding Instruction</em>, and (2) &#128433;&#65039; draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically.
536
  <br>
537
  For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/gligen/demo?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>
538
  </p>
539
  """
540
  gr.HTML(description)
541
-
542
  with gr.Row():
543
  with gr.Column(scale=4):
544
  sketch_pad_trigger = gr.Number(value=0, visible=False)
545
  sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
 
 
546
  init_white_trigger = gr.Number(value=0, visible=False)
547
- image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
548
  new_image_trigger = gr.Number(value=0, visible=False)
549
-
 
 
550
  task = gr.Radio(
551
- choices=["Grounded Generation", 'Grounded Inpainting'],
552
  type="value",
553
- value="Grounded Generation",
554
  label="Task",
 
 
555
  )
556
  language_instruction = gr.Textbox(
557
  label="Language instruction",
@@ -561,33 +600,38 @@ with Blocks(
561
  )
562
  with gr.Row():
563
  sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
564
- out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
565
  with gr.Row():
566
  clear_btn = gr.Button(value='Clear')
567
  gen_btn = gr.Button(value='Generate')
 
 
 
568
  with gr.Accordion("Advanced Options", open=False):
569
  with gr.Column():
570
  alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
571
  guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
572
- batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
573
  append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
574
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
575
  with gr.Row():
576
  fix_seed = gr.Checkbox(value=True, label="Fixed seed")
577
  rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
578
- with gr.Row():
579
- use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition")
580
- style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True)
 
581
  with gr.Column(scale=4):
582
  gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
583
  with gr.Row():
584
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
585
- out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
586
  with gr.Row():
587
  out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
588
  out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
589
 
590
  state = gr.State({})
 
591
 
592
  class Controller:
593
  def __init__(self):
@@ -605,75 +649,43 @@ with Blocks(
605
  return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
606
  + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
607
 
608
- def resize_centercrop(self, state):
609
- self.resizes += 1
610
- image = state['original_image'].copy()
611
- inpaint_hw = int(0.9 * min(*image.shape[:2]))
612
- state['inpaint_hw'] = inpaint_hw
613
- image_cc = center_crop(image, inpaint_hw)
614
- # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
615
- return image_cc, state
616
-
617
- def resize_masked(self, state):
618
- self.resizes += 1
619
- image = state['original_image'].copy()
620
- inpaint_hw = int(0.9 * min(*image.shape[:2]))
621
- state['inpaint_hw'] = inpaint_hw
622
- image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
623
- state['masked_image'] = image_mask.copy()
624
- # print(f'mask triggered {self.resizes}')
625
- return image_mask, state
626
-
627
- def switch_task_hide_cond(self, task):
628
- cond = False
629
- if task == "Grounded Generation":
630
- cond = True
631
-
632
- return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
633
-
634
  controller = Controller()
635
  main.load(
636
  lambda x:x+1,
637
  inputs=sketch_pad_trigger,
638
  outputs=sketch_pad_trigger,
639
  queue=False)
 
640
  sketch_pad.edit(
641
  draw,
642
- inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
643
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
644
  queue=False,
645
  )
 
 
 
 
 
 
646
  grounding_instruction.change(
647
  draw,
648
- inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
649
- outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
650
  queue=False,
651
  )
652
  clear_btn.click(
653
  clear,
654
- inputs=[task, sketch_pad_trigger, batch_size, state],
655
- outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
656
- queue=False)
657
- task.change(
658
- partial(clear, switch_task=True),
659
- inputs=[task, sketch_pad_trigger, batch_size, state],
660
- outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
661
  queue=False)
 
662
  sketch_pad_trigger.change(
663
  controller.init_white,
664
  inputs=[init_white_trigger],
665
  outputs=[sketch_pad, image_scale, init_white_trigger],
666
  queue=False)
667
- sketch_pad_resize_trigger.change(
668
- controller.resize_masked,
669
- inputs=[state],
670
- outputs=[sketch_pad, state],
671
- queue=False)
672
- batch_size.change(
673
- controller.change_n_samples,
674
- inputs=[batch_size],
675
- outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
676
- queue=False)
677
  gen_btn.click(
678
  generate,
679
  inputs=[
@@ -687,88 +699,98 @@ with Blocks(
687
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
688
  queue=True
689
  )
690
- sketch_pad_resize_trigger.change(
691
- None,
692
- None,
693
- sketch_pad_resize_trigger,
694
- _js=rescale_js,
695
- queue=False)
696
  init_white_trigger.change(
697
  None,
698
  None,
699
  init_white_trigger,
700
  _js=rescale_js,
701
  queue=False)
702
- use_style_cond.change(
703
- lambda cond: gr.Image.update(visible=cond),
704
- use_style_cond,
705
- style_cond_image,
706
- queue=False)
707
- task.change(
708
- controller.switch_task_hide_cond,
709
- inputs=task,
710
- outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask],
711
- queue=False)
712
-
713
- with gr.Column():
714
- gr.Examples(
715
- examples=[
716
- [
717
- "images/blank.png",
718
- "Grounded Generation",
719
- "a dog and an apple",
720
- "a dog;an apple",
721
  ],
722
  [
723
- "images/blank.png",
724
- "Grounded Generation",
725
- "John Lennon is using a pc",
726
- "John Lennon;a pc",
727
- [
728
- "images/blank.png",
729
- "Grounded Generation",
730
- "a painting of a fox sitting in a field at sunrise in the style of Claude Mone",
731
- "fox;sunrise",
732
- ],
733
  ],
734
  [
735
- "images/blank.png",
736
- "Grounded Generation",
737
- "a beautiful painting of hot dog by studio ghibli, octane render, brilliantly coloured",
738
- "hot dog",
 
739
  ],
740
  [
741
- "images/blank.png",
742
- "Grounded Generation",
743
- "a sport car, unreal engine, global illumination, ray tracing",
744
- "a sport car",
 
745
  ],
746
  [
747
- "images/flower_beach.jpg",
748
- "Grounded Inpainting",
749
- "a squirrel and the space needle",
750
- "a squirrel;the space needle",
 
751
  ],
752
  [
753
- "images/arg_corgis.jpeg",
754
- "Grounded Inpainting",
755
- "a dog and a birthday cake",
756
- "a dog; a birthday cake",
 
757
  ],
758
  [
759
- "images/teddy.jpg",
760
- "Grounded Inpainting",
761
- "a teddy bear wearing a santa claus red shirt; holding a Christmas gift box on hand",
762
- "a santa claus shirt; a Christmas gift box",
763
- ],
764
- ],
765
- inputs=[sketch_pad, task, language_instruction, grounding_instruction],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
  outputs=None,
767
  fn=None,
768
  cache_examples=False,
 
769
  )
770
 
771
  main.queue(concurrency_count=1, api_open=False)
772
- main.launch(share=False, show_api=False, show_error=True)
773
-
774
-
 
1
  import gradio as gr
2
+ import os
3
  import torch
4
  from omegaconf import OmegaConf
5
  from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
 
19
 
20
  from datetime import datetime
21
 
22
+ from example_component import create_examples
23
+
24
  from huggingface_hub import hf_hub_download
25
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
26
+ import cv2
27
  import sys
28
  sys.tracebacklimit = 0
29
 
 
42
  pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
43
  config = OmegaConf.create( config["_content"] ) # config used in training
44
  config.alpha_scale = 1.0
 
 
45
 
46
  if common_instances is None:
47
  common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
 
139
  if x is None:
140
  return x
141
  if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
142
+
143
  decode_image = processing_utils.decode_base64_to_image(x)
144
+ print('decode to 64')
145
  width, height = decode_image.size
146
+ img = np.asarray(decode_image)
147
+ return {'image':img, 'mask':binarize_2(img)}
148
+
149
  mask = np.zeros((height, width, 4), dtype=np.uint8)
150
+
151
  mask[..., -1] = 255
152
  mask = self.postprocess(mask)
153
  x = {'image': x, 'mask': mask}
154
+ print('vao preprocess-------------------------')
155
+ hh = super().preprocess(x)
156
+ if (hh['image'].min()!=255) and (hh['mask'][:,:,:3].max()==0):
157
+
158
+ hh['mask'] = binarize_2(hh['image'])
159
+
160
+ return hh
161
 
162
 
163
  class Blocks(gr.Blocks):
 
193
  inference model
194
  '''
195
 
196
+ # @torch.no_grad()
197
+ def inference(task, language_instruction, phrase_list, location_list, inpainting_boxes_nodrop, image,
198
  alpha_sample, guidance_scale, batch_size,
199
  fix_seed, rand_seed, actual_mask, style_image,
200
  *args, **kwargs):
201
+ # import pdb; pdb.set_trace()
202
+
203
+ # grounding_instruction = json.loads(grounding_instruction)
204
+ # phrase_list, location_list = [], []
205
+ # for k, v in grounding_instruction.items():
206
+ # phrase_list.append(k)
207
+ # location_list.append(v)
208
 
209
  placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
210
  image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
211
 
212
  batch_size = int(batch_size)
213
  if not 1 <= batch_size <= 4:
214
+ batch_size = 1
215
 
216
  if style_image == None:
217
  has_text_mask = 1
 
227
 
228
  location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location
229
 
 
 
 
230
  instruction = dict(
231
  prompt = language_instruction,
232
  phrases = phrase_list,
 
250
  phrase_list=phrase_list)
251
 
252
  with torch.autocast(device_type='cuda', dtype=torch.float16):
253
+ if task == 'User provide boxes' or 'Available boxes':
254
  if style_image == None:
255
+ result = grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
256
+ torch.cuda.empty_cache()
257
+ return result
258
  else:
259
  return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
 
 
 
 
260
 
261
 
262
  def draw_box(boxes=[], texts=[], img=None):
263
  if len(boxes) == 0 and img is None:
264
  return None
265
+
266
  if img is None:
267
  img = Image.new('RGB', (512, 512), (255, 255, 255))
268
  colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
 
291
 
292
  def auto_append_grounding(language_instruction, grounding_texts):
293
  for grounding_text in grounding_texts:
294
+ if grounding_text.lower() not in language_instruction.lower() and grounding_text != 'auto':
295
  language_instruction += "; " + grounding_text
296
  return language_instruction
297
 
 
302
  alpha_sample, guidance_scale, batch_size,
303
  fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
304
  state):
305
+
306
  if 'boxes' not in state:
307
  state['boxes'] = []
308
 
 
318
 
319
  boxes = (np.asarray(boxes) / 512).tolist()
320
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
 
321
  image = None
322
  actual_mask = None
323
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  if append_grounding:
326
  language_instruction = auto_append_grounding(language_instruction, grounding_texts)
327
 
328
  gen_images, gen_overlays = inference(
329
+ task, language_instruction, grounding_texts,boxes, boxes, image,
330
  alpha_sample, guidance_scale, batch_size,
331
  fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
332
  )
 
 
 
 
 
 
 
 
 
 
333
  blank_samples = batch_size % 2 if batch_size > 1 else 0
334
  gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
335
  + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
 
340
 
341
  def binarize(x):
342
  return (x != 0).astype('uint8') * 255
343
+ def binarize_2(x):
344
+ gray_image = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
345
+ return (gray_image!=255).astype('uint8') * 255
346
 
347
  def sized_center_crop(img, cropx, cropy):
348
  y, x = img.shape[:2]
 
375
  img = img.resize(tgt_size)
376
  return np.array(img)
377
 
378
+ # 接收 sketchpad 的输入 (左边)
379
+ def draw(task, input, grounding_texts, new_image_trigger, state, generate_parsed, box_image):
380
+ print('input', generate_parsed)
381
+
382
  if type(input) == dict:
383
  image = input['image']
384
  mask = input['mask']
385
+ if generate_parsed==1:
386
+ generate_parsed = 0
387
+ # import pdb; pdb.set_trace()
388
+ print('do nothing')
389
+
390
+ return [box_image, new_image_trigger, 1., state, generate_parsed]
391
+
392
  else:
393
  mask = input
394
 
 
396
  mask = mask[..., 0]
397
 
398
  image_scale = 1.0
399
+
400
+ print('vao draw--------------------')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  mask = binarize(mask)
402
  if mask.shape != (512, 512):
403
  # assert False, "should not receive any non- 512x512 masks."
 
406
  image = center_crop(state['original_image'], state['inpaint_hw'])
407
  else:
408
  mask = np.zeros((512, 512), dtype=np.uint8)
 
409
  mask = binarize(mask)
410
 
411
  if type(mask) != np.ndarray:
412
  mask = np.array(mask)
413
+ #
414
+ if mask.sum() == 0:
415
  state = {}
416
+ print('delete state')
417
 
418
+ if True:
419
  image = None
420
  else:
421
  image = Image.fromarray(image)
 
423
  if 'boxes' not in state:
424
  state['boxes'] = []
425
 
426
+ if 'masks' not in state or len(state['masks']) == 0 :
427
  state['masks'] = []
428
  last_mask = np.zeros_like(mask)
429
  else:
430
  last_mask = state['masks'][-1]
431
+
432
+ if type(mask) == np.ndarray and mask.size > 1 :
433
  diff_mask = mask - last_mask
434
  else:
435
  diff_mask = np.zeros([])
436
 
437
  if diff_mask.sum() > 0:
438
+ x1x2 = np.where(diff_mask.max(0) > 1)[0]
439
+ y1y2 = np.where(diff_mask.max(1) > 1)[0]
440
  y1, y2 = y1y2.min(), y1y2.max()
441
  x1, x2 = x1x2.min(), x1x2.max()
442
 
 
448
  grounding_texts = [x for x in grounding_texts if len(x) > 0]
449
  if len(grounding_texts) < len(state['boxes']):
450
  grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
451
+
452
  box_image = draw_box(state['boxes'], grounding_texts, image)
453
+ generate_parsed = 0
454
+
455
+ return [box_image, new_image_trigger, image_scale, state, generate_parsed]
456
+
457
+ def change_state(bboxes,layout, state, instruction, trigger_stage, boxes):
458
+ if trigger_stage ==0 :
459
+ return [boxes, state, 0]
460
+ # mask =
461
+ state['boxes'] = []
462
+ state['masks'] = []
463
+ image = None
464
+ list_boxes = bboxes.split('/')
465
+ result =[]
466
+ for b in list_boxes:
467
+ ints = b[1:-1].split(',')
468
+ l = []
469
+ for i in ints:
470
+ l.append(int(i))
471
+ result.append(l)
472
+ print('run change state')
473
+
474
+ for box in result:
475
+ state['boxes'].append(box)
476
+ grounding_texts = [x.strip() for x in instruction.split(';')]
477
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
478
+ if len(grounding_texts) < len(result):
479
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(result))]
480
 
481
+ box_image = draw_box(result, grounding_texts)
482
+
483
+ mask = binarize_2(layout['image'])
484
+ state['masks'].append(mask.copy())
485
+ # print('done change state', state)
486
+ print('done change state')
487
+ # import pdb; pdb.set_trace()
488
+ return [box_image,state, trigger_stage]
489
+
490
+ def example_click(name, grounding_instruction, instruction, bboxes,generate_parsed, trigger_parsed):
491
+
492
+ list_boxes = bboxes.split('/')
493
+ result =[]
494
+
495
+ for b in list_boxes:
496
+ ints = b[1:-1].split(',')
497
+ l = []
498
+ for i in ints:
499
+ l.append(int(i))
500
+ result.append(l)
501
+ print('run change state')
502
+
503
+ box_image = draw_box(result, instruction)
504
+ trigger_parsed += 1
505
+ print('done the example click')
506
+ return [box_image, trigger_parsed]
507
 
508
+ def clear(task, sketch_pad_trigger, batch_size, state,trigger_stage, switch_task=False):
509
+
510
+ sketch_pad_trigger = sketch_pad_trigger + 1
511
+ trigger_stage = 0
512
  blank_samples = batch_size % 2 if batch_size > 1 else 0
513
  out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
514
  + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
515
  + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
516
  state = {}
517
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] + [trigger_stage]
518
 
519
  css = """
520
  #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
 
531
  cursor: pointer;
532
  text-decoration: none;
533
  }
534
+ #my_image > div.fixed-height
535
+ {
536
+ height: var(--height) !important;
537
+ }
538
  """
539
 
540
  rescale_js = """
 
549
  return x;
550
  }
551
  """
552
+ # [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
553
  with Blocks(
554
  css=css,
555
  analytics_enabled=False,
556
+ title="Attention-refocusing demo",
557
  ) as main:
558
  description = """<p style="text-align: center; font-weight: bold;">
559
+ <span style="font-size: 28px">Grounded Text-to-Image Synthesis with Attention Refocusing</span>
560
  <br>
561
  <span style="font-size: 18px" id="paper-info">
562
+ [<a href="https://attention-refocusing.github.io/" target="_blank">Project Page</a>]
563
+
564
+ [<a href="https://github.com/Attention-Refocusing/attention-refocusing" target="_blank">GitHub</a>]
565
  </span>
566
  </p>
567
  <p>
568
+ To identify the areas of interest based on specific spatial parameters, you need to (1) &#9000;&#65039; input the names of the concepts you're interested in <em> Grounding Instruction</em>, and (2) &#128433;&#65039; draw their corresponding bounding boxes using <em> Sketch Pad</em> -- the parsed boxes will automatically be showed up once you've drawn them.
569
  <br>
570
  For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/gligen/demo?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>
571
  </p>
572
  """
573
  gr.HTML(description)
574
+
575
  with gr.Row():
576
  with gr.Column(scale=4):
577
  sketch_pad_trigger = gr.Number(value=0, visible=False)
578
  sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
579
+ trigger_stage = gr.Number(value=0, visible=False)
580
+
581
  init_white_trigger = gr.Number(value=0, visible=False)
582
+ image_scale = gr.Number(value=1.0, elem_id="image_scale", visible=False)
583
  new_image_trigger = gr.Number(value=0, visible=False)
584
+ text_box = gr.Textbox(visible=False)
585
+ generate_parsed = gr.Number(value=0, visible=False)
586
+
587
  task = gr.Radio(
588
+ choices=["Available boxes", 'User provide boxes'],
589
  type="value",
590
+ value="User provide boxes",
591
  label="Task",
592
+ visible=False
593
+
594
  )
595
  language_instruction = gr.Textbox(
596
  label="Language instruction",
 
600
  )
601
  with gr.Row():
602
  sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
603
+ out_imagebox = gr.Image(type="pil",elem_id="my_image" ,label="Parsed Sketch Pad", shape=(512,512))
604
  with gr.Row():
605
  clear_btn = gr.Button(value='Clear')
606
  gen_btn = gr.Button(value='Generate')
607
+ with gr.Row():
608
+ parsed_btn = gr.Button(value='generate parsed boxes')
609
+
610
  with gr.Accordion("Advanced Options", open=False):
611
  with gr.Column():
612
  alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
613
  guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
614
+ batch_size = gr.Slider(minimum=1, maximum=4,visible=False, step=1, value=1, label="Number of Samples")
615
  append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
616
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
617
  with gr.Row():
618
  fix_seed = gr.Checkbox(value=True, label="Fixed seed")
619
  rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
620
+
621
+ with gr.Row():
622
+ use_style_cond = gr.Checkbox(value=False,visible=False, label="Enable Style Condition")
623
+ style_cond_image = gr.Image(type="pil",visible=False, label="Style Condition", interactive=True)
624
  with gr.Column(scale=4):
625
  gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
626
  with gr.Row():
627
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
628
+ out_gen_2 = gr.Image(type="pil", visible=False, show_label=False)
629
  with gr.Row():
630
  out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
631
  out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
632
 
633
  state = gr.State({})
634
+
635
 
636
  class Controller:
637
  def __init__(self):
 
649
  return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
650
  + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  controller = Controller()
653
  main.load(
654
  lambda x:x+1,
655
  inputs=sketch_pad_trigger,
656
  outputs=sketch_pad_trigger,
657
  queue=False)
658
+
659
  sketch_pad.edit(
660
  draw,
661
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed, out_imagebox],
662
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed],
663
  queue=False,
664
  )
665
+ trigger_stage.change(
666
+ change_state,
667
+ inputs=[text_box,sketch_pad, state, grounding_instruction, trigger_stage,out_imagebox],
668
+ outputs=[out_imagebox,state,trigger_stage],
669
+ queue=True
670
+ )
671
  grounding_instruction.change(
672
  draw,
673
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed,out_imagebox],
674
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed],
675
  queue=False,
676
  )
677
  clear_btn.click(
678
  clear,
679
+ inputs=[task, sketch_pad_trigger, batch_size,trigger_stage, state],
680
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state, trigger_stage],
 
 
 
 
 
681
  queue=False)
682
+
683
  sketch_pad_trigger.change(
684
  controller.init_white,
685
  inputs=[init_white_trigger],
686
  outputs=[sketch_pad, image_scale, init_white_trigger],
687
  queue=False)
688
+
 
 
 
 
 
 
 
 
 
689
  gen_btn.click(
690
  generate,
691
  inputs=[
 
699
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
700
  queue=True
701
  )
 
 
 
 
 
 
702
  init_white_trigger.change(
703
  None,
704
  None,
705
  init_white_trigger,
706
  _js=rescale_js,
707
  queue=False)
708
+ examples = [
709
+ [
710
+ 'guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg',
711
+ "a cat;a dog",
712
+ "a cat on the right of a dog",
713
+ '(291, 88, 481, 301)/(25, 64, 260, 391)',
714
+ 1, 1
 
 
 
 
 
 
 
 
 
 
 
 
715
  ],
716
  [
717
+ 'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',#'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',
718
+ "a bus;a car",
719
+ "a bus and a car",
720
+ '(8,128,266,384)/(300,196,502,316)', #'(8,128,266,384)', #/(300,196,502,316)
721
+ 1, 2
 
 
 
 
 
722
  ],
723
  [
724
+ 'guide_imgs/1_Two_cars_on_the_street..jpg',
725
+ "a car;a car",
726
+ "Two cars on the street.",
727
+ '(34, 98, 247, 264)/(271, 122, 481, 293)',
728
+ 1, 3
729
  ],
730
  [
731
+ 'guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg',
732
+ "an apple;an apple",
733
+ "two apples lay side by side on a wooden table, their glossy red and green skins glinting in the sunlight.",
734
+ '(40, 210, 235, 450)/(275, 210, 470, 450)',
735
+ 1, 4
736
  ],
737
  [
738
+ 'guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg',
739
+ "a banana;an apple",
740
+ "A banana on the left of an apple.",
741
+ '(62, 193, 225, 354)/(300, 184, 432, 329)',
742
+ 1, 5
743
  ],
744
  [
745
+ 'guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg',
746
+ "a pizza ;a suitcase",
747
+ "A pizza on the right of a suitcase.",
748
+ '(307, 112, 490, 280)/(41, 120, 244, 270)',
749
+ 1, 6
750
  ],
751
  [
752
+ 'guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg',
753
+ "a wine glass;a dog",
754
+ "A wine glass on top of a dog.",
755
+ '(206, 78, 306, 214)/(137, 222, 367, 432)',
756
+ 1, 7
757
+ ]
758
+ ,
759
+ [
760
+ 'guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg',
761
+ "a bicycle;a boat",
762
+ "A bicycle on top of a boat.",
763
+ '(185, 110, 335, 205)/(111, 228, 401, 373)',
764
+ 1, 8
765
+ ]
766
+ ,
767
+ [
768
+ 'guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg',
769
+ "a laptop;a teddy bear",
770
+ "A laptop on top of a teddy bear.",
771
+ '(180, 70, 332, 210)/(150, 240, 362, 420)',
772
+ 1, 9
773
+ ]
774
+ ,
775
+ [
776
+ 'guide_imgs/0_A_train_on_top_of_a_surfboard..jpg',
777
+ "a train;a surfboard",
778
+ "A train on top of a surfboard.",
779
+ '(130, 80, 385, 240)/(75, 260, 440, 450)',
780
+ 1, 10
781
+ ]
782
+ ]
783
+
784
+ with gr.Column():
785
+
786
+ create_examples(
787
+ examples=examples,
788
+ inputs=[sketch_pad, grounding_instruction,language_instruction , text_box, generate_parsed, trigger_stage],
789
  outputs=None,
790
  fn=None,
791
  cache_examples=False,
792
+
793
  )
794
 
795
  main.queue(concurrency_count=1, api_open=False)
796
+ main.launch(share=False, show_api=False, show_error=True, debug=False,)
 
 
dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
dataset/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
dataset/__pycache__/catalog.cpython-310.pyc ADDED
Binary file (1.1 kB). View file
 
dataset/__pycache__/catalog.cpython-38.pyc ADDED
Binary file (1.11 kB). View file
 
dataset/__pycache__/concat_dataset.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
dataset/__pycache__/concat_dataset.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
environment.yaml CHANGED
@@ -1,4 +1,4 @@
1
- name: loco_gligen_demo
2
  channels:
3
  - xformers/label/dev
4
  - pytorch
 
1
+ name: gligen_demo
2
  channels:
3
  - xformers/label/dev
4
  - pytorch
example_component.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines helper methods useful for loading and caching Interface examples.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import csv
8
+ import inspect
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import threading
13
+ import warnings
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
16
+
17
+ import matplotlib
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import PIL
21
+ import PIL.Image
22
+
23
+ from gradio import components, processing_utils, routes, utils
24
+ from gradio.context import Context
25
+ from gradio.documentation import document, set_documentation_group
26
+ from gradio.flagging import CSVLogger
27
+
28
+ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
29
+ from gradio.components import IOComponent
30
+
31
+ CACHED_FOLDER = "gradio_cached_examples"
32
+ LOG_FILE = "log.csv"
33
+
34
+ set_documentation_group("helpers")
35
+
36
+
37
+ def create_examples(
38
+ examples: List[Any] | List[List[Any]] | str,
39
+ inputs: IOComponent | List[IOComponent],
40
+ outputs: IOComponent | List[IOComponent] | None = None,
41
+ fn: Callable | None = None,
42
+ cache_examples: bool = False,
43
+ examples_per_page: int = 10,
44
+ _api_mode: bool = False,
45
+ label: str | None = None,
46
+ elem_id: str | None = None,
47
+ run_on_click: bool = False,
48
+ preprocess: bool = True,
49
+ postprocess: bool = True,
50
+ batch: bool = False,
51
+ ):
52
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
53
+ examples_obj = Examples(
54
+ examples=examples,
55
+ inputs=inputs,
56
+ outputs=outputs,
57
+ fn=fn,
58
+ cache_examples=cache_examples,
59
+ examples_per_page=examples_per_page,
60
+ _api_mode=_api_mode,
61
+ label=label,
62
+ elem_id=elem_id,
63
+ run_on_click=run_on_click,
64
+ preprocess=preprocess,
65
+ postprocess=postprocess,
66
+ batch=batch,
67
+ _initiated_directly=False,
68
+ )
69
+ utils.synchronize_async(examples_obj.create)
70
+ return examples_obj
71
+
72
+
73
+ class Examples:
74
+ """
75
+ This class is a wrapper over the Dataset component and can be used to create Examples
76
+ for Blocks / Interfaces. Populates the Dataset component with examples and
77
+ assigns event listener so that clicking on an example populates the input/output
78
+ components. Optionally handles example caching for fast inference.
79
+
80
+ Demos: blocks_inputs, fake_gan
81
+ Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ examples: List[Any] | List[List[Any]] | str,
87
+ inputs: IOComponent | List[IOComponent],
88
+ outputs: IOComponent | List[IOComponent] | None = None,
89
+ fn: Callable | None = None,
90
+ cache_examples: bool = False,
91
+ examples_per_page: int = 10,
92
+ _api_mode: bool = False,
93
+ label: str | None = "Examples",
94
+ elem_id: str | None = None,
95
+ run_on_click: bool = False,
96
+ preprocess: bool = True,
97
+ postprocess: bool = True,
98
+ batch: bool = False,
99
+ _initiated_directly: bool = True,
100
+ ):
101
+ """
102
+ Parameters:
103
+ examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
104
+ inputs: the component or list of components corresponding to the examples
105
+ outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
106
+ fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
107
+ cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
108
+ examples_per_page: how many examples to show per page.
109
+ label: the label to use for the examples component (by default, "Examples")
110
+ elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
111
+ run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
112
+ preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
113
+ postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
114
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
115
+ """
116
+ if _initiated_directly:
117
+ warnings.warn(
118
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
119
+ )
120
+
121
+ if cache_examples and (fn is None or outputs is None):
122
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
123
+
124
+ if not isinstance(inputs, list):
125
+ inputs = [inputs]
126
+ if outputs and not isinstance(outputs, list):
127
+ outputs = [outputs]
128
+
129
+ working_directory = Path().absolute()
130
+
131
+ if examples is None:
132
+ raise ValueError("The parameter `examples` cannot be None")
133
+ elif isinstance(examples, list) and (
134
+ len(examples) == 0 or isinstance(examples[0], list)
135
+ ):
136
+ pass
137
+ elif (
138
+ isinstance(examples, list) and len(inputs) == 1
139
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
140
+ examples = [[e] for e in examples]
141
+ elif isinstance(examples, str):
142
+ if not Path(examples).exists():
143
+ raise FileNotFoundError(
144
+ "Could not find examples directory: " + examples
145
+ )
146
+ working_directory = examples
147
+ if not (Path(examples) / LOG_FILE).exists():
148
+ if len(inputs) == 1:
149
+ examples = [[e] for e in os.listdir(examples)]
150
+ else:
151
+ raise FileNotFoundError(
152
+ "Could not find log file (required for multiple inputs): "
153
+ + LOG_FILE
154
+ )
155
+ else:
156
+ with open(Path(examples) / LOG_FILE) as logs:
157
+ examples = list(csv.reader(logs))
158
+ examples = [
159
+ examples[i][: len(inputs)] for i in range(1, len(examples))
160
+ ] # remove header and unnecessary columns
161
+
162
+ else:
163
+ raise ValueError(
164
+ "The parameter `examples` must either be a string directory or a list"
165
+ "(if there is only 1 input component) or (more generally), a nested "
166
+ "list, where each sublist represents a set of inputs."
167
+ )
168
+
169
+ input_has_examples = [False] * len(inputs)
170
+ for example in examples:
171
+ for idx, example_for_input in enumerate(example):
172
+ if not (example_for_input is None):
173
+ try:
174
+ input_has_examples[idx] = True
175
+ except IndexError:
176
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
177
+
178
+ inputs_with_examples = [
179
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
180
+ ]
181
+ non_none_examples = [
182
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
183
+ for example in examples
184
+ ]
185
+
186
+ self.examples = examples
187
+ self.non_none_examples = non_none_examples
188
+ self.inputs = inputs
189
+ self.inputs_with_examples = inputs_with_examples
190
+ self.outputs = outputs
191
+ self.fn = fn
192
+ self.cache_examples = cache_examples
193
+ self._api_mode = _api_mode
194
+ self.preprocess = preprocess
195
+ self.postprocess = postprocess
196
+ self.batch = batch
197
+
198
+ with utils.set_directory(working_directory):
199
+ self.processed_examples = [
200
+ [
201
+ component.postprocess(sample)
202
+ for component, sample in zip(inputs, example)
203
+ ]
204
+ for example in examples
205
+ ]
206
+ self.non_none_processed_examples = [
207
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
208
+ for example in self.processed_examples
209
+ ]
210
+ if cache_examples:
211
+ for example in self.examples:
212
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
213
+ warnings.warn(
214
+ "Examples are being cached but not all input components have "
215
+ "example values. This may result in an exception being thrown by "
216
+ "your function. If you do get an error while caching examples, make "
217
+ "sure all of your inputs have example values for all of your examples "
218
+ "or you provide default values for those particular parameters in your function."
219
+ )
220
+ break
221
+
222
+ with utils.set_directory(working_directory):
223
+ self.dataset = components.Dataset(
224
+ components=inputs_with_examples,
225
+ samples=non_none_examples,
226
+ type="index",
227
+ label=label,
228
+ samples_per_page=examples_per_page,
229
+ elem_id=elem_id,
230
+ )
231
+
232
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
233
+ self.cached_file = Path(self.cached_folder) / "log.csv"
234
+ self.cache_examples = cache_examples
235
+ self.run_on_click = run_on_click
236
+
237
+ async def create(self) -> None:
238
+ """Caches the examples if self.cache_examples is True and creates the Dataset
239
+ component to hold the examples"""
240
+
241
+ async def load_example(example_id):
242
+ # import pdb; pdb.set_trace()
243
+ if self.cache_examples:
244
+ processed_example = self.non_none_processed_examples[
245
+ example_id
246
+ ] + await self.load_from_cache(example_id)
247
+ else:
248
+ processed_example = self.non_none_processed_examples[example_id]
249
+ return utils.resolve_singleton(processed_example)
250
+
251
+ if Context.root_block:
252
+ if self.cache_examples and self.outputs:
253
+ targets = self.inputs_with_examples + self.outputs
254
+ else:
255
+ targets = self.inputs_with_examples
256
+ self.dataset.click(
257
+ load_example,
258
+ inputs=[self.dataset],
259
+ outputs=targets, # type: ignore
260
+ postprocess=False,
261
+ queue=False,
262
+ )
263
+ self.dataset.click(
264
+ self.fn,
265
+ inputs=[self.dataset],
266
+ outputs=targets, # type: ignore
267
+ postprocess=False,
268
+ queue=False,
269
+ )
270
+ # if self.run_on_click and not self.cache_examples:
271
+ # if self.fn is None:
272
+ # raise ValueError("Cannot run_on_click if no function is provided")
273
+ # self.dataset.click(
274
+ # self.fn,
275
+ # inputs=self.inputs, # type: ignore
276
+ # outputs=self.outputs, # type: ignore
277
+ # )
278
+
279
+ if self.cache_examples:
280
+ await self.cache()
281
+
282
+ async def cache(self) -> None:
283
+ """
284
+ Caches all of the examples so that their predictions can be shown immediately.
285
+ """
286
+ if Path(self.cached_file).exists():
287
+ print(
288
+ f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
289
+ )
290
+ else:
291
+ if Context.root_block is None:
292
+ raise ValueError("Cannot cache examples if not in a Blocks context")
293
+
294
+ print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
295
+ cache_logger = CSVLogger()
296
+
297
+ # create a fake dependency to process the examples and get the predictions
298
+ dependency = Context.root_block.set_event_trigger(
299
+ event_name="fake_event",
300
+ fn=self.fn,
301
+ inputs=self.inputs_with_examples, # type: ignore
302
+ outputs=self.outputs, # type: ignore
303
+ preprocess=self.preprocess and not self._api_mode,
304
+ postprocess=self.postprocess and not self._api_mode,
305
+ batch=self.batch,
306
+ )
307
+
308
+ fn_index = Context.root_block.dependencies.index(dependency)
309
+ assert self.outputs is not None
310
+ cache_logger.setup(self.outputs, self.cached_folder)
311
+ for example_id, _ in enumerate(self.examples):
312
+ processed_input = self.processed_examples[example_id]
313
+ if self.batch:
314
+ processed_input = [[value] for value in processed_input]
315
+ prediction = await Context.root_block.process_api(
316
+ fn_index=fn_index, inputs=processed_input, request=None, state={}
317
+ )
318
+ output = prediction["data"]
319
+ if self.batch:
320
+ output = [value[0] for value in output]
321
+ cache_logger.flag(output)
322
+ # Remove the "fake_event" to prevent bugs in loading interfaces from spaces
323
+ Context.root_block.dependencies.remove(dependency)
324
+ Context.root_block.fns.pop(fn_index)
325
+
326
+ async def load_from_cache(self, example_id: int) -> List[Any]:
327
+ """Loads a particular cached example for the interface.
328
+ Parameters:
329
+ example_id: The id of the example to process (zero-indexed).
330
+ """
331
+ # import pdb; pdb.set_trace()
332
+ with open(self.cached_file, encoding="utf-8") as cache:
333
+ examples = list(csv.reader(cache))
334
+ example = examples[example_id + 1] # +1 to adjust for header
335
+ output = []
336
+ assert self.outputs is not None
337
+ for component, value in zip(self.outputs, example):
338
+ try:
339
+ value_as_dict = ast.literal_eval(value)
340
+ assert utils.is_update(value_as_dict)
341
+ output.append(value_as_dict)
342
+ except (ValueError, TypeError, SyntaxError, AssertionError):
343
+ output.append(component.serialize(value, self.cached_folder))
344
+ return output
345
+
346
+
347
+ class TrackedIterable:
348
+ def __init__(
349
+ self,
350
+ iterable: Iterable | None,
351
+ index: int | None,
352
+ length: int | None,
353
+ desc: str | None,
354
+ unit: str | None,
355
+ _tqdm=None,
356
+ progress: float | None = None,
357
+ ) -> None:
358
+ self.iterable = iterable
359
+ self.index = index
360
+ self.length = length
361
+ self.desc = desc
362
+ self.unit = unit
363
+ self._tqdm = _tqdm
364
+ self.progress = progress
365
+
366
+
367
+ @document("__call__", "tqdm")
368
+ class Progress(Iterable):
369
+ """
370
+ The Progress class provides a custom progress tracker that is used in a function signature.
371
+ To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance.
372
+ The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable.
373
+ The Progress tracker is currently only available with `queue()`.
374
+ Example:
375
+ import gradio as gr
376
+ import time
377
+ def my_function(x, progress=gr.Progress()):
378
+ progress(0, desc="Starting...")
379
+ time.sleep(1)
380
+ for i in progress.tqdm(range(100)):
381
+ time.sleep(0.1)
382
+ return x
383
+ gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch()
384
+ Demos: progress
385
+ """
386
+
387
+ def __init__(
388
+ self,
389
+ track_tqdm: bool = False,
390
+ _callback: Callable | None = None, # for internal use only
391
+ _event_id: str | None = None,
392
+ ):
393
+ """
394
+ Parameters:
395
+ track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function.
396
+ """
397
+ self.track_tqdm = track_tqdm
398
+ self._callback = _callback
399
+ self._event_id = _event_id
400
+ self.iterables: List[TrackedIterable] = []
401
+
402
+ def __len__(self):
403
+ return self.iterables[-1].length
404
+
405
+ def __iter__(self):
406
+ return self
407
+
408
+ def __next__(self):
409
+ """
410
+ Updates progress tracker with next item in iterable.
411
+ """
412
+ if self._callback:
413
+ current_iterable = self.iterables[-1]
414
+ while (
415
+ not hasattr(current_iterable.iterable, "__next__")
416
+ and len(self.iterables) > 0
417
+ ):
418
+ current_iterable = self.iterables.pop()
419
+ self._callback(
420
+ event_id=self._event_id,
421
+ iterables=self.iterables,
422
+ )
423
+ assert current_iterable.index is not None, "Index not set."
424
+ current_iterable.index += 1
425
+ try:
426
+ return next(current_iterable.iterable) # type: ignore
427
+ except StopIteration:
428
+ self.iterables.pop()
429
+ raise StopIteration
430
+ else:
431
+ return self
432
+
433
+ def __call__(
434
+ self,
435
+ progress: float | Tuple[int, int | None] | None,
436
+ desc: str | None = None,
437
+ total: int | None = None,
438
+ unit: str = "steps",
439
+ _tqdm=None,
440
+ ):
441
+ """
442
+ Updates progress tracker with progress and message text.
443
+ Parameters:
444
+ progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar.
445
+ desc: description to display.
446
+ total: estimated total number of steps.
447
+ unit: unit of iterations.
448
+ """
449
+ if self._callback:
450
+ if isinstance(progress, tuple):
451
+ index, total = progress
452
+ progress = None
453
+ else:
454
+ index = None
455
+ self._callback(
456
+ event_id=self._event_id,
457
+ iterables=self.iterables
458
+ + [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)],
459
+ )
460
+ else:
461
+ return progress
462
+
463
+ def tqdm(
464
+ self,
465
+ iterable: Iterable | None,
466
+ desc: str | None = None,
467
+ total: int | None = None,
468
+ unit: str = "steps",
469
+ _tqdm=None,
470
+ *args,
471
+ **kwargs,
472
+ ):
473
+ """
474
+ Attaches progress tracker to iterable, like tqdm.
475
+ Parameters:
476
+ iterable: iterable to attach progress tracker to.
477
+ desc: description to display.
478
+ total: estimated total number of steps.
479
+ unit: unit of iterations.
480
+ """
481
+ if self._callback:
482
+ if iterable is None:
483
+ new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm)
484
+ self.iterables.append(new_iterable)
485
+ self._callback(event_id=self._event_id, iterables=self.iterables)
486
+ return self
487
+ length = len(iterable) if hasattr(iterable, "__len__") else None # type: ignore
488
+ self.iterables.append(
489
+ TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm)
490
+ )
491
+ return self
492
+
493
+ def update(self, n=1):
494
+ """
495
+ Increases latest iterable with specified number of steps.
496
+ Parameters:
497
+ n: number of steps completed.
498
+ """
499
+ if self._callback and len(self.iterables) > 0:
500
+ current_iterable = self.iterables[-1]
501
+ assert current_iterable.index is not None, "Index not set."
502
+ current_iterable.index += n
503
+ self._callback(
504
+ event_id=self._event_id,
505
+ iterables=self.iterables,
506
+ )
507
+ else:
508
+ return
509
+
510
+ def close(self, _tqdm):
511
+ """
512
+ Removes iterable with given _tqdm.
513
+ """
514
+ if self._callback:
515
+ for i in range(len(self.iterables)):
516
+ if id(self.iterables[i]._tqdm) == id(_tqdm):
517
+ self.iterables.pop(i)
518
+ break
519
+ self._callback(
520
+ event_id=self._event_id,
521
+ iterables=self.iterables,
522
+ )
523
+ else:
524
+ return
525
+
526
+
527
+ def create_tracker(root_blocks, event_id, fn, track_tqdm):
528
+
529
+ progress = Progress(_callback=root_blocks._queue.set_progress, _event_id=event_id)
530
+ if not track_tqdm:
531
+ return progress, fn
532
+
533
+ try:
534
+ _tqdm = __import__("tqdm")
535
+ except ModuleNotFoundError:
536
+ return progress, fn
537
+ if not hasattr(root_blocks, "_progress_tracker_per_thread"):
538
+ root_blocks._progress_tracker_per_thread = {}
539
+
540
+ def init_tqdm(self, iterable=None, desc=None, *args, **kwargs):
541
+ self._progress = root_blocks._progress_tracker_per_thread.get(
542
+ threading.get_ident()
543
+ )
544
+ if self._progress is not None:
545
+ self._progress.event_id = event_id
546
+ self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
547
+ kwargs["file"] = open(os.devnull, "w")
548
+ self.__init__orig__(iterable, desc, *args, **kwargs)
549
+
550
+ def iter_tqdm(self):
551
+ if self._progress is not None:
552
+ return self._progress
553
+ else:
554
+ return self.__iter__orig__()
555
+
556
+ def update_tqdm(self, n=1):
557
+ if self._progress is not None:
558
+ self._progress.update(n)
559
+ return self.__update__orig__(n)
560
+
561
+ def close_tqdm(self):
562
+ if self._progress is not None:
563
+ self._progress.close(self)
564
+ return self.__close__orig__()
565
+
566
+ def exit_tqdm(self, exc_type, exc_value, traceback):
567
+ if self._progress is not None:
568
+ self._progress.close(self)
569
+ return self.__exit__orig__(exc_type, exc_value, traceback)
570
+
571
+ if not hasattr(_tqdm.tqdm, "__init__orig__"):
572
+ _tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__
573
+ _tqdm.tqdm.__init__ = init_tqdm
574
+ if not hasattr(_tqdm.tqdm, "__update__orig__"):
575
+ _tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update
576
+ _tqdm.tqdm.update = update_tqdm
577
+ if not hasattr(_tqdm.tqdm, "__close__orig__"):
578
+ _tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close
579
+ _tqdm.tqdm.close = close_tqdm
580
+ if not hasattr(_tqdm.tqdm, "__exit__orig__"):
581
+ _tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__
582
+ _tqdm.tqdm.__exit__ = exit_tqdm
583
+ if not hasattr(_tqdm.tqdm, "__iter__orig__"):
584
+ _tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__
585
+ _tqdm.tqdm.__iter__ = iter_tqdm
586
+ if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"):
587
+ _tqdm.auto.tqdm = _tqdm.tqdm
588
+
589
+ def tracked_fn(*args):
590
+ thread_id = threading.get_ident()
591
+ root_blocks._progress_tracker_per_thread[thread_id] = progress
592
+ response = fn(*args)
593
+ del root_blocks._progress_tracker_per_thread[thread_id]
594
+ return response
595
+
596
+ return progress, tracked_fn
597
+
598
+
599
+ def special_args(
600
+ fn: Callable,
601
+ inputs: List[Any] | None = None,
602
+ request: routes.Request | None = None,
603
+ ):
604
+ """
605
+ Checks if function has special arguments Request (via annotation) or Progress (via default value).
606
+ If inputs is provided, these values will be loaded into the inputs array.
607
+ Parameters:
608
+ block_fn: function to check.
609
+ inputs: array to load special arguments into.
610
+ request: request to load into inputs.
611
+ Returns:
612
+ updated inputs, request index, progress index
613
+ """
614
+ signature = inspect.signature(fn)
615
+ positional_args = []
616
+ for i, param in enumerate(signature.parameters.values()):
617
+ if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
618
+ break
619
+ positional_args.append(param)
620
+ progress_index = None
621
+ for i, param in enumerate(positional_args):
622
+ if isinstance(param.default, Progress):
623
+ progress_index = i
624
+ if inputs is not None:
625
+ inputs.insert(i, param.default)
626
+ elif param.annotation == routes.Request:
627
+ if inputs is not None:
628
+ inputs.insert(i, request)
629
+ if inputs is not None:
630
+ while len(inputs) < len(positional_args):
631
+ i = len(inputs)
632
+ param = positional_args[i]
633
+ if param.default == param.empty:
634
+ warnings.warn("Unexpected argument. Filling with None.")
635
+ inputs.append(None)
636
+ else:
637
+ inputs.append(param.default)
638
+ return inputs or [], progress_index
639
+
640
+
641
+ @document()
642
+ def update(**kwargs) -> dict:
643
+ """
644
+ Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component.
645
+ This is a shorthand for using the update method on a component.
646
+ For example, rather than using gr.Number.update(...) you can just use gr.update(...).
647
+ Note that your editor's autocompletion will suggest proper parameters
648
+ if you use the update method on the component.
649
+ Demos: blocks_essay, blocks_update, blocks_essay_update
650
+
651
+ Parameters:
652
+ kwargs: Key-word arguments used to update the component's properties.
653
+ Example:
654
+ # Blocks Example
655
+ import gradio as gr
656
+ with gr.Blocks() as demo:
657
+ radio = gr.Radio([1, 2, 4], label="Set the value of the number")
658
+ number = gr.Number(value=2, interactive=True)
659
+ radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
660
+ demo.launch()
661
+
662
+ # Interface example
663
+ import gradio as gr
664
+ def change_textbox(choice):
665
+ if choice == "short":
666
+ return gr.Textbox.update(lines=2, visible=True)
667
+ elif choice == "long":
668
+ return gr.Textbox.update(lines=8, visible=True)
669
+ else:
670
+ return gr.Textbox.update(visible=False)
671
+ gr.Interface(
672
+ change_textbox,
673
+ gr.Radio(
674
+ ["short", "long", "none"], label="What kind of essay would you like to write?"
675
+ ),
676
+ gr.Textbox(lines=2),
677
+ live=True,
678
+ ).launch()
679
+ """
680
+ kwargs["__type__"] = "generic_update"
681
+ return kwargs
682
+
683
+
684
+ def skip() -> dict:
685
+ return update()
686
+
687
+
688
+ @document()
689
+ def make_waveform(
690
+ audio: str | Tuple[int, np.ndarray],
691
+ *,
692
+ bg_color: str = "#f3f4f6",
693
+ bg_image: str | None = None,
694
+ fg_alpha: float = 0.75,
695
+ bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
696
+ bar_count: int = 50,
697
+ bar_width: float = 0.6,
698
+ ):
699
+ """
700
+ Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
701
+ Parameters:
702
+ audio: Audio file path or tuple of (sample_rate, audio_data)
703
+ bg_color: Background color of waveform (ignored if bg_image is provided)
704
+ bg_image: Background image of waveform
705
+ fg_alpha: Opacity of foreground waveform
706
+ bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
707
+ bar_count: Number of bars in waveform
708
+ bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
709
+ Returns:
710
+ A filepath to the output video.
711
+ """
712
+ if isinstance(audio, str):
713
+ audio_file = audio
714
+ audio = processing_utils.audio_from_file(audio)
715
+ else:
716
+ tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
717
+ processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
718
+ audio_file = tmp_wav.name
719
+ duration = round(len(audio[1]) / audio[0], 4)
720
+
721
+ # Helper methods to create waveform
722
+ def hex_to_RGB(hex_str):
723
+ return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
724
+
725
+ def get_color_gradient(c1, c2, n):
726
+ assert n > 1
727
+ c1_rgb = np.array(hex_to_RGB(c1)) / 255
728
+ c2_rgb = np.array(hex_to_RGB(c2)) / 255
729
+ mix_pcts = [x / (n - 1) for x in range(n)]
730
+ rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
731
+ return [
732
+ "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
733
+ for item in rgb_colors
734
+ ]
735
+
736
+ # Reshape audio to have a fixed number of bars
737
+ samples = audio[1]
738
+ if len(samples.shape) > 1:
739
+ samples = np.mean(samples, 1)
740
+ bins_to_pad = bar_count - (len(samples) % bar_count)
741
+ samples = np.pad(samples, [(0, bins_to_pad)])
742
+ samples = np.reshape(samples, (bar_count, -1))
743
+ samples = np.abs(samples)
744
+ samples = np.max(samples, 1)
745
+
746
+ matplotlib.use("Agg")
747
+ plt.clf()
748
+ # Plot waveform
749
+ color = (
750
+ bars_color
751
+ if isinstance(bars_color, str)
752
+ else get_color_gradient(bars_color[0], bars_color[1], bar_count)
753
+ )
754
+ plt.bar(
755
+ np.arange(0, bar_count),
756
+ samples * 2,
757
+ bottom=(-1 * samples),
758
+ width=bar_width,
759
+ color=color,
760
+ )
761
+ plt.axis("off")
762
+ plt.margins(x=0)
763
+ tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
764
+ savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"}
765
+ if bg_image is not None:
766
+ savefig_kwargs["transparent"] = True
767
+ else:
768
+ savefig_kwargs["facecolor"] = bg_color
769
+ plt.savefig(tmp_img.name, **savefig_kwargs)
770
+ waveform_img = PIL.Image.open(tmp_img.name)
771
+ waveform_img = waveform_img.resize((1000, 200))
772
+
773
+ # Composite waveform with background image
774
+ if bg_image is not None:
775
+ waveform_array = np.array(waveform_img)
776
+ waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
777
+ waveform_img = PIL.Image.fromarray(waveform_array)
778
+
779
+ bg_img = PIL.Image.open(bg_image)
780
+ waveform_width, waveform_height = waveform_img.size
781
+ bg_width, bg_height = bg_img.size
782
+ if waveform_width != bg_width:
783
+ bg_img = bg_img.resize(
784
+ (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2))
785
+ )
786
+ bg_width, bg_height = bg_img.size
787
+ composite_height = max(bg_height, waveform_height)
788
+ composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF")
789
+ composite.paste(bg_img, (0, composite_height - bg_height))
790
+ composite.paste(
791
+ waveform_img, (0, composite_height - waveform_height), waveform_img
792
+ )
793
+ composite.save(tmp_img.name)
794
+ img_width, img_height = composite.size
795
+ else:
796
+ img_width, img_height = waveform_img.size
797
+ waveform_img.save(tmp_img.name)
798
+
799
+ # Convert waveform to video with ffmpeg
800
+ output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
801
+
802
+ ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}"""
803
+
804
+ subprocess.call(ffmpeg_cmd, shell=True)
805
+ return output_mp4.name
gligen/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gligen/SD_input_conv_weight_bias.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5a0efad69747a766158304f39091c2b6a24cafb5f833d174f32bee8e864a562
3
+ size 130
gligen/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (356 Bytes). View file
 
gligen/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (345 Bytes). View file
 
gligen/__pycache__/distributed.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
gligen/__pycache__/distributed.cpython-38.pyc ADDED
Binary file (2.91 kB). View file
 
gligen/__pycache__/evaluator.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
gligen/__pycache__/evaluator.cpython-38.pyc ADDED
Binary file (5.9 kB). View file
 
gligen/__pycache__/task_grounded_generation.cpython-310.pyc ADDED
Binary file (9.17 kB). View file
 
gligen/__pycache__/task_grounded_generation.cpython-38.pyc ADDED
Binary file (9.11 kB). View file
 
gligen/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
gligen/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
gligen/evaluator.py CHANGED
@@ -14,7 +14,7 @@ from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap
14
  from PIL import Image
15
  import math
16
  import json
17
-
18
 
19
  def draw_masks_from_boxes(boxes,size):
20
 
 
14
  from PIL import Image
15
  import math
16
  import json
17
+ #hello
18
 
19
  def draw_masks_from_boxes(boxes,size):
20
 
gligen/ldm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gligen/ldm/__pycache__/util.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
gligen/ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (3.2 kB). View file
 
gligen/ldm/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gligen/ldm/data/imagenet_train_hr_indices.p ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f86ea1924a1522b20bc0f709a069cc65f09d5fc617a7a31af7aaa3839a5a4d73
3
+ size 132
gligen/ldm/data/imagenet_val_hr_indices.p ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff1f5eb275a93c0fb53e227679f323ea1d024c87db296453296cebeef86fc0f4
3
+ size 131
gligen/ldm/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gligen/ldm/models/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (159 Bytes). View file
 
gligen/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (4.56 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (4.57 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (2.12 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-310.pyc ADDED
Binary file (4.07 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ldm.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/loss.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc ADDED
Binary file (4.23 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc ADDED
Binary file (8.65 kB). View file
 
gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc ADDED
Binary file (8.71 kB). View file
 
gligen/ldm/models/diffusion/ddim.py CHANGED
@@ -87,7 +87,9 @@ class DDIMSampler(object):
87
  # set alpha
88
  if self.alpha_generator_func != None:
89
  self.set_alpha_scale(self.model, alphas[i])
90
-
 
 
91
  # run
92
  index = total_steps - i - 1
93
  input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
@@ -110,9 +112,7 @@ class DDIMSampler(object):
110
 
111
  e_t = self.model(input)
112
  if uc is not None and guidance_scale != 1:
113
- unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc)
114
- if "inpainting_extra_input" in input:
115
- unconditional_input["inpainting_extra_input"] = input["inpainting_extra_input"]
116
  e_t_uncond = self.model( unconditional_input )
117
  e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
118
 
 
87
  # set alpha
88
  if self.alpha_generator_func != None:
89
  self.set_alpha_scale(self.model, alphas[i])
90
+ if alphas[i] == 0:
91
+ self.model.restore_first_conv_from_SD()
92
+
93
  # run
94
  index = total_steps - i - 1
95
  input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
 
112
 
113
  e_t = self.model(input)
114
  if uc is not None and guidance_scale != 1:
115
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input'])
 
 
116
  e_t_uncond = self.model( unconditional_input )
117
  e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
118