bmay commited on
Commit
9e62052
1 Parent(s): b25a60b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -14,7 +14,7 @@ def load_description(fp):
14
 
15
 
16
  @spaces.GPU(duration=90)
17
- def run_theia(image):
18
  theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
19
  theia_model = theia_model.to('cuda')
20
  target_model_names = [
@@ -41,19 +41,22 @@ def run_theia(image):
41
  mask_generator=mask_generator,
42
  sam_model=sam_model,
43
  depth_anything_decoder=depth_anything_decoder,
44
- pred_iou_thresh=0.5,
45
- stability_score_thresh=0.7,
46
  gt=True,
47
  device='cuda',
48
  )
49
 
50
  _, width, _ = theia_decode_results[0].shape
51
- theia_decode_dino = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, width // 4 : 2 * width // 4, :]
52
- theia_decode_sam = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, 2 * width // 4 : 3 * width // 4, :]
53
- theia_decode_depth = (255.0 * theia_decode_results[0]).astype(np.uint8)[:, 3 * width // 4 :, :]
54
- gt_dino = (255.0 * gt_results[0]).astype(np.uint8)[:, width // 4 : 2 * width // 4, :]
55
- gt_sam = (255.0 * gt_results[0]).astype(np.uint8)[:, 2 * width // 4 : 3 * width // 4, :]
56
- gt_depth = (255.0 * gt_results[0]).astype(np.uint8)[:, 3 * width // 4 :, :]
 
 
 
57
 
58
  dinov2_output = [(theia_decode_dino, "Theia"), (gt_dino, "Ground Truth")]
59
  sam_output = [(theia_decode_sam, "Theia"), (gt_sam, "Ground Truth")]
@@ -67,6 +70,8 @@ with gr.Blocks() as demo:
67
  with gr.Row():
68
  with gr.Column():
69
  input_image = gr.Image(label="Input Image", type="pil")
 
 
70
  submit_button = gr.Button("Submit")
71
 
72
  with gr.Column():
@@ -74,7 +79,11 @@ with gr.Blocks() as demo:
74
  sam_output = gr.Gallery(label="SAM", type="numpy")
75
  depth_anything_output = gr.Gallery(label="Depth-Anything", type="numpy")
76
 
77
- submit_button.click(run_theia, inputs=input_image, outputs=[dinov2_output, sam_output, depth_anything_output])
 
 
 
 
78
 
79
  demo.queue()
80
  demo.launch()
 
14
 
15
 
16
  @spaces.GPU(duration=90)
17
+ def run_theia(image, pred_iou_thresh, stability_score_thresh):
18
  theia_model = AutoModel.from_pretrained("theaiinstitute/theia-tiny-patch16-224-cddsv", trust_remote_code=True)
19
  theia_model = theia_model.to('cuda')
20
  target_model_names = [
 
41
  mask_generator=mask_generator,
42
  sam_model=sam_model,
43
  depth_anything_decoder=depth_anything_decoder,
44
+ pred_iou_thresh=pred_iou_thresh,
45
+ stability_score_thresh=stability_score_thresh,
46
  gt=True,
47
  device='cuda',
48
  )
49
 
50
  _, width, _ = theia_decode_results[0].shape
51
+ theia_decode_results = (255.0 * theia_decode_results[0]).astype(np.uint8)
52
+ theia_decode_dino = theia_decode_results[:, width // 4 : 2 * width // 4, :]
53
+ theia_decode_sam = theia_decode_results[:, 2 * width // 4 : 3 * width // 4, :]
54
+ theia_decode_depth = theia_decode_results[:, 3 * width // 4 :, :]
55
+
56
+ gt_results = (255.0 * gt_results[0]).astype(np.uint8)
57
+ gt_dino = gt_results[:, width // 4 : 2 * width // 4, :]
58
+ gt_sam = gt_results[:, 2 * width // 4 : 3 * width // 4, :]
59
+ gt_depth = gt_results[:, 3 * width // 4 :, :]
60
 
61
  dinov2_output = [(theia_decode_dino, "Theia"), (gt_dino, "Ground Truth")]
62
  sam_output = [(theia_decode_sam, "Theia"), (gt_sam, "Ground Truth")]
 
70
  with gr.Row():
71
  with gr.Column():
72
  input_image = gr.Image(label="Input Image", type="pil")
73
+ pred_iou_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Pred IoU Thresh")
74
+ stability_score_thresh = gr.Slider(0.0, 1.0, value=0.7, label="Stability Score Thresh")
75
  submit_button = gr.Button("Submit")
76
 
77
  with gr.Column():
 
79
  sam_output = gr.Gallery(label="SAM", type="numpy")
80
  depth_anything_output = gr.Gallery(label="Depth-Anything", type="numpy")
81
 
82
+ submit_button.click(
83
+ run_theia,
84
+ inputs=[input_image, pred_iou_thresh, stability_score_thresh],
85
+ outputs=[dinov2_output, sam_output, depth_anything_output]
86
+ )
87
 
88
  demo.queue()
89
  demo.launch()