ClassCat commited on
Commit
30ef691
1 Parent(s): 783a0a3

add app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -43,32 +43,35 @@ def get_figure(in_pil_img, in_results):
43
  for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
44
  selected_color = choice(COLORS)
45
 
46
- #box = [round(i, 2) for i in box.tolist()]
47
- x, y, w, h = int(box[0]), int(box[1]), int(box[2]-box[0]), int(box[3]-box[1])
48
- print(x, y, w, h)
49
- ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
50
- ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 1)}%", fontdict=fdic)
51
- #print(
52
- # f"Detected {model_tiny.config.id2label[label.item()]} with confidence "
53
- # f"{round(score.item(), 3)} at location {box}"
54
- #)
55
 
56
  plt.axis("off")
57
 
58
  return plt.gcf()
59
 
60
 
61
- def infer(in_model, in_threshold, in_pil_img):
62
- inputs = image_processor_tiny(images=in_pil_img, return_tensors="pt")
63
- outputs = model_tiny(**inputs)
64
-
65
- # convert outputs (bounding boxes and class logits) to COCO API
66
  target_sizes = torch.tensor([in_pil_img.size[::-1]])
67
- results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[
68
- 0
69
- ]
70
- print(results)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  figure = get_figure(in_pil_img, results)
73
 
74
  buf = io.BytesIO()
@@ -88,7 +91,7 @@ with gr.Blocks(title="YOLOS Object Detection - ClassCat",
88
 
89
  gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
90
 
91
- model = gr.Radio(["yolos-tiny", "yolos-small"], value="yolos-tiny")
92
 
93
  gr.HTML("""<br/>""")
94
  gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
@@ -109,7 +112,14 @@ with gr.Blocks(title="YOLOS Object Detection - ClassCat",
109
  gr.HTML("""<h4 style="color:navy;">4. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
110
 
111
  send_btn = gr.Button("Infer")
112
- send_btn.click(fn=infer, inputs=[model, threshold, input_image], outputs=[output_image])
 
 
 
 
 
 
 
113
 
114
  #demo.queue()
115
  demo.launch(debug=True)
 
43
  for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
44
  selected_color = choice(COLORS)
45
 
46
+ x, y, w, h = round(box[0]), round(box[1]), round(box[2]-box[0]), round(box[3]-box[1])
47
+ #x, y, w, h = int(box[0]), int(box[1]), int(box[2]-box[0]), int(box[3]-box[1])
48
+
49
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=2))
50
+ ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 2)}%", fontdict=fdic)
 
 
 
 
51
 
52
  plt.axis("off")
53
 
54
  return plt.gcf()
55
 
56
 
57
+ def infer(in_pil_img, in_model="yolos-tiny", in_threshold=0.9):
 
 
 
 
58
  target_sizes = torch.tensor([in_pil_img.size[::-1]])
 
 
 
 
59
 
60
+ if in_model == "yolos-small":
61
+ inputs = image_processor_small(images=in_pil_img, return_tensors="pt")
62
+ outputs = model_small(**inputs)
63
+
64
+ # convert outputs (bounding boxes and class logits) to COCO API
65
+ results = image_processor_small.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
66
+
67
+ else:
68
+ inputs = image_processor_tiny(images=in_pil_img, return_tensors="pt")
69
+ outputs = model_tiny(**inputs)
70
+
71
+ # convert outputs (bounding boxes and class logits) to COCO API
72
+ results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]
73
+
74
+
75
  figure = get_figure(in_pil_img, results)
76
 
77
  buf = io.BytesIO()
 
91
 
92
  gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
93
 
94
+ model = gr.Radio(["yolos-tiny", "yolos-small"], value="yolos-tiny", label="Model name")
95
 
96
  gr.HTML("""<br/>""")
97
  gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
 
112
  gr.HTML("""<h4 style="color:navy;">4. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
113
 
114
  send_btn = gr.Button("Infer")
115
+ send_btn.click(fn=infer, inputs=[input_image, model, threshold], outputs=[output_image])
116
+
117
+ gr.HTML("""<br/>""")
118
+ gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
119
+ gr.HTML("""<ul>""")
120
+ gr.HTML("""<li><a href="https://huggingface.co/docs/transformers/model_doc/yolos" target="_blank">Hugging Face Transformers - YOLOS</a>""")
121
+ gr.HTML("""</ul>""")
122
+
123
 
124
  #demo.queue()
125
  demo.launch(debug=True)