ClassCat commited on
Commit
9271906
1 Parent(s): af93aba

update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -43,11 +43,12 @@ def get_figure(in_pil_img, in_results):
43
  for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
44
  selected_color = choice(COLORS)
45
 
46
- x, y, w, h = torch.round(box[0]).item(), torch.round(box[1]).item(), torch.round(box[2]-box[0]).item(), torch.round(box[3]-box[1]).item()
47
- #x, y, w, h = int(box[0]), int(box[1]), int(box[2]-box[0]), int(box[3]-box[1])
 
48
 
49
  ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=2))
50
- ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 2)}%", fontdict=fdic)
51
 
52
  plt.axis("off")
53
 
@@ -71,7 +72,6 @@ def infer(in_pil_img, in_model="yolos-tiny", in_threshold=0.9):
71
  # convert outputs (bounding boxes and class logits) to COCO API
72
  results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
73
 
74
-
75
  figure = get_figure(in_pil_img, results)
76
 
77
  buf = io.BytesIO()
 
43
  for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
44
  selected_color = choice(COLORS)
45
 
46
+ box_int = [i.item() for i in torch.round(box).to(torch.int32)]
47
+ x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
48
+ #x, y, w, h = torch.round(box[0]).item(), torch.round(box[1]).item(), torch.round(box[2]-box[0]).item(), torch.round(box[3]-box[1]).item()
49
 
50
  ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=2))
51
+ ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 2)}%", fontdict=fdic, alpha=0.8)
52
 
53
  plt.axis("off")
54
 
 
72
  # convert outputs (bounding boxes and class logits) to COCO API
73
  results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
74
 
 
75
  figure = get_figure(in_pil_img, results)
76
 
77
  buf = io.BytesIO()