import os import json import pandas as pd import torch from PIL import Image from torchvision import transforms from model import resnet101 def predict(test_metadata, root_path='/tmp/data/private_testset', output_csv_path='./submission.csv'): data_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # load image # img_name_list = ["1163.jpg", "1164.jpg"] # id_list = [1, 2] id_list = test_metadata['observation_id'].tolist() img_name_list = test_metadata['filename'].tolist() print(os.path.abspath(os.path.dirname(__file__))) id2classId = dict() id2prob = dict() prob_list = list() classId_list = list() for img_name in img_name_list: img_path = os.path.join(root_path, img_name) assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img ='RGB') img = data_transform(img) img = torch.unsqueeze(img, dim=0) with torch.no_grad(): # predict class output = model( predict = torch.softmax(output, dim=1) probs, classesId = torch.max(predict, dim=1) prob =[0] classesId =[0] prob_list.append(prob) classId_list.append(classesId) for i, id in enumerate(id_list): if id not in id2classId.keys(): id2classId[id] = classId_list[i] id2prob[id] = prob_list[i] else: if prob_list[i] > id2prob[id]: id2classId[id] = classId_list[i] id2prob[id] = prob_list[i] classes = list() for id in id_list: classes.append(str(id2classId[id])) test_metadata["class_id"] = classes user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) if __name__ == '__main__': import zipfile with zipfile.ZipFile("/tmp/data/", 'r') as zip_ref: zip_ref.extractall("/tmp/data") root_path = '/tmp/data/private_testset' # root_path = "../../data_set/flower_data/val/n1" # json_file = open(json_path, "r") # index2class = json.load(json_file) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # create model model = resnet101(num_classes=1784).to(device) # load model weights weights_path = "./resNet101.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) # prediction model.eval() metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv" # metadata_file_path = "./test1.csv" test_metadata = pd.read_csv(metadata_file_path) predict(test_metadata, root_path)