MRBagherifar commited on
Commit
fa119c0
1 Parent(s): ae5af25

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +43 -28
script.py CHANGED
@@ -1,59 +1,72 @@
1
  import pandas as pd
2
  import numpy as np
3
- from PIL import Image
4
  import onnxruntime as ort
5
  import os
6
  from tqdm import tqdm
7
-
 
 
 
8
 
9
  def is_gpu_available():
10
  """Check if the python package `onnxruntime-gpu` is installed."""
11
- return ort.get_device() == "GPU"
12
 
13
 
14
- class ONNXWorker:
15
  """Run inference using ONNX runtime."""
16
 
17
- def __init__(self, onnx_path: str):
18
- print("Setting up ONNX runtime session.")
19
- self.use_gpu = is_gpu_available()
20
- if self.use_gpu:
21
- providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
22
- else:
23
- providers = ["CPUExecutionProvider"]
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- print(f"Using {providers}")
26
- self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
27
 
28
- def _resize_image(self, image: np.ndarray) -> np.ndarray:
 
 
29
 
30
- new_size = (384, 384)
31
- return np.array(Image.fromarray(image).resize(new_size))
32
-
33
 
34
  def predict_image(self, image: np.ndarray) -> list():
 
 
 
 
35
 
36
- """Run inference using ONNX runtime."""
37
- resized_image = self._resize_image(image)
38
- logits = self.ort_session.run(None, {"input": resized_image})
39
 
40
  return logits.tolist()
41
 
42
 
43
- def make_submission(test_metadata, model_path, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
44
  """Make submission with given """
45
 
46
- model = ONNXWorker(model_path)
47
 
48
  predictions = []
49
 
50
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
51
- image_path = os.path.join(images_root_path, row.filename)
52
 
53
  test_image = Image.open(image_path).convert("RGB")
54
- test_image_resized = np.asarray(test_image.resize((384, 384)))
55
 
56
- logits = model.predict_image(test_image_resized)
57
 
58
  predictions.append(np.argmax(logits))
59
 
@@ -70,12 +83,14 @@ if __name__ == "__main__":
70
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
71
  zip_ref.extractall("/tmp/data")
72
 
73
- ONNX_MODEL_PATH = "./MetaFG_meta_2.onnx"
 
74
 
75
  metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
76
  test_metadata = pd.read_csv(metadata_file_path)
77
 
78
  make_submission(
79
  test_metadata=test_metadata,
80
- model_path=ONNX_MODEL_PATH,
81
- )
 
 
1
  import pandas as pd
2
  import numpy as np
 
3
  import onnxruntime as ort
4
  import os
5
  from tqdm import tqdm
6
+ import timm
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ import torch
10
 
11
  def is_gpu_available():
12
  """Check if the python package `onnxruntime-gpu` is installed."""
13
+ return torch.cuda.is_available()
14
 
15
 
16
+ class PytorchWorker:
17
  """Run inference using ONNX runtime."""
18
 
19
+ def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1604):
20
+
21
+ def _load_model(model_name, model_path):
22
+
23
+ print("Setting up Pytorch Model")
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ print(f"Using devide: {self.device}")
26
+
27
+ model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
+
29
+ # if not torch.cuda.is_available():
30
+ # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
31
+ # else:
32
+ # model_ckpt = torch.load(model_path)
33
+
34
+ model_ckpt = torch.load(model_path, map_location=self.device)
35
+ model.load_state_dict(model_ckpt)
36
+
37
+ return model.to(self.device).eval()
38
 
39
+ self.model = _load_model(model_name, model_path)
 
40
 
41
+ self.transforms = T.Compose([T.Resize((299, 299)),
42
+ T.ToTensor(),
43
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
44
 
 
 
 
45
 
46
  def predict_image(self, image: np.ndarray) -> list():
47
+ """Run inference using ONNX runtime.
48
+ :param image: Input image as numpy array.
49
+ :return: A list with logits and confidences.
50
+ """
51
 
52
+ logits = self.model(self.transforms(image).unsqueeze(0).to(self.device))
 
 
53
 
54
  return logits.tolist()
55
 
56
 
57
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
58
  """Make submission with given """
59
 
60
+ model = PytorchWorker(model_path, model_name)
61
 
62
  predictions = []
63
 
64
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
65
+ image_path = os.path.join(images_root_path, row.image_path)
66
 
67
  test_image = Image.open(image_path).convert("RGB")
 
68
 
69
+ logits = model.predict_image(test_image)
70
 
71
  predictions.append(np.argmax(logits))
72
 
 
83
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
84
  zip_ref.extractall("/tmp/data")
85
 
86
+ MODEL_PATH = "pytorch_model.bin"
87
+ MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
88
 
89
  metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
90
  test_metadata = pd.read_csv(metadata_file_path)
91
 
92
  make_submission(
93
  test_metadata=test_metadata,
94
+ model_path=MODEL_PATH,
95
+ model_name=MODEL_NAME
96
+ )