Changed device to global variable in gradio (ie gr.State instance)
Browse files
app.py
CHANGED
@@ -208,26 +208,26 @@ def get_ind_to_filter(text, word_ids, keywords):
|
|
208 |
return inds_to_filter
|
209 |
|
210 |
@spaces.GPU
|
211 |
-
def count(image, text, prompts, state):
|
212 |
print("state: " + str(state))
|
213 |
keywords = "" # do not handle this for now
|
214 |
# Handle no prompt case.
|
215 |
if prompts is None:
|
216 |
prompts = {"image": image, "points": []}
|
217 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
218 |
-
input_image = input_image.unsqueeze(0).to(
|
219 |
exemplars = get_box_inputs(prompts["points"])
|
220 |
print(exemplars)
|
221 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
222 |
-
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(
|
223 |
-
exemplars = [exemplars["exemplars"].to(
|
224 |
|
225 |
with torch.no_grad():
|
226 |
model_output = model(
|
227 |
nested_tensor_from_tensor_list(input_image),
|
228 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
229 |
exemplars,
|
230 |
-
[torch.tensor([0]).to(
|
231 |
captions=[text + " ."] * len(input_image),
|
232 |
)
|
233 |
|
@@ -297,25 +297,25 @@ def count(image, text, prompts, state):
|
|
297 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
298 |
|
299 |
@spaces.GPU
|
300 |
-
def count_main(image, text, prompts):
|
301 |
keywords = "" # do not handle this for now
|
302 |
# Handle no prompt case.
|
303 |
if prompts is None:
|
304 |
prompts = {"image": image, "points": []}
|
305 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
306 |
-
input_image = input_image.unsqueeze(0).to(
|
307 |
exemplars = get_box_inputs(prompts["points"])
|
308 |
print(exemplars)
|
309 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
310 |
-
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(
|
311 |
-
exemplars = [exemplars["exemplars"].to(
|
312 |
|
313 |
with torch.no_grad():
|
314 |
model_output = model(
|
315 |
nested_tensor_from_tensor_list(input_image),
|
316 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
317 |
exemplars,
|
318 |
-
[torch.tensor([0]).to(
|
319 |
captions=[text + " ."] * len(input_image),
|
320 |
)
|
321 |
|
@@ -396,6 +396,7 @@ As shown earlier, there are 3 ways to specify the object to count: (1) with text
|
|
396 |
|
397 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
398 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
|
|
399 |
with gr.Tab("Tutorial"):
|
400 |
with gr.Row():
|
401 |
with gr.Column():
|
@@ -419,7 +420,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
419 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
420 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
421 |
|
422 |
-
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
423 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
424 |
with gr.Tab("App", visible=True) as main_app:
|
425 |
|
@@ -445,7 +446,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
445 |
submit_btn_main = gr.Button("Count", variant="primary")
|
446 |
clear_btn_main = gr.ClearButton(variant="secondary")
|
447 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
448 |
-
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main], outputs=[detected_instances_main, pred_count_main])
|
449 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
450 |
|
451 |
|
|
|
208 |
return inds_to_filter
|
209 |
|
210 |
@spaces.GPU
|
211 |
+
def count(image, text, prompts, state, device):
|
212 |
print("state: " + str(state))
|
213 |
keywords = "" # do not handle this for now
|
214 |
# Handle no prompt case.
|
215 |
if prompts is None:
|
216 |
prompts = {"image": image, "points": []}
|
217 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
218 |
+
input_image = input_image.unsqueeze(0).to(device)
|
219 |
exemplars = get_box_inputs(prompts["points"])
|
220 |
print(exemplars)
|
221 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
222 |
+
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
223 |
+
exemplars = [exemplars["exemplars"].to(device)]
|
224 |
|
225 |
with torch.no_grad():
|
226 |
model_output = model(
|
227 |
nested_tensor_from_tensor_list(input_image),
|
228 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
229 |
exemplars,
|
230 |
+
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
231 |
captions=[text + " ."] * len(input_image),
|
232 |
)
|
233 |
|
|
|
297 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
298 |
|
299 |
@spaces.GPU
|
300 |
+
def count_main(image, text, prompts, device):
|
301 |
keywords = "" # do not handle this for now
|
302 |
# Handle no prompt case.
|
303 |
if prompts is None:
|
304 |
prompts = {"image": image, "points": []}
|
305 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
306 |
+
input_image = input_image.unsqueeze(0).to(device)
|
307 |
exemplars = get_box_inputs(prompts["points"])
|
308 |
print(exemplars)
|
309 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
310 |
+
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
311 |
+
exemplars = [exemplars["exemplars"].to(device)]
|
312 |
|
313 |
with torch.no_grad():
|
314 |
model_output = model(
|
315 |
nested_tensor_from_tensor_list(input_image),
|
316 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
317 |
exemplars,
|
318 |
+
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
319 |
captions=[text + " ."] * len(input_image),
|
320 |
)
|
321 |
|
|
|
396 |
|
397 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
398 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
399 |
+
device = gr.State(args.device)
|
400 |
with gr.Tab("Tutorial"):
|
401 |
with gr.Row():
|
402 |
with gr.Column():
|
|
|
420 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
421 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
422 |
|
423 |
+
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
424 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
425 |
with gr.Tab("App", visible=True) as main_app:
|
426 |
|
|
|
446 |
submit_btn_main = gr.Button("Count", variant="primary")
|
447 |
clear_btn_main = gr.ClearButton(variant="secondary")
|
448 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
449 |
+
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
450 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
451 |
|
452 |
|