Jsonwu commited on
Commit
f2fd017
1 Parent(s): 68135d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -15,22 +15,22 @@ with open(LABELS_PATH, "r") as f:
15
 
16
  img_transforms = transforms.ToTensor()
17
 
18
- # inter_class_nms implemented by GPT
19
- def inter_class_nms(boxes, scores, iou_threshold=0.5):
20
  # Perform non-maximum suppression
21
  keep = nms(boxes, scores, iou_threshold)
22
 
23
  # Filter boxes and scores
24
  new_boxes = boxes[keep]
25
  new_scores = scores[keep]
 
26
 
27
  # Return the result in a dictionary
28
- return {'boxes': new_boxes, 'scores': new_scores}
29
 
30
  def predict(img, conf_thresh=0.4):
31
  img_input = [img_transforms(img)]
32
  _, pred = model(img_input)
33
- pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'])]
34
  out_img = img.copy()
35
  draw = ImageDraw.Draw(out_img)
36
  font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)
 
15
 
16
  img_transforms = transforms.ToTensor()
17
 
18
+ def inter_class_nms(boxes, scores, labels, iou_threshold=0.5):
 
19
  # Perform non-maximum suppression
20
  keep = nms(boxes, scores, iou_threshold)
21
 
22
  # Filter boxes and scores
23
  new_boxes = boxes[keep]
24
  new_scores = scores[keep]
25
+ new_labels = labels[keep]
26
 
27
  # Return the result in a dictionary
28
+ return {'boxes': new_boxes, 'scores': new_scores, 'labels': new_labels}
29
 
30
  def predict(img, conf_thresh=0.4):
31
  img_input = [img_transforms(img)]
32
  _, pred = model(img_input)
33
+ pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'], pred[0]['labels'])]
34
  out_img = img.copy()
35
  draw = ImageDraw.Draw(out_img)
36
  font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)