poojapremnath commited on
Commit
4a57161
1 Parent(s): fba9291

Upload two files

Browse files
Files changed (2) hide show
  1. resnet_classifier.pth +3 -0
  2. script.py +81 -0
resnet_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ca5ff2798636210fe50f065081559a6de41d940ff8a1b6c3439fcca1f646f5
3
+ size 109894058
script.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ from tqdm import tqdm
5
+ import timm
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ import torch
9
+
10
+ class PytorchWorker:
11
+ """Run inference using PyTorch."""
12
+
13
+ def __init__(self, model_path: str, model_name: str, number_of_categories: int):
14
+
15
+ def _load_model(model_name, model_path):
16
+ print("Setting up Pytorch Model")
17
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ print(f"Using device: {self.device}")
19
+
20
+ model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
21
+ model_ckpt = torch.load(model_path, map_location=self.device)
22
+ model.load_state_dict(model_ckpt)
23
+ return model.to(self.device).eval()
24
+
25
+ self.model = _load_model(model_name, model_path)
26
+ self.transforms = T.Compose([
27
+ T.Resize((224, 224)),
28
+ T.ToTensor(),
29
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ def predict_image(self, image: Image.Image) -> list:
33
+ """Run inference using PyTorch.
34
+ :param image: Input image as PIL Image.
35
+ :return: A list with logits.
36
+ """
37
+ image_tensor = self.transforms(image).unsqueeze(0).to(self.device)
38
+ logits = self.model(image_tensor)
39
+ return logits.tolist()
40
+
41
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
42
+ """Make submission with given metadata and model."""
43
+
44
+ model = PytorchWorker(model_path, model_name, number_of_categories=1604) # Adjust number_of_categories as needed
45
+
46
+ predictions = []
47
+ observation_predictions = {}
48
+
49
+ for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
50
+ image_path = os.path.join(images_root_path, row['filename'])
51
+ test_image = Image.open(image_path).convert("RGB")
52
+ logits = model.predict_image(test_image)
53
+ predicted_class = np.argmax(logits)
54
+
55
+ obs_id = row['observation_id']
56
+ if obs_id not in observation_predictions:
57
+ observation_predictions[obs_id] = []
58
+ observation_predictions[obs_id].append(predicted_class)
59
+
60
+ final_predictions = {obs_id: max(set(preds), key=preds.count) for obs_id, preds in observation_predictions.items()}
61
+
62
+ output_df = pd.DataFrame(list(final_predictions.items()), columns=['observation_id', 'class_id'])
63
+ output_df.to_csv(output_csv_path, index=False)
64
+
65
+ if __name__ == "__main__":
66
+ import zipfile
67
+
68
+ with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
69
+ zip_ref.extractall("/tmp/data")
70
+
71
+ MODEL_PATH = "resnet_classifier.pth" # Ensure this matches the filename of your model
72
+ MODEL_NAME = "resnet50" # Adjust this to your specific model
73
+
74
+ metadata_file_path = "./SnakeCLEF2023_TestMetadata.csv"
75
+ test_metadata = pd.read_csv(metadata_file_path)
76
+
77
+ make_submission(
78
+ test_metadata=test_metadata,
79
+ model_path=MODEL_PATH,
80
+ model_name=MODEL_NAME
81
+ )