|
import os |
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
from model import efficientnetv2_l as create_model |
|
|
|
|
|
def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'): |
|
|
|
img_size = {"s": [384, 384], |
|
"m": [384, 480], |
|
"l": [384, 480]} |
|
num_model = "s" |
|
|
|
data_transform = transforms.Compose( |
|
[transforms.Resize(img_size[num_model][1]), |
|
transforms.CenterCrop(img_size[num_model][1]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
|
id_list = test_metadata['observation_id'].tolist() |
|
img_name_list = test_metadata['filename'].tolist() |
|
print(os.path.abspath(os.path.dirname(__file__))) |
|
|
|
id2classId = dict() |
|
id2prob = dict() |
|
prob_list = list() |
|
classId_list = list() |
|
|
|
for img_name in img_name_list: |
|
img_path = os.path.join(root_path, img_name) |
|
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) |
|
img = Image.open(img_path).convert('RGB') |
|
img = data_transform(img) |
|
img = torch.unsqueeze(img, dim=0) |
|
|
|
with torch.no_grad(): |
|
|
|
output = model(img.to(device)).cpu() |
|
predict = torch.softmax(output, dim=1) |
|
probs, classesId = torch.max(predict, dim=1) |
|
prob = probs.data.numpy().tolist()[0] |
|
classesId = classesId.data.numpy().tolist()[0] |
|
prob_list.append(prob) |
|
classId_list.append(classesId) |
|
|
|
for i, id in enumerate(id_list): |
|
if id not in id2classId.keys(): |
|
id2classId[id] = classId_list[i] |
|
id2prob[id] = prob_list[i] |
|
else: |
|
if prob_list[i] > id2prob[id]: |
|
id2classId[id] = classId_list[i] |
|
id2prob[id] = prob_list[i] |
|
classes = list() |
|
for id in id_list: |
|
classes.append(str(id2classId[id])) |
|
test_metadata["class_id"] = classes |
|
|
|
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") |
|
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) |
|
|
|
|
|
if __name__ == '__main__': |
|
import zipfile |
|
|
|
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: |
|
zip_ref.extractall("/tmp/data") |
|
root_path = '/tmp/data/private_testset' |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model = create_model(num_classes=1784).to(device) |
|
|
|
|
|
model_weight_path = "efficientNetV2.pth" |
|
model.load_state_dict(torch.load(model_weight_path, map_location=device)) |
|
model.eval() |
|
|
|
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv" |
|
|
|
test_metadata = pd.read_csv(metadata_file_path) |
|
predict(test_metadata, root_path) |
|
|