File size: 3,182 Bytes
0a8dd97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import json

import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from model import resnet101


def predict(test_metadata, index2class, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'):

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_name_list = ["../tulip.jpg", "../rose.jpg"]
    id_list = test_metadata['observation_id'].tolist()
    img_name_list = test_metadata['filename'].tolist()
    img_list = []
    print(os.path.abspath(os.path.dirname(__file__)))
    for img_name in img_name_list:
        img_path = os.path.join(root_path, img_name)
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        img = data_transform(img)
        img_list.append(img)

    # batch img
    batch_img = torch.stack(img_list, dim=0)

    with torch.no_grad():
        # predict class
        output = model(batch_img.to(device)).cpu()
        predict = torch.softmax(output, dim=1)
        probs, classesId = torch.max(predict, dim=1)
        probs = probs.data.numpy().tolist()
        classesId = classesId.data.numpy().tolist()
        id2classId = dict()
        id2prob = dict()
        for i, id in enumerate(id_list):
            if id not in id2classId.keys():
                id2classId[id] = classesId[i]
                id2prob[id] = probs[i]
            else:
                if probs[i] > id2prob[id]:
                    id2classId[id] = classesId[i]
                    id2prob[id] = probs[i]
        classes = list()
        for id in id_list:
            classes.append(index2class[str(id2classId[id])])
        test_metadata["class_id"] = classes

        user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
        user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)


if __name__ == '__main__':

    import zipfile

    # with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
    #     zip_ref.extractall("/tmp/data")

    root_path = '/tmp/data/private_testset'
    # root_path = "../../data_set/flower_data/val/n1"

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    index2class = json.load(json_file)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # create model
    model = resnet101(num_classes=1784).to(device)

    # load model weights
    weights_path = "./resNet101.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()

    metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
    # metadata_file_path = "./test.csv"
    test_metadata = pd.read_csv(metadata_file_path)
    predict(test_metadata, index2class, root_path)