from abc import ABC, abstractmethod from nervaluate import Evaluator from sklearn.metrics import classification_report from token_level_output import get_token_output_labels class EvaluationMetric(ABC): """Base class defining the attributes & methods of an evaluation metric""" name: str description: str @abstractmethod def get_evaluation_metric(gt_ner_span, pred_ner_span, text, tags) -> float: pass class PartialSpanOverlapMetric(EvaluationMetric): def __init__(self) -> None: super().__init__() self.name = "Span Based Evaluation with Partial Overlap" self.description = "" @staticmethod def get_evaluation_metric(gt_ner_span, pred_ner_span, text, tags) -> float: evaluator = Evaluator([gt_ner_span], [pred_ner_span], tags=tags) return round(evaluator.evaluate()[0]["ent_type"]["f1"], 2) class ExactSpanOverlapMetric(EvaluationMetric): def __init__(self) -> None: super().__init__() self.name = "Span Based Evaluation with Exact Overlap" self.description = "" @staticmethod def get_evaluation_metric(gt_ner_span, pred_ner_span, text, tags) -> float: evaluator = Evaluator([gt_ner_span], [pred_ner_span], tags=tags) return round(evaluator.evaluate()[0]["strict"]["f1"], 2) class TokenMicroMetric(EvaluationMetric): def __init__(self) -> None: super().__init__() self.name = "Token Based Evaluation with Micro Average" self.description = "" @staticmethod def get_evaluation_metric(gt_ner_span, pred_ner_span, text, tags) -> float: return round( classification_report( get_token_output_labels(gt_ner_span, text), get_token_output_labels(pred_ner_span, text), labels=tags, output_dict=True, )["micro avg"]["f1-score"], 2, ) class TokenMacroMetric(EvaluationMetric): def __init__(self) -> None: super().__init__() self.name = "Token Based Evaluation with Macro Average" self.description = "" @staticmethod def get_evaluation_metric(gt_ner_span, pred_ner_span, text, tags) -> float: return round( classification_report( get_token_output_labels(gt_ner_span, text), get_token_output_labels(pred_ner_span, text), labels=tags, output_dict=True, )["macro avg"]["f1-score"], 2, ) EVALUATION_METRICS = [ PartialSpanOverlapMetric(), ExactSpanOverlapMetric(), TokenMicroMetric(), TokenMacroMetric(), ]