resNet101 / snake_script.py
xcssgzs's picture
Upload 4 files
0a8dd97 verified
raw
history blame
No virus
3.18 kB
import os
import json
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from model import resnet101
def predict(test_metadata, index2class, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'):
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_name_list = ["../tulip.jpg", "../rose.jpg"]
id_list = test_metadata['observation_id'].tolist()
img_name_list = test_metadata['filename'].tolist()
img_list = []
print(os.path.abspath(os.path.dirname(__file__)))
for img_name in img_name_list:
img_path = os.path.join(root_path, img_name)
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img = data_transform(img)
img_list.append(img)
# batch img
batch_img = torch.stack(img_list, dim=0)
with torch.no_grad():
# predict class
output = model(batch_img.to(device)).cpu()
predict = torch.softmax(output, dim=1)
probs, classesId = torch.max(predict, dim=1)
probs = probs.data.numpy().tolist()
classesId = classesId.data.numpy().tolist()
id2classId = dict()
id2prob = dict()
for i, id in enumerate(id_list):
if id not in id2classId.keys():
id2classId[id] = classesId[i]
id2prob[id] = probs[i]
else:
if probs[i] > id2prob[id]:
id2classId[id] = classesId[i]
id2prob[id] = probs[i]
classes = list()
for id in id_list:
classes.append(index2class[str(id2classId[id])])
test_metadata["class_id"] = classes
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
if __name__ == '__main__':
import zipfile
# with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
# zip_ref.extractall("/tmp/data")
root_path = '/tmp/data/private_testset'
# root_path = "../../data_set/flower_data/val/n1"
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
json_file = open(json_path, "r")
index2class = json.load(json_file)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# create model
model = resnet101(num_classes=1784).to(device)
# load model weights
weights_path = "./resNet101.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))
# prediction
model.eval()
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
# metadata_file_path = "./test.csv"
test_metadata = pd.read_csv(metadata_file_path)
predict(test_metadata, index2class, root_path)