Changed device to global variable in gradio (ie gr.State instance)
Browse files
app.py
CHANGED
|
@@ -208,26 +208,26 @@ def get_ind_to_filter(text, word_ids, keywords):
|
|
| 208 |
return inds_to_filter
|
| 209 |
|
| 210 |
@spaces.GPU
|
| 211 |
-
def count(image, text, prompts, state):
|
| 212 |
print("state: " + str(state))
|
| 213 |
keywords = "" # do not handle this for now
|
| 214 |
# Handle no prompt case.
|
| 215 |
if prompts is None:
|
| 216 |
prompts = {"image": image, "points": []}
|
| 217 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 218 |
-
input_image = input_image.unsqueeze(0).to(
|
| 219 |
exemplars = get_box_inputs(prompts["points"])
|
| 220 |
print(exemplars)
|
| 221 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 222 |
-
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(
|
| 223 |
-
exemplars = [exemplars["exemplars"].to(
|
| 224 |
|
| 225 |
with torch.no_grad():
|
| 226 |
model_output = model(
|
| 227 |
nested_tensor_from_tensor_list(input_image),
|
| 228 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
| 229 |
exemplars,
|
| 230 |
-
[torch.tensor([0]).to(
|
| 231 |
captions=[text + " ."] * len(input_image),
|
| 232 |
)
|
| 233 |
|
|
@@ -297,25 +297,25 @@ def count(image, text, prompts, state):
|
|
| 297 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
| 298 |
|
| 299 |
@spaces.GPU
|
| 300 |
-
def count_main(image, text, prompts):
|
| 301 |
keywords = "" # do not handle this for now
|
| 302 |
# Handle no prompt case.
|
| 303 |
if prompts is None:
|
| 304 |
prompts = {"image": image, "points": []}
|
| 305 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 306 |
-
input_image = input_image.unsqueeze(0).to(
|
| 307 |
exemplars = get_box_inputs(prompts["points"])
|
| 308 |
print(exemplars)
|
| 309 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 310 |
-
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(
|
| 311 |
-
exemplars = [exemplars["exemplars"].to(
|
| 312 |
|
| 313 |
with torch.no_grad():
|
| 314 |
model_output = model(
|
| 315 |
nested_tensor_from_tensor_list(input_image),
|
| 316 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
| 317 |
exemplars,
|
| 318 |
-
[torch.tensor([0]).to(
|
| 319 |
captions=[text + " ."] * len(input_image),
|
| 320 |
)
|
| 321 |
|
|
@@ -396,6 +396,7 @@ As shown earlier, there are 3 ways to specify the object to count: (1) with text
|
|
| 396 |
|
| 397 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
| 398 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
|
|
|
| 399 |
with gr.Tab("Tutorial"):
|
| 400 |
with gr.Row():
|
| 401 |
with gr.Column():
|
|
@@ -419,7 +420,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
| 419 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
| 420 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
| 421 |
|
| 422 |
-
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
| 423 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
| 424 |
with gr.Tab("App", visible=True) as main_app:
|
| 425 |
|
|
@@ -445,7 +446,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
| 445 |
submit_btn_main = gr.Button("Count", variant="primary")
|
| 446 |
clear_btn_main = gr.ClearButton(variant="secondary")
|
| 447 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
| 448 |
-
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main], outputs=[detected_instances_main, pred_count_main])
|
| 449 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
| 450 |
|
| 451 |
|
|
|
|
| 208 |
return inds_to_filter
|
| 209 |
|
| 210 |
@spaces.GPU
|
| 211 |
+
def count(image, text, prompts, state, device):
|
| 212 |
print("state: " + str(state))
|
| 213 |
keywords = "" # do not handle this for now
|
| 214 |
# Handle no prompt case.
|
| 215 |
if prompts is None:
|
| 216 |
prompts = {"image": image, "points": []}
|
| 217 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 218 |
+
input_image = input_image.unsqueeze(0).to(device)
|
| 219 |
exemplars = get_box_inputs(prompts["points"])
|
| 220 |
print(exemplars)
|
| 221 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 222 |
+
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 223 |
+
exemplars = [exemplars["exemplars"].to(device)]
|
| 224 |
|
| 225 |
with torch.no_grad():
|
| 226 |
model_output = model(
|
| 227 |
nested_tensor_from_tensor_list(input_image),
|
| 228 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
| 229 |
exemplars,
|
| 230 |
+
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 231 |
captions=[text + " ."] * len(input_image),
|
| 232 |
)
|
| 233 |
|
|
|
|
| 297 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
| 298 |
|
| 299 |
@spaces.GPU
|
| 300 |
+
def count_main(image, text, prompts, device):
|
| 301 |
keywords = "" # do not handle this for now
|
| 302 |
# Handle no prompt case.
|
| 303 |
if prompts is None:
|
| 304 |
prompts = {"image": image, "points": []}
|
| 305 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 306 |
+
input_image = input_image.unsqueeze(0).to(device)
|
| 307 |
exemplars = get_box_inputs(prompts["points"])
|
| 308 |
print(exemplars)
|
| 309 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 310 |
+
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 311 |
+
exemplars = [exemplars["exemplars"].to(device)]
|
| 312 |
|
| 313 |
with torch.no_grad():
|
| 314 |
model_output = model(
|
| 315 |
nested_tensor_from_tensor_list(input_image),
|
| 316 |
nested_tensor_from_tensor_list(input_image_exemplars),
|
| 317 |
exemplars,
|
| 318 |
+
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 319 |
captions=[text + " ."] * len(input_image),
|
| 320 |
)
|
| 321 |
|
|
|
|
| 396 |
|
| 397 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
| 398 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
| 399 |
+
device = gr.State(args.device)
|
| 400 |
with gr.Tab("Tutorial"):
|
| 401 |
with gr.Row():
|
| 402 |
with gr.Column():
|
|
|
|
| 420 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
| 421 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
| 422 |
|
| 423 |
+
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
| 424 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
| 425 |
with gr.Tab("App", visible=True) as main_app:
|
| 426 |
|
|
|
|
| 446 |
submit_btn_main = gr.Button("Count", variant="primary")
|
| 447 |
clear_btn_main = gr.ClearButton(variant="secondary")
|
| 448 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
| 449 |
+
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
| 450 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
| 451 |
|
| 452 |
|