Lonly-geese commited on
Commit
a37893b
1 Parent(s): da36ba0

Delete script.py

Browse files
Files changed (1) hide show
  1. script.py +0 -97
script.py DELETED
@@ -1,97 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- from PIL import Image
4
- import onnxruntime as ort
5
- import os
6
- from tqdm import tqdm
7
-
8
-
9
- def is_gpu_available():
10
- """Check if the python package `onnxruntime-gpu` is installed."""
11
- return ort.get_device() == "GPU"
12
-
13
-
14
- class ONNXWorker:
15
- """Run inference using ONNX runtime."""
16
-
17
- def __init__(self, onnx_path: str):
18
- print("Setting up ONNX runtime session.")
19
- self.use_gpu = is_gpu_available()
20
- if self.use_gpu:
21
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
22
- else:
23
- providers = ["CPUExecutionProvider"]
24
-
25
- print(f"Using {providers}")
26
- self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
27
-
28
- def _resize_image(self, image: np.ndarray) -> np.ndarray:
29
- """
30
-
31
- :param image:
32
- :return:
33
- """
34
-
35
- newsize = (300, 300)
36
- im1 = im1.resize(newsize)
37
-
38
- def predict_image(self, image: np.ndarray) -> list():
39
- """Run inference using ONNX runtime.
40
-
41
- :param image: Input image as numpy array.
42
- :return: A list with logits and confidences.
43
- """
44
-
45
- logits= self.ort_session.run(None, {"input": image.astype(dtype=np.float32)})
46
-
47
- return logits
48
-
49
-
50
- def make_submission(test_metadata, model_path, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
51
- """Make submission with given """
52
-
53
- model = ONNXWorker(model_path)
54
-
55
- predictions = []
56
-
57
- for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
58
- image_path = os.path.join(images_root_path, row.image_path)
59
-
60
- test_image = Image.open(image_path).convert("RGB")
61
- test_image_resized = np.asarray(test_image.resize((256, 256)))
62
- mean=np.array([0.485, 0.456, 0.406])
63
- std=np.array([0.229, 0.224, 0.225])
64
- mean=mean[None,None,:]
65
- std=std[None,None,:]
66
- test_image_resized=test_image_resized/255
67
- test_image_resized=(test_image_resized-mean)/std
68
- test_image_resized=test_image_resized.astype(np.float32)
69
- test_image_resized=test_image_resized[None,:,:,:].transpose(0,3,1,2)
70
-
71
-
72
- logits = model.predict_image(test_image_resized)[0]
73
-
74
- predictions.append(np.argmax(logits))
75
-
76
- test_metadata["class_id"] = predictions
77
-
78
- user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
79
- user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
80
-
81
-
82
- if __name__ == "__main__":
83
-
84
- import zipfile
85
-
86
- with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
87
- zip_ref.extractall("/tmp/data")
88
-
89
- ONNX_MODEL_PATH = "./convt_gem.onnx"
90
-
91
- metadata_file_path = "SnakeCLEF2024-TestMetadata.csv"
92
- test_metadata = pd.read_csv(metadata_file_path)
93
-
94
- make_submission(
95
- test_metadata=test_metadata,
96
- model_path=ONNX_MODEL_PATH,
97
- )