napatswift commited on
Commit
00196b8
·
1 Parent(s): 5242954
Files changed (1) hide show
  1. main.py +8 -6
main.py CHANGED
@@ -88,14 +88,16 @@ def predict(image_input):
88
  mask_images = result.pred_instances.masks.cpu().numpy()
89
  scores = result.pred_instances.scores.cpu().numpy()
90
 
91
- # If there are no tables, return an empty list.
92
- if len(mask_images) == 0:
93
- return []
94
 
95
- # Get mask with highest score
96
- mask_image = mask_images[scores.argmax()]
97
 
98
- return get_bbox(mask_image.astype(np.uint8))
 
 
 
 
99
 
100
 
101
  def run():
 
88
  mask_images = result.pred_instances.masks.cpu().numpy()
89
  scores = result.pred_instances.scores.cpu().numpy()
90
 
91
+ bbox_list = []
 
 
92
 
93
+ # Filter out the masks with a score less than 0.5.
94
+ filtered_mask_images = mask_images[scores > 0.5]
95
 
96
+ # Get the bounding boxes of the tables.
97
+ for mask_image in filtered_mask_images:
98
+ bbox_list.extend(get_bbox(mask_image.astype(np.uint8)))
99
+
100
+ return bbox_list
101
 
102
 
103
  def run():