File size: 2,405 Bytes
ed49033
 
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed49033
 
 
 
 
 
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from sentence_transformers import util

def calc_recall(true_pos, false_neg, eps=1e-8):
    return true_pos / (true_pos + false_neg + eps)



def calc_precision(true_pos, false_pos, eps=1e-8):
    return true_pos / (true_pos + false_pos + eps)



def calc_f1_score(precision, recall, eps=1e-8):
    return (2*precision*recall) / (precision + recall + eps)



def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    false_pos_ids = []
    false_neg_ids = []
    
    i = 0
    total = len(true)
    for j, (true_ents, pred_ents) in enumerate(zip(true, predicted)):
        i += 1
        # print(f'{i}/{total}')
        # print('----------------------------')
        
        if len(true_ents) == 0:
            false_pos += len(pred_ents)
            
            if len(pred_ents) > 0:
                false_pos_ids.append(j)
            
            continue
            
        if len(pred_ents) == 0:
            false_neg += len(true_ents)
            
            if len(true_ents) > 0:
                # print('False Negative')
                false_neg_ids.append(j)
            
            continue
        
        embed_true = model.encode(true_ents, convert_to_tensor=True)
        embed_pred = model.encode(pred_ents, convert_to_tensor=True)

        similarities = util.pytorch_cos_sim(embed_true, embed_pred)
        # similarities = model.similarity(true_ents, pred_ents, device='cuda')
        
        for row in similarities:
            if (row >= threshold).any():
                true_pos += 1
            else:
                false_neg += 1
                # print('False Negative 2222222')
                false_neg_ids.append(j)

        for row in similarities.T:
            if (row >= threshold).any():
                continue
            else:
                false_pos += 1
                false_pos_ids.append(j)
                
    recall = calc_recall(true_pos, false_neg)
    precision = calc_precision(true_pos, false_pos)
    f1_score = calc_f1_score(precision, recall, eps=eps)
    
    return {
        # 'true_pos': true_pos,
        # 'false_pos': false_pos,
        # 'false_neg': false_neg,
        'recall': recall,
        'precision': precision,
        'f1': f1_score,
        # 'false_pos_ids': list(set(false_pos_ids)),
        # 'false_neg_ids': list(set(false_neg_ids))
    }