kvignesh17 commited on
Commit
56daede
1 Parent(s): 44fb0a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -11,14 +11,14 @@ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
11
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
12
 
13
 
14
- def get_class_list_from_input(classes_string: str):
15
  if classes_string == "":
16
  return []
17
  classes_list = classes_string.split(",")
18
  classes_list = [x.strip() for x in classes_list]
19
  return classes_list
20
 
21
- def infer(img, model_name: str, prob_threshold: int, classes_to_show = str):
22
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
23
  model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
24
 
@@ -36,7 +36,7 @@ def infer(img, model_name: str, prob_threshold: int, classes_to_show = str):
36
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
37
  bboxes_scaled = postprocessed_outputs[0]['boxes']
38
 
39
- classes_list = get_class_list_from_input(classes_to_show)
40
  res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
41
 
42
  return res_img
@@ -75,10 +75,10 @@ image_in = gr.components.Image()
75
  image_out = gr.components.Image()
76
  model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos-base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small", label="YOLOS Model")
77
  prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.9, label="Probability Threshold")
78
- classes_to_show = gr.components.Textbox(placeholder="e.g. person, boat", label="Classes to use (empty means all classes)")
79
 
80
  Iface = gr.Interface(
81
- fn=infer,
82
  inputs=[image_in,model_choice, prob_threshold_slider, classes_to_show],
83
  outputs=image_out,
84
  title="YOLOS - Object Detection",
 
11
  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
12
 
13
 
14
+ def process_class_list(classes_string: str):
15
  if classes_string == "":
16
  return []
17
  classes_list = classes_string.split(",")
18
  classes_list = [x.strip() for x in classes_list]
19
  return classes_list
20
 
21
+ def model_inference(img, model_name: str, prob_threshold: int, classes_to_show = str):
22
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/{model_name}")
23
  model = YolosForObjectDetection.from_pretrained(f"hustvl/{model_name}")
24
 
 
36
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
37
  bboxes_scaled = postprocessed_outputs[0]['boxes']
38
 
39
+ classes_list = process_class_list(classes_to_show)
40
  res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
41
 
42
  return res_img
 
75
  image_out = gr.components.Image()
76
  model_choice = gr.components.Dropdown(["yolos-tiny", "yolos-small", "yolos-base", "yolos-small-300", "yolos-small-dwr"], value="yolos-small", label="YOLOS Model")
77
  prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.9, label="Probability Threshold")
78
+ classes_to_show = gr.components.Textbox(placeholder="e.g. person, truck", label="Classes to use (defaulted to detect all classes)")
79
 
80
  Iface = gr.Interface(
81
+ fn=model_inference,
82
  inputs=[image_in,model_choice, prob_threshold_slider, classes_to_show],
83
  outputs=image_out,
84
  title="YOLOS - Object Detection",