SnakeCLEF-resnet / script.py
poojapremnath's picture
Upload two files
4a57161 verified
raw
history blame contribute delete
No virus
3.19 kB
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch
class PytorchWorker:
"""Run inference using PyTorch."""
def __init__(self, model_path: str, model_name: str, number_of_categories: int):
def _load_model(model_name, model_path):
print("Setting up Pytorch Model")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
model_ckpt = torch.load(model_path, map_location=self.device)
model.load_state_dict(model_ckpt)
return model.to(self.device).eval()
self.model = _load_model(model_name, model_path)
self.transforms = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(self, image: Image.Image) -> list:
"""Run inference using PyTorch.
:param image: Input image as PIL Image.
:return: A list with logits.
"""
image_tensor = self.transforms(image).unsqueeze(0).to(self.device)
logits = self.model(image_tensor)
return logits.tolist()
def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
"""Make submission with given metadata and model."""
model = PytorchWorker(model_path, model_name, number_of_categories=1604) # Adjust number_of_categories as needed
predictions = []
observation_predictions = {}
for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
image_path = os.path.join(images_root_path, row['filename'])
test_image = Image.open(image_path).convert("RGB")
logits = model.predict_image(test_image)
predicted_class = np.argmax(logits)
obs_id = row['observation_id']
if obs_id not in observation_predictions:
observation_predictions[obs_id] = []
observation_predictions[obs_id].append(predicted_class)
final_predictions = {obs_id: max(set(preds), key=preds.count) for obs_id, preds in observation_predictions.items()}
output_df = pd.DataFrame(list(final_predictions.items()), columns=['observation_id', 'class_id'])
output_df.to_csv(output_csv_path, index=False)
if __name__ == "__main__":
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
MODEL_PATH = "resnet_classifier.pth" # Ensure this matches the filename of your model
MODEL_NAME = "resnet50" # Adjust this to your specific model
metadata_file_path = "./SnakeCLEF2023_TestMetadata.csv"
test_metadata = pd.read_csv(metadata_file_path)
make_submission(
test_metadata=test_metadata,
model_path=MODEL_PATH,
model_name=MODEL_NAME
)