# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TODO: Add a description here.""" from operator import eq from typing import Callable, Iterable, Union import evaluate import datasets import numpy as np import logging logger = logging.getLogger(__name__) # TODO: Add BibTeX citation _CITATION = """\ @InProceedings{huggingface:module, title = {A great new module}, authors={huggingface, Inc.}, year={2020} } """ # TODO: Add description of the module here _DESCRIPTION = """\ Computes precision, recall, f1 scores for joint entity-relation extraction task. """ # TODO: Add description of the arguments of the module here _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: list of predictions to score. Each predictions should be a string with tokens separated by spaces. references: list of reference for each prediction. Each reference should be a string with tokens separated by spaces. eq_fn: function to compare two items. Defaults to the equality operator. Returns: recall: precision: f1: Examples: >>> jer = evaluate.load("jer") >>> results = jer.compute(references=[["Baris | play | tennis", "Deniz | travel | London"]], predictions=[["Baris | play | tennis"]]) >>> print(results) {'recall': 0.5, 'precision': 1.0, 'f1': 0.6666666666666666} """ Triplet = Union[str, tuple, int] @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class jer(evaluate.Metric): """TODO: Short description of my evaluation module.""" def _info(self): # TODO: Specifies the evaluate.EvaluationModuleInfo object return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features({ 'predictions': datasets.features.Sequence(datasets.Value('string')), 'references': datasets.features.Sequence(datasets.Value('string')), }), # Homepage of the module for documentation homepage="http://module.homepage", # Additional links to the codebase or references codebase_urls=["http://github.com/path/to/codebase/of/new_module"], reference_urls=["http://path.to.reference.url/new_module"] ) def _download_and_prepare(self, dl_manager): """Optional: download external resources useful to compute the scores""" pass def _compute(self, predictions, references, eq_fn=eq): """Returns the scores""" score_dicts = [ self._compute_single(prediction=prediction, reference=reference, eq_fn=eq_fn) for prediction, reference in zip(predictions, references) ] return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()} def _compute_single( self, *, prediction: Iterable[Triplet], reference: Iterable[Triplet], eq_fn: Callable[[Triplet, Triplet], bool], ): reference_set = set(reference) if len(reference) != len(reference_set): logger.warn(f"Duplicates found in the reference list {reference}") prediction_set = set(prediction) tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference) fp = len(prediction_set) - tp fn = len(reference_set) - tp # Calculate metrics precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / (tp + fn) if tp + fn > 0 else 0 f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 return { 'precision': precision, 'recall': recall, 'f1': f1_score } def is_in(target, collection: Iterable, eq_fn=eq) -> bool: for item in collection: if eq_fn(item, target): return True return False