sagemaker commited on
Commit
7b843a4
1 Parent(s): bc954b8

some upgrades

Browse files
Files changed (1) hide show
  1. app.py +33 -10
app.py CHANGED
@@ -4,13 +4,21 @@ from PIL import Image
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import io
 
7
 
8
 
9
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
10
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
11
 
12
 
13
- def infer(img, model_name):
 
 
 
 
 
 
 
14
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
15
  model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
16
 
@@ -22,26 +30,33 @@ def infer(img, model_name):
22
  outputs = model(pixel_values, output_attentions=True)
23
 
24
  probas = outputs.logits.softmax(-1)[0, :, :-1]
25
- keep = probas.max(-1).values > 0.9
26
 
27
  target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
28
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
29
  bboxes_scaled = postprocessed_outputs[0]['boxes']
30
 
31
- res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model)
 
32
 
33
  return res_img
34
 
35
- def plot_results(pil_img, prob, boxes, model):
36
  plt.figure(figsize=(16,10))
37
  plt.imshow(pil_img)
38
  ax = plt.gca()
39
  colors = COLORS * 100
40
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
41
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
42
- fill=False, color=c, linewidth=3))
43
  cl = p.argmax()
44
- text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
 
 
 
 
 
 
 
 
45
  ax.text(xmin, ymin, text, fontsize=15,
46
  bbox=dict(facecolor='yellow', alpha=0.5))
47
  plt.axis('off')
@@ -54,17 +69,25 @@ def fig2img(fig):
54
  img = Image.open(buf)
55
  return img
56
 
57
- description = """Object Detection with YOLOS. Choose your model and you're good to go."""
 
 
 
 
 
 
58
 
59
  image_in = gr.components.Image()
60
  image_out = gr.components.Image()
61
  model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos_base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small")
 
 
62
 
63
  Iface = gr.Interface(
64
  fn=infer,
65
- inputs=[image_in,model_choice],
66
  outputs=image_out,
67
- examples=[["examples/10_People_Marching_People_Marching_2_120.jpg"], ["examples/12_Group_Group_12_Group_Group_12_26.jpg"], ["examples/43_Row_Boat_Canoe_43_247.jpg"]],
68
  title="Object Detection with YOLOS",
69
  description=description,
70
  ).launch()
 
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import io
7
+ import numpy as np
8
 
9
 
10
  COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
11
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
12
 
13
 
14
+ def get_class_list_from_input(classes_string: str):
15
+ if classes_string == "":
16
+ return []
17
+ classes_list = classes_string.split(",")
18
+ classes_list = [x.strip() for x in classes_list]
19
+ return classes_list
20
+
21
+ def infer(img, model_name: str, prob_threshold: int, classes_to_show = str):
22
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
23
  model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
24
 
 
30
  outputs = model(pixel_values, output_attentions=True)
31
 
32
  probas = outputs.logits.softmax(-1)[0, :, :-1]
33
+ keep = probas.max(-1).values > prob_threshold
34
 
35
  target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
36
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
37
  bboxes_scaled = postprocessed_outputs[0]['boxes']
38
 
39
+ classes_list = get_class_list_from_input(classes_to_show)
40
+ res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
41
 
42
  return res_img
43
 
44
+ def plot_results(pil_img, prob, boxes, model, classes_list):
45
  plt.figure(figsize=(16,10))
46
  plt.imshow(pil_img)
47
  ax = plt.gca()
48
  colors = COLORS * 100
49
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
 
 
50
  cl = p.argmax()
51
+ object_class = model.config.id2label[cl.item()]
52
+
53
+ if len(classes_list) > 0 :
54
+ if object_class not in classes_list:
55
+ continue
56
+
57
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
58
+ fill=False, color=c, linewidth=3))
59
+ text = f'{object_class}: {p[cl]:0.2f}'
60
  ax.text(xmin, ymin, text, fontsize=15,
61
  bbox=dict(facecolor='yellow', alpha=0.5))
62
  plt.axis('off')
 
69
  img = Image.open(buf)
70
  return img
71
 
72
+ description = """Object Detection with YOLOS. Choose https://github.com/amikelive/coco-labels/blob/master/coco-labels-2014_2017.txtyour model and you're good to go.
73
+
74
+ You can adapt the minimum probability threshold with the slider.
75
+
76
+ Additionally you can restrict the classes that will be shown by putting in a comma separated list of
77
+ [COCO classes](https://github.com/amikelive/coco-labels/blob/master/coco-labels-2014_2017.txt).
78
+ Leaving the field empty will show all classes"""
79
 
80
  image_in = gr.components.Image()
81
  image_out = gr.components.Image()
82
  model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos_base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small")
83
+ prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.9, label="Probability Threshold")
84
+ classes_to_show = gr.components.Textbox(placeholder="e.g. person, boat")
85
 
86
  Iface = gr.Interface(
87
  fn=infer,
88
+ inputs=[image_in,model_choice, prob_threshold_slider, classes_to_show],
89
  outputs=image_out,
90
+ #examples=[["examples/10_People_Marching_People_Marching_2_120.jpg"], ["examples/12_Group_Group_12_Group_Group_12_26.jpg"], ["examples/43_Row_Boat_Canoe_43_247.jpg"]],
91
  title="Object Detection with YOLOS",
92
  description=description,
93
  ).launch()