Spaces:
Running
on
Zero
Running
on
Zero
Update code to new gradio version
Browse files
README.md
CHANGED
@@ -4,10 +4,10 @@ emoji: 😺
|
|
4 |
colorFrom: pink
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: cc-by-nc-sa-
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: pink
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.22.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: cc-by-nc-sa-4.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -12,7 +12,7 @@ from PIL import Image, ImageDraw
|
|
12 |
|
13 |
from functools import partial
|
14 |
|
15 |
-
RESOLUTION =
|
16 |
MIN_SIZE = 0.01
|
17 |
WHITE = 255
|
18 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
@@ -85,7 +85,7 @@ def generate(
|
|
85 |
):
|
86 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
87 |
if len(boxes) != len(subject_token_indices):
|
88 |
-
raise
|
89 |
The number of boxes should be equal to the number of subject token indices.
|
90 |
Number of boxes drawn: {}, number of grounding tokens: {}.
|
91 |
""".format(len(boxes), len(subject_token_indices)))
|
@@ -99,11 +99,6 @@ def generate(
|
|
99 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
100 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
101 |
|
102 |
-
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
103 |
-
images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(images)] \
|
104 |
-
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
|
105 |
-
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
|
106 |
-
|
107 |
return images
|
108 |
|
109 |
|
@@ -114,22 +109,25 @@ def convert_token_indices(token_indices, nested=False):
|
|
114 |
return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
|
115 |
|
116 |
|
117 |
-
def draw(
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
-
mask = (mask != 0).astype('uint8') * WHITE
|
122 |
-
if mask.sum() > 0:
|
123 |
x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
|
124 |
y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
|
125 |
y1, y2 = y1y2.min(), y1y2.max()
|
126 |
x1, x2 = x1x2.min(), x1x2.max()
|
127 |
|
128 |
-
if (x2 - x1
|
129 |
-
|
130 |
-
|
|
|
131 |
|
132 |
-
|
|
|
133 |
|
134 |
|
135 |
def draw_boxes(boxes):
|
@@ -146,9 +144,7 @@ def draw_boxes(boxes):
|
|
146 |
|
147 |
|
148 |
def clear(batch_size):
|
149 |
-
|
150 |
-
out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
|
151 |
-
return [[], None, None] + out_images
|
152 |
|
153 |
|
154 |
def main():
|
@@ -209,7 +205,7 @@ def main():
|
|
209 |
<br>
|
210 |
<span style="font-size: 18px" id="paper-info">
|
211 |
[<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>]
|
212 |
-
[<a href="
|
213 |
[<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>]
|
214 |
</span>
|
215 |
</p>
|
@@ -233,13 +229,16 @@ def main():
|
|
233 |
)
|
234 |
|
235 |
with gr.Row():
|
236 |
-
|
237 |
-
layout_image = gr.Image(type="pil", label="Bounding Boxes")
|
238 |
-
out_images = gr.Image(type="pil", visible=True, label="Generated Image")
|
239 |
|
240 |
with gr.Row():
|
241 |
clear_button = gr.Button(value='Clear')
|
242 |
-
|
|
|
|
|
|
|
|
|
243 |
|
244 |
with gr.Accordion("Advanced Options", open=False):
|
245 |
with gr.Column():
|
@@ -290,21 +289,21 @@ def main():
|
|
290 |
|
291 |
boxes = gr.State([])
|
292 |
|
293 |
-
sketch_pad.edit(
|
294 |
-
draw,
|
295 |
-
inputs=[boxes, sketch_pad, layout_image],
|
296 |
-
outputs=[boxes, sketch_pad, layout_image],
|
297 |
-
queue=False,
|
298 |
-
)
|
299 |
-
|
300 |
clear_button.click(
|
301 |
clear,
|
302 |
inputs=[batch_size],
|
303 |
-
outputs=[boxes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
queue=False,
|
305 |
)
|
306 |
|
307 |
-
|
308 |
fn=partial(generate, device, model),
|
309 |
inputs=[
|
310 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
@@ -337,7 +336,7 @@ def main():
|
|
337 |
gr.HTML(description)
|
338 |
|
339 |
demo.queue(max_size=50)
|
340 |
-
demo.launch(
|
341 |
|
342 |
if __name__ == '__main__':
|
343 |
main()
|
|
|
12 |
|
13 |
from functools import partial
|
14 |
|
15 |
+
RESOLUTION = 256
|
16 |
MIN_SIZE = 0.01
|
17 |
WHITE = 255
|
18 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
|
|
85 |
):
|
86 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
87 |
if len(boxes) != len(subject_token_indices):
|
88 |
+
raise gr.Error("""
|
89 |
The number of boxes should be equal to the number of subject token indices.
|
90 |
Number of boxes drawn: {}, number of grounding tokens: {}.
|
91 |
""".format(len(boxes), len(subject_token_indices)))
|
|
|
99 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
100 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
102 |
return images
|
103 |
|
104 |
|
|
|
109 |
return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
|
110 |
|
111 |
|
112 |
+
def draw(sketchpad):
|
113 |
+
boxes = []
|
114 |
+
for i, layer in enumerate(sketchpad['layers']):
|
115 |
+
mask = (layer != 0)
|
116 |
+
if mask.sum() < 0:
|
117 |
+
raise gr.Error(f'Box in layer {i} is too small')
|
118 |
|
|
|
|
|
119 |
x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
|
120 |
y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
|
121 |
y1, y2 = y1y2.min(), y1y2.max()
|
122 |
x1, x2 = x1x2.min(), x1x2.max()
|
123 |
|
124 |
+
if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
|
125 |
+
raise gr.Error(f'Box in layer {i} is too small')
|
126 |
+
|
127 |
+
boxes.append((x1, y1, x2, y2))
|
128 |
|
129 |
+
layout_image = draw_boxes(boxes)
|
130 |
+
return [boxes, layout_image]
|
131 |
|
132 |
|
133 |
def draw_boxes(boxes):
|
|
|
144 |
|
145 |
|
146 |
def clear(batch_size):
|
147 |
+
return [[], None, None, None]
|
|
|
|
|
148 |
|
149 |
|
150 |
def main():
|
|
|
205 |
<br>
|
206 |
<span style="font-size: 18px" id="paper-info">
|
207 |
[<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>]
|
208 |
+
[<a href="https://arxiv.org/abs/2403.16990" target="_blank">Paper</a>]
|
209 |
[<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>]
|
210 |
</span>
|
211 |
</p>
|
|
|
229 |
)
|
230 |
|
231 |
with gr.Row():
|
232 |
+
sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
|
233 |
+
layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION)
|
|
|
234 |
|
235 |
with gr.Row():
|
236 |
clear_button = gr.Button(value='Clear')
|
237 |
+
generate_layout_button = gr.Button(value='Generate layout')
|
238 |
+
generate_image_button = gr.Button(value='Generate image')
|
239 |
+
|
240 |
+
with gr.Row():
|
241 |
+
out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
|
242 |
|
243 |
with gr.Accordion("Advanced Options", open=False):
|
244 |
with gr.Column():
|
|
|
289 |
|
290 |
boxes = gr.State([])
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
clear_button.click(
|
293 |
clear,
|
294 |
inputs=[batch_size],
|
295 |
+
outputs=[boxes, sketchpad, layout_image, out_images],
|
296 |
+
queue=False,
|
297 |
+
)
|
298 |
+
|
299 |
+
generate_layout_button.click(
|
300 |
+
draw,
|
301 |
+
inputs=[sketchpad],
|
302 |
+
outputs=[boxes, layout_image],
|
303 |
queue=False,
|
304 |
)
|
305 |
|
306 |
+
generate_image_button.click(
|
307 |
fn=partial(generate, device, model),
|
308 |
inputs=[
|
309 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
|
|
336 |
gr.HTML(description)
|
337 |
|
338 |
demo.queue(max_size=50)
|
339 |
+
demo.launch(show_api=False, show_error=True)
|
340 |
|
341 |
if __name__ == '__main__':
|
342 |
main()
|