efficientNetV2 / script.py
xcssgzs's picture
Update script.py
15e79c8 verified
raw
history blame contribute delete
No virus
3.07 kB
import os
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from model import efficientnetv2_l as create_model
def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'):
img_size = {"s": [384, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "s"
data_transform = transforms.Compose(
[transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
id_list = test_metadata['observation_id'].tolist()
img_name_list = test_metadata['filename'].tolist()
print(os.path.abspath(os.path.dirname(__file__)))
id2classId = dict()
id2prob = dict()
prob_list = list()
classId_list = list()
for img_name in img_name_list:
img_path = os.path.join(root_path, img_name)
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path).convert('RGB')
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
# predict class
output = model(img.to(device)).cpu()
predict = torch.softmax(output, dim=1)
probs, classesId = torch.max(predict, dim=1)
prob = probs.data.numpy().tolist()[0]
classesId = classesId.data.numpy().tolist()[0]
prob_list.append(prob)
classId_list.append(classesId)
for i, id in enumerate(id_list):
if id not in id2classId.keys():
id2classId[id] = classId_list[i]
id2prob[id] = prob_list[i]
else:
if prob_list[i] > id2prob[id]:
id2classId[id] = classId_list[i]
id2prob[id] = prob_list[i]
classes = list()
for id in id_list:
classes.append(str(id2classId[id]))
test_metadata["class_id"] = classes
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
if __name__ == '__main__':
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
root_path = '/tmp/data/private_testset'
# root_path = "../../data_set/flower_data/val/n1"
# json_file = open(json_path, "r")
# index2class = json.load(json_file)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# create model
model = create_model(num_classes=1784).to(device)
# load model weights
model_weight_path = "efficientNetV2.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
# metadata_file_path = "./test1.csv"
test_metadata = pd.read_csv(metadata_file_path)
predict(test_metadata, root_path)