gatesla commited on
Commit
c9810d9
1 Parent(s): 324c14e

Trying to get YOLOV8

Browse files
Files changed (1) hide show
  1. app.py +61 -36
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import pathlib
7
  from PIL import Image
8
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
 
9
 
10
  import os
11
 
@@ -58,48 +59,71 @@ def detect_objects(model_name,url_input,image_input,threshold):
58
 
59
  #Extract model and feature extractor
60
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
61
-
62
- if 'detr' in model_name:
63
-
64
- model = DetrForObjectDetection.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- elif 'yolos' in model_name:
 
 
 
 
 
67
 
68
- model = YolosForObjectDetection.from_pretrained(model_name)
 
 
 
69
 
70
- tb_label = ""
71
- if validators.url(url_input):
72
- image = Image.open(requests.get(url_input, stream=True).raw)
73
- tb_label = "Confidence Values URL"
74
 
75
- elif image_input:
76
- image = image_input
77
- tb_label = "Confidence Values Upload"
78
 
79
- #Make prediction
80
- processed_output_list = make_prediction(image, feature_extractor, model)
81
- print("After make_prediction" + str(processed_output_list))
82
- processed_outputs = processed_output_list[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- #Visualize prediction
85
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
 
 
 
 
 
 
86
 
87
- # return [viz_img, processed_outputs]
88
- # print(type(viz_img))
89
-
90
- final_str_abv = ""
91
- final_str_else = ""
92
- for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
93
- box = [round(i, 2) for i in box.tolist()]
94
- if score.item() >= threshold:
95
- final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
96
- else:
97
- final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
98
-
99
- # https://docs.python.org/3/library/string.html#format-examples
100
- final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
101
-
102
- return viz_img, final_str
103
 
104
  def set_example_image(example: list) -> dict:
105
  return gr.Image.update(value=example[0])
@@ -119,10 +143,11 @@ Links to HuggingFace Models:
119
  - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
120
  - [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5)
121
  - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
 
122
 
123
  """
124
 
125
- models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300']
126
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
127
 
128
  # twitter_link = """
 
6
  import pathlib
7
  from PIL import Image
8
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
+ from ultralyticsplus import YOLO, render_result
10
 
11
  import os
12
 
 
59
 
60
  #Extract model and feature extractor
61
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
62
+
63
+ if 'yolov8' in model_name:
64
+
65
+ model = YOLO(model_name)
66
+ # set model parameters
67
+ model.overrides['conf'] = 0.25 # NMS confidence threshold
68
+ model.overrides['iou'] = 0.45 # NMS IoU threshold
69
+ model.overrides['agnostic_nms'] = False # NMS class-agnostic
70
+ model.overrides['max_det'] = 1000 # maximum number of detections per image
71
+
72
+ results = model.predict(image_input)
73
+
74
+ render = render_result(model=model, image=image_input, result=results[0])
75
+
76
+ return render, ""
77
 
78
+ # for result in results:
79
+ # # https://docs.ultralytics.com/modes/predict/#key-features-of-predict-mode
80
+ # #TODO
81
+ # im_array = result.plot()
82
+ # im = Image.fromarray(im_array[..., ::=1])
83
+
84
 
85
+ else:
86
+ if 'detr' in model_name:
87
+
88
+ model = DetrForObjectDetection.from_pretrained(model_name)
89
 
90
+ elif 'yolos' in model_name:
 
 
 
91
 
92
+ model = YolosForObjectDetection.from_pretrained(model_name)
 
 
93
 
94
+ tb_label = ""
95
+ if validators.url(url_input):
96
+ image = Image.open(requests.get(url_input, stream=True).raw)
97
+ tb_label = "Confidence Values URL"
98
+
99
+ elif image_input:
100
+ image = image_input
101
+ tb_label = "Confidence Values Upload"
102
+
103
+ #Make prediction
104
+ processed_output_list = make_prediction(image, feature_extractor, model)
105
+ print("After make_prediction" + str(processed_output_list))
106
+ processed_outputs = processed_output_list[0]
107
+
108
+ #Visualize prediction
109
+ viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
110
+
111
+ # return [viz_img, processed_outputs]
112
+ # print(type(viz_img))
113
 
114
+ final_str_abv = ""
115
+ final_str_else = ""
116
+ for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
117
+ box = [round(i, 2) for i in box.tolist()]
118
+ if score.item() >= threshold:
119
+ final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
120
+ else:
121
+ final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
122
 
123
+ # https://docs.python.org/3/library/string.html#format-examples
124
+ final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
125
+
126
+ return viz_img, final_str
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def set_example_image(example: list) -> dict:
129
  return gr.Image.update(value=example[0])
 
143
  - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
144
  - [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5)
145
  - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
146
+ - [mshamrai/yolov8x-visdrone](https://huggingface.co/mshamrai/yolov8x-visdrone)
147
 
148
  """
149
 
150
+ models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300', 'mshamrai/yolov8x-visdrone']
151
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
152
 
153
  # twitter_link = """