File size: 1,590 Bytes
5d666d5 337cf38 5d666d5 337cf38 5d666d5 337cf38 5d666d5 337cf38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import pandas as pd
import numpy as np
import os
import subprocess
import sys
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch
# custom script arguments
CONFIG_PATH = 'models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6.py'
CHECKPOINT_PATH = "models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6_epoch_6_20240514-de00365e.pth"
SCORE_THRESHOLD = 0.2
def run_inference(input_csv, output_csv, data_root_path):
"""Load model and dataloader and run inference."""
if not data_root_path.endswith('/'):
data_root_path += '/'
data_cfg_opts = [
f'test_dataloader.dataset.data_root=',
f'test_dataloader.dataset.ann_file={input_csv}',
f'test_dataloader.dataset.data_prefix={data_root_path}']
inference = subprocess.Popen([
'python', '-m',
'tools.test_generate_result_pre-consensus',
CONFIG_PATH, CHECKPOINT_PATH,
output_csv,
'--threshold', str(SCORE_THRESHOLD),
'--no-scores',
'--cfg-options'] + data_cfg_opts)
return_code = inference.wait()
if return_code != 0:
print(f'Inference crashed with exit code {return_code}')
sys.exit(return_code)
print(f'Written {output_csv}')
if __name__ == "__main__":
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
run_inference(metadata_file_path, "./submission.csv", "/tmp/data/private_testset")
|