FAPM_demo / output /cal_f1.py
wenkai's picture
Upload 29 files
1a7e2de verified
raw
history blame
No virus
2.89 kB
import pandas as pd
def cal_f1(df, standard=False):
df['label_list'] = df['label'].apply(lambda x: [i.strip().lower() for i in x.split(';')])
#df['pred_list_go'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
if standard:
df['pred_list'] = df['pred'].apply(lambda x: [i[0] for i in eval(str(x))])
else:
df['pred_list_prob'] = df['pred'].apply(lambda x: [eval(i.strip()) for i in str(x).split(';')])
df['pred_list'] = df['pred_list_prob'].apply(lambda x: [i[0] for i in x])
labels = []
pred_labels = []
for l in df['label_list']:
labels.extend(l)
label_count = {}
for x in labels:
if x not in label_count:
label_count[x] = 1
else:
label_count[x] += 1
labels = list(set(labels))
total = len(labels)
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0] * len(labels))), dict(zip(labels, [0] * len(labels))), dict(
zip(labels, [0] * len(labels)))
for preds, label in zip(df['pred_list'], df['label_list']):
for t in label:
# supgo = godb.get_anchestors(t)
# if supgo.intersection(set(preds)):
if t in preds:
tp_dict[t] += 1
else:
fn_dict[t] += 1
for p in preds:
# supgo = godb.get_anchestors(p)
# if not supgo.intersection(set(label)):
if p not in label:
if p in fp_dict:
fp_dict[p] += 1
else:
fp_dict[p] = 1
pred_labels.extend(preds)
p_total = len(set(pred_labels))
recall, pr = 0., 0.
for x in labels:
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
r = recall / total
p = pr / p_total
f1 = 2 * p * r / (p + r + 1e-8)
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
names = ['output_test_mf_exp_493552.txt', 'output_test_mf_exp_445772_pre.txt', 'output_test_mf_exp_445772.txt', 'output_test_mf_exp_486524.txt', 'output_test_mf_493552_standard.csv', 'output_test_mf_445772_standard.csv', 'output_test_mf_exp_445772_withprompt.txt', 'output_test_mf_exp_506753.txt']
#names = ['output_test_bp_exp_451674.txt', 'output_test_bp_exp_493547_pre.txt', 'output_test_bp_exp_496359_withprompt.txt']
for name in names:
print(name)
df = pd.read_csv('/cluster/home/wenkai/LAVIS/output/mf_bp_cc/{}'.format(name), sep='|', header=None)
if df.iloc[0, 0] == 'name':
df = df[1:]
#print(df.shape)
df.columns = ['name', 'pred', 'label']
if 'standard' in name:
cal_f1(df, standard=True)
else:
cal_f1(df)