omer11a commited on
Commit
6c2e4ca
1 Parent(s): 8811432

Update code to new gradio version

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +33 -34
README.md CHANGED
@@ -4,10 +4,10 @@ emoji: 😺
4
  colorFrom: pink
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-sa-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: pink
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-sa-4.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -12,7 +12,7 @@ from PIL import Image, ImageDraw
12
 
13
  from functools import partial
14
 
15
- RESOLUTION = 512
16
  MIN_SIZE = 0.01
17
  WHITE = 255
18
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
@@ -85,7 +85,7 @@ def generate(
85
  ):
86
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
87
  if len(boxes) != len(subject_token_indices):
88
- raise ValueError("""
89
  The number of boxes should be equal to the number of subject token indices.
90
  Number of boxes drawn: {}, number of grounding tokens: {}.
91
  """.format(len(boxes), len(subject_token_indices)))
@@ -99,11 +99,6 @@ def generate(
99
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
100
  num_iterations, loss_threshold, num_guidance_steps, seed)
101
 
102
- blank_samples = batch_size % 2 if batch_size > 1 else 0
103
- images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(images)] \
104
- + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
105
- + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
106
-
107
  return images
108
 
109
 
@@ -114,22 +109,25 @@ def convert_token_indices(token_indices, nested=False):
114
  return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
115
 
116
 
117
- def draw(boxes, mask, layout_image):
118
- if mask.ndim == 3:
119
- mask = WHITE - mask[..., 0]
 
 
 
120
 
121
- mask = (mask != 0).astype('uint8') * WHITE
122
- if mask.sum() > 0:
123
  x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
124
  y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
125
  y1, y2 = y1y2.min(), y1y2.max()
126
  x1, x2 = x1x2.min(), x1x2.max()
127
 
128
- if (x2 - x1 > MIN_SIZE) and (y2 - y1 > MIN_SIZE):
129
- boxes.append((x1, y1, x2, y2))
130
- layout_image = draw_boxes(boxes)
 
131
 
132
- return [boxes, None, layout_image]
 
133
 
134
 
135
  def draw_boxes(boxes):
@@ -146,9 +144,7 @@ def draw_boxes(boxes):
146
 
147
 
148
  def clear(batch_size):
149
- blank_samples = batch_size % 2 if batch_size > 1 else 0
150
- out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
151
- return [[], None, None] + out_images
152
 
153
 
154
  def main():
@@ -209,7 +205,7 @@ def main():
209
  <br>
210
  <span style="font-size: 18px" id="paper-info">
211
  [<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>]
212
- [<a href=" " target="_blank">Paper</a>]
213
  [<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>]
214
  </span>
215
  </p>
@@ -233,13 +229,16 @@ def main():
233
  )
234
 
235
  with gr.Row():
236
- sketch_pad = gr.Sketchpad(label="Sketch Pad", shape=(RESOLUTION, RESOLUTION))
237
- layout_image = gr.Image(type="pil", label="Bounding Boxes")
238
- out_images = gr.Image(type="pil", visible=True, label="Generated Image")
239
 
240
  with gr.Row():
241
  clear_button = gr.Button(value='Clear')
242
- generate_button = gr.Button(value='Generate')
 
 
 
 
243
 
244
  with gr.Accordion("Advanced Options", open=False):
245
  with gr.Column():
@@ -290,21 +289,21 @@ def main():
290
 
291
  boxes = gr.State([])
292
 
293
- sketch_pad.edit(
294
- draw,
295
- inputs=[boxes, sketch_pad, layout_image],
296
- outputs=[boxes, sketch_pad, layout_image],
297
- queue=False,
298
- )
299
-
300
  clear_button.click(
301
  clear,
302
  inputs=[batch_size],
303
- outputs=[boxes, sketch_pad, layout_image, out_images],
 
 
 
 
 
 
 
304
  queue=False,
305
  )
306
 
307
- generate_button.click(
308
  fn=partial(generate, device, model),
309
  inputs=[
310
  prompt, subject_token_indices, filter_token_indices, num_tokens,
@@ -337,7 +336,7 @@ def main():
337
  gr.HTML(description)
338
 
339
  demo.queue(max_size=50)
340
- demo.launch(share=False, show_api=False, show_error=True)
341
 
342
  if __name__ == '__main__':
343
  main()
 
12
 
13
  from functools import partial
14
 
15
+ RESOLUTION = 256
16
  MIN_SIZE = 0.01
17
  WHITE = 255
18
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
 
85
  ):
86
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
87
  if len(boxes) != len(subject_token_indices):
88
+ raise gr.Error("""
89
  The number of boxes should be equal to the number of subject token indices.
90
  Number of boxes drawn: {}, number of grounding tokens: {}.
91
  """.format(len(boxes), len(subject_token_indices)))
 
99
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
100
  num_iterations, loss_threshold, num_guidance_steps, seed)
101
 
 
 
 
 
 
102
  return images
103
 
104
 
 
109
  return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
110
 
111
 
112
+ def draw(sketchpad):
113
+ boxes = []
114
+ for i, layer in enumerate(sketchpad['layers']):
115
+ mask = (layer != 0)
116
+ if mask.sum() < 0:
117
+ raise gr.Error(f'Box in layer {i} is too small')
118
 
 
 
119
  x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
120
  y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
121
  y1, y2 = y1y2.min(), y1y2.max()
122
  x1, x2 = x1x2.min(), x1x2.max()
123
 
124
+ if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
125
+ raise gr.Error(f'Box in layer {i} is too small')
126
+
127
+ boxes.append((x1, y1, x2, y2))
128
 
129
+ layout_image = draw_boxes(boxes)
130
+ return [boxes, layout_image]
131
 
132
 
133
  def draw_boxes(boxes):
 
144
 
145
 
146
  def clear(batch_size):
147
+ return [[], None, None, None]
 
 
148
 
149
 
150
  def main():
 
205
  <br>
206
  <span style="font-size: 18px" id="paper-info">
207
  [<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>]
208
+ [<a href="https://arxiv.org/abs/2403.16990" target="_blank">Paper</a>]
209
  [<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>]
210
  </span>
211
  </p>
 
229
  )
230
 
231
  with gr.Row():
232
+ sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
233
+ layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION)
 
234
 
235
  with gr.Row():
236
  clear_button = gr.Button(value='Clear')
237
+ generate_layout_button = gr.Button(value='Generate layout')
238
+ generate_image_button = gr.Button(value='Generate image')
239
+
240
+ with gr.Row():
241
+ out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
242
 
243
  with gr.Accordion("Advanced Options", open=False):
244
  with gr.Column():
 
289
 
290
  boxes = gr.State([])
291
 
 
 
 
 
 
 
 
292
  clear_button.click(
293
  clear,
294
  inputs=[batch_size],
295
+ outputs=[boxes, sketchpad, layout_image, out_images],
296
+ queue=False,
297
+ )
298
+
299
+ generate_layout_button.click(
300
+ draw,
301
+ inputs=[sketchpad],
302
+ outputs=[boxes, layout_image],
303
  queue=False,
304
  )
305
 
306
+ generate_image_button.click(
307
  fn=partial(generate, device, model),
308
  inputs=[
309
  prompt, subject_token_indices, filter_token_indices, num_tokens,
 
336
  gr.HTML(description)
337
 
338
  demo.queue(max_size=50)
339
+ demo.launch(show_api=False, show_error=True)
340
 
341
  if __name__ == '__main__':
342
  main()