File size: 5,752 Bytes
5d666d5
 
 
 
 
 
 
 
 
7318fe0
5d666d5
7318fe0
5d666d5
7318fe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d666d5
 
 
 
337cf38
5d666d5
337cf38
 
5d666d5
7318fe0
 
 
337cf38
7318fe0
5d666d5
7318fe0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pandas as pd
import numpy as np
import os
import sys
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch
from multiprocessing import Pool

from mmpretrain.apis import ImageClassificationInferencer, FeatureExtractor

import mmpretrain.utils.progress as progress
progress.disable_progress_bar = True


sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.realpath(__file__))))


def load_image(path : str, images_root_path="/tmp/data/private_testset"):
    return np.array(Image.open(os.path.join(images_root_path, path)))[:, :, ::-1]

def rerank_poison(posison_status_list : pd.DataFrame, pred_scores : np.array) -> tuple[int, float]:
    class_id = np.argmax(pred_scores)
    class_score = np.max(pred_scores)

    poisonous = posison_status_list.copy()
    poisonous['score'] = pred_scores
    poisonous.sort_values(by=['score'], ascending=False, inplace=True)
    first_poisonous = poisonous[poisonous['poisonous'] == 1].iloc[0]

    if 13 * first_poisonous['score'] > class_score:
        class_id = first_poisonous['class_id']
        class_score = first_poisonous['score']

    return class_id, class_score


def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
    """Make submission with given """

    #inferencer = ImageClassificationInferencer(model=model_name, pretrained=model_path, device="cuda:0")
    feature_extractor = FeatureExtractor(model=model_name, pretrained=model_path, device="cuda:0")

    predictions = []
    prediction_scores = []
    prediction_scores_dict = {}
    prediction_feats_dict = {}
    obs_imgs_dict = {}

    BATCH_SIZE = 4
    p = Pool(BATCH_SIZE)
    # image_paths_next_batch = test_metadata['image_path'][0:BACTH_SIZE]
    # next_batch = p.map_async(load_image, image_paths_next_batch)
    for i in tqdm(range(int(np.ceil(test_metadata.shape[0] / BATCH_SIZE)))):
        # batch_imgs = next_batch.get()
        # image_paths_next_batch = test_metadata['image_path'][(i+1) * BACTH_SIZE:(i+2) * BATCH_SIZE]
        # next_batch = p.map_async(load_image, image_paths_next_batch)
        img_paths_batch = test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]
        batch_imgs = p.map(load_image, img_paths_batch)
        #batch_imgs = [np.array(Image.open(os.path.join(images_root_path, x)))[:, :, ::-1] for x in test_metadata['image_path'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE]]
        #results = inferencer(batch_imgs, batch_size=BATCH_SIZE)
        feats = feature_extractor(batch_imgs, batch_size=BATCH_SIZE)
        feats = (torch.stack([x[0] for x in feats], dim=0),)
        results = feature_extractor.model.head.task_heads['species'].predict(feats, img_paths=img_paths_batch)
        for res, f, obs_id, img_path in zip(results, feats[0], test_metadata['observation_id'][(i) * BATCH_SIZE:(i+1) * BATCH_SIZE], img_paths_batch):
            #pred_scores = res.species.pred_score.detach().cpu().numpy()
            pred_scores = res.pred_score.detach().cpu().numpy()
            #pred_scores = res['pred_scores']
            predictions.append(np.argmax(pred_scores))
            prediction_scores.append(pred_scores)
            prediction_scores_dict.setdefault(obs_id, []).append(pred_scores)
            prediction_feats_dict.setdefault(obs_id, []).append(f)
            obs_imgs_dict[obs_id] = img_path

    print('finished inference')

    test_metadata["class_id"] = predictions
    test_metadata["max_score"] = prediction_scores

    poison_status_list = pd.read_csv('poison_status_list.csv')
    poison_status_list = poison_status_list.sort_values(by=['class_id'])

    poison_classes = set(poison_status_list[poison_status_list['poisonous'] == 1]['class_id'])

    for obs_id, pred_feats in tqdm(prediction_feats_dict.items()):
        #fusion_scores = np.prod(np.array(pred_scores), axis=0)
        #fusion_scores = np.mean(np.array(pred_scores), axis=0)
        #fusion_scores = np.max(np.array(pred_scores), axis=0)
        fusion_feats = torch.mean(torch.stack(pred_feats, dim=0), dim=0, keepdim=True)
        results = feature_extractor.model.head.task_heads['species'].predict((fusion_feats,), img_paths=[obs_imgs_dict[obs_id]])
        fusion_scores = results[0].pred_score.detach().cpu().numpy()
        class_score = np.max(fusion_scores)
        class_id = np.argmax(fusion_scores)
        class_id, class_score = rerank_poison(poison_status_list, fusion_scores)
        entropy = -np.sum(fusion_scores * np.log(fusion_scores))
        if entropy > 7 or (class_id not in poison_classes and entropy > 2.5):
            class_id = -1
        #if class_score < 0.1:
        #    class_id = -1
        test_metadata.loc[test_metadata["observation_id"] == obs_id, "class_id"] = class_id
        test_metadata.loc[test_metadata["observation_id"] == obs_id, "max_score"] = class_score

    user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
    user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)


if __name__ == "__main__":

    import zipfile

    with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
        zip_ref.extractall("/tmp/data")

    MODEL_PATH = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4_epoch_2_20240524-a429ecac.pth"
    MODEL_NAME = "swinv2_base_w24_b16x8-fp16_fungi+val2_res_384_genus-loss_no-unknown-target-zero_metadata_epochs_4.py"

    metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
    test_metadata = pd.read_csv(metadata_file_path)

    make_submission(
        test_metadata=test_metadata,
        model_path=MODEL_PATH,
        model_name=MODEL_NAME
    )