chychiu commited on
Commit
815a0c4
1 Parent(s): 23728f1

delet ckpt

Browse files
Files changed (2) hide show
  1. metaformer-s-224.ckpt +0 -3
  2. script.py +18 -5
metaformer-s-224.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:30285f565bdeb54f1ce8bfea244bca3d47c04746a649cd319d02f31ec553fd49
3
- size 331278138
 
 
 
 
script.py CHANGED
@@ -15,11 +15,13 @@ def is_gpu_available():
15
  WIDTH = 224
16
  HEIGHT = 224
17
 
 
 
18
 
19
  class PytorchWorker:
20
  """Run inference using ONNX runtime."""
21
 
22
- def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1604):
23
 
24
  def _load_model(model_name, model_path):
25
 
@@ -60,7 +62,7 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
60
  predictions = []
61
 
62
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
63
- image_path = os.path.join(images_root_path, row.image_path)
64
 
65
  test_image = Image.open(image_path).convert("RGB")
66
 
@@ -73,17 +75,28 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
73
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
74
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
 
 
 
79
  import zipfile
80
 
81
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
82
  zip_ref.extractall("/tmp/data")
83
 
84
- MODEL_PATH = "metaformer-s-224.pth"
85
- MODEL_NAME = "caformer_s18.sail_in22k"
86
-
87
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
88
  test_metadata = pd.read_csv(metadata_file_path)
89
 
 
15
  WIDTH = 224
16
  HEIGHT = 224
17
 
18
+ MODEL_PATH = "metaformer-s-224.pth"
19
+ MODEL_NAME = "caformer_s18.sail_in22k"
20
 
21
  class PytorchWorker:
22
  """Run inference using ONNX runtime."""
23
 
24
+ def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1605):
25
 
26
  def _load_model(model_name, model_path):
27
 
 
62
  predictions = []
63
 
64
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
65
+ image_path = os.path.join(images_root_path, row.image_path.replace("jpg", "JPG"))
66
 
67
  test_image = Image.open(image_path).convert("RGB")
68
 
 
75
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
76
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
77
 
78
+ def test_submission():
79
+
80
+ metadata_file_path = "../val_mini.csv"
81
+ test_metadata = pd.read_csv(metadata_file_path)
82
+
83
+ make_submission(
84
+ test_metadata=test_metadata,
85
+ model_path=MODEL_PATH,
86
+ model_name=MODEL_NAME,
87
+ images_root_path="../data/DF_FULL/"
88
+ )
89
+
90
 
91
  if __name__ == "__main__":
92
 
93
+ # test_submission()
94
+
95
  import zipfile
96
 
97
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
98
  zip_ref.extractall("/tmp/data")
99
 
 
 
 
100
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
101
  test_metadata = pd.read_csv(metadata_file_path)
102