File size: 1,684 Bytes
44921ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from nervaluate import Evaluator
from sklearn.metrics import classification_report

from token_level_output import get_token_output_labels

EVALUATION_METRICS = [
    "Span Based Evaluation with Partial Overlap",
    "Token Based Evaluation with Micro Avg",
    "Token Based Evaluation with Macro Avg",
]


def get_span_eval(gt_ner_span, pred_ner_span, text):
    evaluator = Evaluator([gt_ner_span], [pred_ner_span], tags=["Disease", "Drug"])
    return round(evaluator.evaluate()[0]["ent_type"]["f1"], 2)


def get_token_micro_eval(gt_ner_span, pred_ner_span, text):
    return round(
        classification_report(
            get_token_output_labels(gt_ner_span, text),
            get_token_output_labels(pred_ner_span, text),
            labels=["Disease", "Drug"],
            output_dict=True,
        )["micro avg"]["f1-score"],
        2,
    )


def get_token_macro_eval(gt_ner_span, pred_ner_span, text):
    return round(
        classification_report(
            get_token_output_labels(gt_ner_span, text),
            get_token_output_labels(pred_ner_span, text),
            labels=["Disease", "Drug"],
            output_dict=True,
        )["macro avg"]["f1-score"],
        2,
    )


def get_evaluation_metric(metric_type, gt_ner_span, pred_ner_span, text):
    match metric_type:
        case "Span Based Evaluation with Partial Overlap":
            return get_span_eval(gt_ner_span, pred_ner_span, text)
        case "Token Based Evaluation with Micro Avg":
            return get_token_micro_eval(gt_ner_span, pred_ner_span, text)
        case "Token Based Evaluation with Macro Avg":
            return get_token_macro_eval(gt_ner_span, pred_ner_span, text)