|
import argparse |
|
|
|
from collections import defaultdict |
|
from statistics import mean |
|
from sys import meta_path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('scores') |
|
parser.add_argument('out') |
|
args = parser.parse_args() |
|
|
|
scores = pd.read_csv(args.scores, header=None) |
|
|
|
obs_ids = scores.iloc[:, 0] |
|
|
|
dropped = 0 |
|
total = 0 |
|
with open(args.out, 'w') as f: |
|
f.write('observation_id,class_id\n') |
|
for id in obs_ids: |
|
scrs = np.array(scores[scores.iloc[:, 0] == id].iloc[0, 1:]) |
|
scrs = torch.softmax(torch.from_numpy(scrs), dim=0).numpy() |
|
max_score = np.max(scrs) |
|
cls_id = np.argmax(scrs) |
|
if max_score < 0.09: |
|
cls_id = -1 |
|
dropped += 1 |
|
total += 1 |
|
f.write(f'{id},{cls_id}\n') |
|
print(f'dropped {dropped} out of {total}') |
|
|
|
|
|
|