File size: 1,590 Bytes
5d666d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337cf38
5d666d5
337cf38
 
5d666d5
337cf38
5d666d5
337cf38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pandas as pd
import numpy as np
import os
import subprocess
import sys
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch


# custom script arguments
CONFIG_PATH = 'models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6.py'
CHECKPOINT_PATH = "models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6_epoch_6_20240514-de00365e.pth"
SCORE_THRESHOLD = 0.2

def run_inference(input_csv, output_csv, data_root_path):
    """Load model and dataloader and run inference."""

    if not data_root_path.endswith('/'):
        data_root_path += '/'
    data_cfg_opts = [
        f'test_dataloader.dataset.data_root=',
        f'test_dataloader.dataset.ann_file={input_csv}',
        f'test_dataloader.dataset.data_prefix={data_root_path}']

    inference = subprocess.Popen([
        'python', '-m',
        'tools.test_generate_result_pre-consensus',
        CONFIG_PATH, CHECKPOINT_PATH,
        output_csv,
        '--threshold', str(SCORE_THRESHOLD),
        '--no-scores',
        '--cfg-options'] + data_cfg_opts)
    return_code = inference.wait()
    if return_code != 0:
        print(f'Inference crashed with exit code {return_code}')
        sys.exit(return_code)
    print(f'Written {output_csv}')


if __name__ == "__main__":

    import zipfile

    with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
        zip_ref.extractall("/tmp/data")

    metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"

    run_inference(metadata_file_path, "./submission.csv", "/tmp/data/private_testset")