from typing import Dict, List from lightning.pytorch.callbacks import Callback from reader.data.relik_reader_sample import RelikReaderSample from relik.reader.relik_reader_predictor import RelikReaderPredictor from relik.reader.utils.metrics import compute_metrics class StrongMatching: def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: # accumulators correct_predictions, total_predictions, total_gold = ( 0, 0, 0, ) correct_predictions_strict, total_predictions_strict = ( 0, 0, ) correct_predictions_bound, total_predictions_bound = ( 0, 0, ) correct_span_predictions, total_span_predictions, total_gold_spans = 0, 0, 0 # collect data from samples for sample in predicted_samples: if sample.triplets is None: sample.triplets = [] if sample.entity_candidates: predicted_annotations_strict = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.predicted_relations ] ) gold_annotations_strict = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], triplet["subject"]["type"], triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], triplet["object"]["type"], ) for triplet in sample.triplets ] ) predicted_spans_strict = set(sample.predicted_entities) gold_spans_strict = set(sample.entities) # strict correct_span_predictions += len( predicted_spans_strict.intersection(gold_spans_strict) ) total_span_predictions += len(predicted_spans_strict) total_gold_spans += len(gold_spans_strict) correct_predictions_strict += len( predicted_annotations_strict.intersection(gold_annotations_strict) ) total_predictions_strict += len(predicted_annotations_strict) predicted_annotations = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) for triplet in sample.predicted_relations ] ) gold_annotations = set( [ ( triplet["subject"]["start"], triplet["subject"]["end"], -1, triplet["relation"]["name"], triplet["object"]["start"], triplet["object"]["end"], -1, ) for triplet in sample.triplets ] ) predicted_spans = set( [(ss, se) for (ss, se, _) in sample.predicted_entities] ) gold_spans = set([(ss, se) for (ss, se, _) in sample.entities]) total_gold_spans += len(gold_spans) correct_predictions_bound += len(predicted_spans.intersection(gold_spans)) total_predictions_bound += len(predicted_spans) total_predictions += len(predicted_annotations) total_gold += len(gold_annotations) # correct relation extraction correct_predictions += len( predicted_annotations.intersection(gold_annotations) ) span_precision, span_recall, span_f1 = compute_metrics( correct_span_predictions, total_span_predictions, total_gold_spans ) bound_precision, bound_recall, bound_f1 = compute_metrics( correct_predictions_bound, total_predictions_bound, total_gold_spans ) precision, recall, f1 = compute_metrics( correct_predictions, total_predictions, total_gold ) if sample.entity_candidates: precision_strict, recall_strict, f1_strict = compute_metrics( correct_predictions_strict, total_predictions_strict, total_gold ) return { "span-precision": span_precision, "span-recall": span_recall, "span-f1": span_f1, "precision": precision, "recall": recall, "f1": f1, "precision-strict": precision_strict, "recall-strict": recall_strict, "f1-strict": f1_strict, } else: return { "span-precision": bound_precision, "span-recall": bound_recall, "span-f1": bound_f1, "precision": precision, "recall": recall, "f1": f1, } class REStrongMatchingCallback(Callback): def __init__(self, dataset_path: str, dataset_conf) -> None: super().__init__() self.dataset_path = dataset_path self.dataset_conf = dataset_conf self.strong_matching_metric = StrongMatching() def on_validation_epoch_start(self, trainer, pl_module) -> None: relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_re_model) predicted_samples = relik_reader_predictor._predict( self.dataset_path, None, self.dataset_conf, ) predicted_samples = list(predicted_samples) for k, v in self.strong_matching_metric(predicted_samples).items(): pl_module.log(f"val_{k}", v)