MRBagherifar
commited on
Commit
•
75d7b65
1
Parent(s):
6f66955
Update script.py
Browse files
script.py
CHANGED
@@ -26,21 +26,16 @@ class ONNXWorker:
|
|
26 |
self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
|
27 |
|
28 |
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
29 |
-
"""
|
30 |
-
:param image:
|
31 |
-
:return:
|
32 |
-
"""
|
33 |
|
34 |
-
|
35 |
-
|
|
|
36 |
|
37 |
def predict_image(self, image: np.ndarray) -> list():
|
38 |
-
"""Run inference using ONNX runtime.
|
39 |
-
:param image: Input image as numpy array.
|
40 |
-
:return: A list with logits and confidences.
|
41 |
-
"""
|
42 |
|
43 |
-
|
|
|
|
|
44 |
|
45 |
return logits.tolist()
|
46 |
|
@@ -56,7 +51,7 @@ def make_submission(test_metadata, model_path, output_csv_path="./submission.csv
|
|
56 |
image_path = os.path.join(images_root_path, row.filename)
|
57 |
|
58 |
test_image = Image.open(image_path).convert("RGB")
|
59 |
-
test_image_resized = np.asarray(test_image.resize((
|
60 |
|
61 |
logits = model.predict_image(test_image_resized)
|
62 |
|
|
|
26 |
self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
|
27 |
|
28 |
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
new_size = (384, 384)
|
31 |
+
return np.array(Image.fromarray(image).resize(new_size))
|
32 |
+
|
33 |
|
34 |
def predict_image(self, image: np.ndarray) -> list():
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
"""Run inference using ONNX runtime."""
|
37 |
+
resized_image = self._resize_image(image)
|
38 |
+
logits = self.ort_session.run(None, {"input": resized_image})
|
39 |
|
40 |
return logits.tolist()
|
41 |
|
|
|
51 |
image_path = os.path.join(images_root_path, row.filename)
|
52 |
|
53 |
test_image = Image.open(image_path).convert("RGB")
|
54 |
+
test_image_resized = np.asarray(test_image.resize((384, 384)))
|
55 |
|
56 |
logits = model.predict_image(test_image_resized)
|
57 |
|