File size: 3,070 Bytes
fd78319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from model import resnet101


def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'):

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    # img_name_list = ["1163.jpg", "1164.jpg"]
    # id_list = [1, 2]
    id_list = test_metadata['observation_id'].tolist()
    img_name_list = test_metadata['filename'].tolist()
    print(os.path.abspath(os.path.dirname(__file__)))

    id2classId = dict()
    id2prob = dict()
    prob_list = list()
    classId_list = list()
    for img_name in img_name_list:
        img_path = os.path.join(root_path, img_name)
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path).convert('RGB')
        img = data_transform(img)
        img = torch.unsqueeze(img, dim=0)

        with torch.no_grad():
            # predict class
            output = model(img.to(device)).cpu()
            predict = torch.softmax(output, dim=1)
            probs, classesId = torch.max(predict, dim=1)
            prob = probs.data.numpy().tolist()[0]
            classesId = classesId.data.numpy().tolist()[0]
            prob_list.append(prob)
            classId_list.append(classesId)

    for i, id in enumerate(id_list):
        if id not in id2classId.keys():
            id2classId[id] = classId_list[i]
            id2prob[id] = prob_list[i]
        else:
            if prob_list[i] > id2prob[id]:
                id2classId[id] = classId_list[i]
                id2prob[id] = prob_list[i]
    classes = list()
    for id in id_list:
        classes.append(str(id2classId[id]))
    test_metadata["class_id"] = classes

    user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
    user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)


if __name__ == '__main__':

    import zipfile

    with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
        zip_ref.extractall("/tmp/data")
    root_path = '/tmp/data/private_testset'

    # root_path = "../../data_set/flower_data/val/n1"

    # json_file = open(json_path, "r")
    # index2class = json.load(json_file)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # create model
    model = resnet101(num_classes=1784).to(device)

    # load model weights
    weights_path = "./resNet101.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()

    metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
    # metadata_file_path = "./test1.csv"
    test_metadata = pd.read_csv(metadata_file_path)
    predict(test_metadata, root_path)