Hamza011 commited on
Commit
a6ff289
·
verified ·
1 Parent(s): 6706dad

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +146 -0
inference.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+
6
+ from utils import draw_detections
7
+
8
+
9
+ class YOLOv10:
10
+ def __init__(self, path):
11
+
12
+ # Initialize model
13
+ self.initialize_model(path)
14
+
15
+ def __call__(self, image):
16
+ return self.detect_objects(image)
17
+
18
+ def initialize_model(self, path):
19
+ self.session = onnxruntime.InferenceSession(
20
+ path, providers=onnxruntime.get_available_providers()
21
+ )
22
+ # Get model info
23
+ self.get_input_details()
24
+ self.get_output_details()
25
+
26
+ def detect_objects(self, image, conf_threshold=0.3):
27
+ input_tensor = self.prepare_input(image)
28
+
29
+ # Perform inference on the image
30
+ new_image = self.inference(image, input_tensor, conf_threshold)
31
+
32
+ return new_image
33
+
34
+ def prepare_input(self, image):
35
+ self.img_height, self.img_width = image.shape[:2]
36
+
37
+ input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
38
+
39
+ # Resize input image
40
+ input_img = cv2.resize(input_img, (self.input_width, self.input_height))
41
+
42
+ # Scale input pixel values to 0 to 1
43
+ input_img = input_img / 255.0
44
+ input_img = input_img.transpose(2, 0, 1)
45
+ input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
46
+
47
+ return input_tensor
48
+
49
+ def inference(self, image, input_tensor, conf_threshold=0.3):
50
+ start = time.perf_counter()
51
+ outputs = self.session.run(
52
+ self.output_names, {self.input_names[0]: input_tensor}
53
+ )
54
+
55
+ print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
56
+ boxes, scores, class_ids, = self.process_output(outputs, conf_threshold)
57
+ return self.draw_detections(image, boxes, scores, class_ids)
58
+
59
+ def process_output(self, output, conf_threshold=0.3):
60
+ predictions = np.squeeze(output[0])
61
+
62
+ # Filter out object confidence scores below threshold
63
+ scores = predictions[:, 4]
64
+ predictions = predictions[scores > conf_threshold, :]
65
+ scores = scores[scores > conf_threshold]
66
+
67
+ if len(scores) == 0:
68
+ return [], [], []
69
+
70
+ # Get the class with the highest confidence
71
+ class_ids = predictions[:, 5].astype(int)
72
+
73
+ # Get bounding boxes for each object
74
+ boxes = self.extract_boxes(predictions)
75
+
76
+ return boxes, scores, class_ids
77
+
78
+ def extract_boxes(self, predictions):
79
+ # Extract boxes from predictions
80
+ boxes = predictions[:, :4]
81
+
82
+ # Scale boxes to original image dimensions
83
+ boxes = self.rescale_boxes(boxes)
84
+
85
+ # Convert boxes to xyxy format
86
+ #boxes = xywh2xyxy(boxes)
87
+
88
+ return boxes
89
+
90
+ def rescale_boxes(self, boxes):
91
+ # Rescale boxes to original image dimensions
92
+ input_shape = np.array(
93
+ [self.input_width, self.input_height, self.input_width, self.input_height]
94
+ )
95
+ boxes = np.divide(boxes, input_shape, dtype=np.float32)
96
+ boxes *= np.array(
97
+ [self.img_width, self.img_height, self.img_width, self.img_height]
98
+ )
99
+ return boxes
100
+
101
+ def draw_detections(self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4):
102
+ return draw_detections(
103
+ image, boxes, scores, class_ids, mask_alpha
104
+ )
105
+
106
+ def get_input_details(self):
107
+ model_inputs = self.session.get_inputs()
108
+ self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
109
+
110
+ self.input_shape = model_inputs[0].shape
111
+ self.input_height = self.input_shape[2]
112
+ self.input_width = self.input_shape[3]
113
+
114
+ def get_output_details(self):
115
+ model_outputs = self.session.get_outputs()
116
+ self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
117
+
118
+
119
+ if __name__ == "__main__":
120
+ import requests
121
+ import tempfile
122
+ from huggingface_hub import hf_hub_download
123
+
124
+ model_file = hf_hub_download(
125
+ repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
126
+ )
127
+
128
+ yolov8_detector = YOLOv10(model_file)
129
+
130
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
131
+ f.write(
132
+ requests.get(
133
+ "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
134
+ ).content
135
+ )
136
+ f.seek(0)
137
+ img = cv2.imread(f.name)
138
+
139
+ # # Detect Objects
140
+ combined_image = yolov8_detector.detect_objects(img)
141
+
142
+
143
+ # Draw detections
144
+ cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
145
+ cv2.imshow("Output", combined_image)
146
+ cv2.waitKey(0)