Spaces:
Sleeping
Sleeping
""" DataStats metric. """ | |
import functools | |
from collections import Counter | |
from multiprocessing import Pool | |
from contextlib import contextmanager | |
from typing import List, Any, Dict, Optional | |
from collections import namedtuple as _namedtuple | |
import spacy | |
import datasets | |
import evaluate | |
from packaging import version | |
try: | |
_en = spacy.load('en_core_web_sm') | |
except OSError as stderr: | |
spacy.cli.download('en_core_web_sm') | |
_en = spacy.load('en_core_web_sm') | |
def filter_logging_context(): | |
def filter_log(record): | |
return False if "This is expected if you are initialising" in record.msg else True | |
logger = datasets.utils.logging.get_logger("transformers.modeling_utils") | |
logger.addFilter(filter_log) | |
try: | |
yield | |
finally: | |
logger.removeFilter(filter_log) | |
_CITATION = """\ | |
@article{grusky2018newsroom, | |
title={Newsroom: A dataset of 1.3 million summaries with diverse extractive strategies}, | |
author={Grusky, Max and Naaman, Mor and Artzi, Yoav}, | |
journal={arXiv preprint arXiv:1804.11283}, | |
year={2018} | |
} | |
""" | |
_DESCRIPTION = """\ | |
DataStats examines summarization strategies using three measures that capture the degree of text overlap between the summary and article, and the rate of compression of the information conveyed. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
BERTScore Metrics with the hashcode from a source against one or more references. | |
Args: | |
predictions (list of str): Prediction/candidate sentences. | |
references (list of str or list of list of str): Reference sentences. | |
Returns: | |
coverage: Percentage of words in the summary that are from the source article, measuring the extent to which a summary is a derivative of a text. | |
density: It is defined as the average length of the extractive fragment to which each summary word belongs. | |
compression: It is defined as the word ratio between the articles and its summaries. | |
Examples: | |
>>> predictions = ["hello there", "general kenobi"] | |
>>> references = ["hello there", "general kenobi"] | |
>>> bertscore = evaluate.load("datastats") | |
>>> results = bertscore.compute(predictions=predictions, references=references) | |
""" | |
def find_ngrams(input_list: List[Any], n: int): | |
return zip(*[input_list[i:] for i in range(n)]) | |
def normalize(tokens: List[str], lowercase: bool = False): | |
""" | |
Lowercases and turns tokens into distinct words. | |
""" | |
return [str(t).lower() if not lowercase else str(t) for t in tokens] | |
class Fragments: | |
Match = _namedtuple("Match", ("summary", "text", "length")) | |
def __init__(self, summary, text, lowercase: bool = False): | |
if isinstance(summary, str): | |
self.summary = summary.split() | |
else: | |
self.summary = summary | |
if isinstance(text, str): | |
self.text = text.split() | |
else: | |
self.text = text | |
self._norm_summary = normalize(self.summary, lowercase) | |
self._norm_text = normalize(self.text, lowercase) | |
self._match(self._norm_summary, self._norm_text) | |
def overlaps(self): | |
""" | |
Return a list of Fragments.Match objects between summary and text. | |
This is a list of named tuples of the form (summary, text, length): | |
""" | |
return self._matches | |
def strings(self, min_length=0, summary_base=True): | |
# Compute the strings against the summary or the text? | |
base = self.summary if summary_base else self.text | |
# Generate strings, filtering out strings below the minimum length. | |
strings = [base[i : i + length] for i, j, length in self.overlaps() if length > min_length] | |
return strings | |
def coverage(self, summary_base=True): | |
""" | |
Return the COVERAGE score of the summary and text. | |
""" | |
numerator = sum(o.length for o in self.overlaps()) | |
if summary_base: | |
denominator = len(self.summary) | |
else: | |
denominator = len(self.text) | |
if denominator == 0: | |
return 0 | |
else: | |
return numerator / denominator | |
def density(self, summary_base=True): | |
""" | |
Return the DENSITY score of summary and text. | |
""" | |
numerator = sum(o.length ** 2 for o in self.overlaps()) | |
if summary_base: | |
denominator = len(self.summary) | |
else: | |
denominator = len(self.text) | |
if denominator == 0: | |
return 0 | |
else: | |
return numerator / denominator | |
def compression(self, text_to_summary=True): | |
""" | |
Return compression ratio between summary and text. | |
""" | |
ratio = [len(self.text), len(self.summary)] | |
try: | |
if text_to_summary: | |
return ratio[0] / ratio[1] | |
else: | |
return ratio[1] / ratio[0] | |
except ZeroDivisionError: | |
return 0 | |
def _match(self, a, b): | |
""" | |
Raw procedure for matching summary in text, described in paper. | |
""" | |
self._matches = [] | |
a_start = b_start = 0 | |
while a_start < len(a): | |
best_match = None | |
best_match_length = 0 | |
while b_start < len(b): | |
if a[a_start] == b[b_start]: | |
a_end = a_start | |
b_end = b_start | |
while a_end < len(a) and b_end < len(b) \ | |
and b[b_end] == a[a_end]: | |
b_end += 1 | |
a_end += 1 | |
length = a_end - a_start | |
if length > best_match_length: | |
best_match = Fragments.Match(a_start, b_start, length) | |
best_match_length = length | |
b_start = b_end | |
else: | |
b_start += 1 | |
b_start = 0 | |
if best_match: | |
if best_match_length > 0: | |
self._matches.append(best_match) | |
a_start += best_match_length | |
else: | |
a_start += 1 | |
class DataStatsMetric(object): | |
def __init__( | |
self, | |
n_gram: int = 3, | |
n_workers: int = 24, | |
lowercase: bool = False, | |
tokenize: bool = True | |
): | |
""" | |
Data Statistics metric | |
Args: | |
n_gram (int): Compute statistics for n-grams up to and including this length. | |
n_workers (int): Number of processes to use if using multiprocessing. | |
case (bool): Whether to lowercase input before calculating statistics. | |
tokenize (bool): Whether to tokenize the input. | |
""" | |
self.n_gram = n_gram | |
self.n_workers = n_workers | |
self.lowercase = lowercase | |
self.tokenize = tokenize | |
def evaluate_example(self, summary, input_text): | |
if self.tokenize: | |
input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"]) | |
input_text = [tok.text for tok in input_text] | |
summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"]) | |
summary = [tok.text for tok in summary] | |
fragments = Fragments(summary, input_text, lowercase=self.lowercase) | |
coverage = fragments.coverage() | |
density = fragments.density() | |
compression = fragments.compression() | |
score_dict = {"coverage": coverage, "density": density, "compression": compression} | |
tokenized_summary = fragments._norm_summary | |
tokenized_text = fragments._norm_text | |
score_dict["summary_length"] = len(tokenized_summary) | |
for i in range(1, self.n_gram + 1): | |
input_ngrams = list(find_ngrams(tokenized_text, i)) | |
summ_ngrams = list(find_ngrams(tokenized_summary, i)) | |
input_ngrams_set = set(input_ngrams) | |
summ_ngrams_set = set(summ_ngrams) | |
intersect = summ_ngrams_set.intersection(input_ngrams_set) | |
try: | |
score_dict[f"percentage_novel_{i}-gram"] = (len(summ_ngrams_set) \ | |
- len(intersect))/float(len(summ_ngrams_set)) | |
ngramCounter = Counter() | |
ngramCounter.update(summ_ngrams) | |
repeated = [key for key, val in ngramCounter.items() if val > 1] | |
score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len(repeated)/float(len(summ_ngrams_set)) | |
except ZeroDivisionError: | |
continue | |
return score_dict | |
def evaluate_batch(self, summaries, input_texts, aggregate=True): | |
corpus_score_dict = Counter() | |
p = Pool(processes=self.n_workers) | |
results = p.starmap(self.evaluate_example, zip(summaries, input_texts)) | |
p.close() | |
if aggregate: | |
[corpus_score_dict.update(x) for x in results] | |
for key in corpus_score_dict.keys(): | |
corpus_score_dict[key] /= float(len(input_texts)) | |
return corpus_score_dict | |
else: | |
return results | |
def supports_multi_ref(self): | |
return False | |
class DataStats(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
homepage="", | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=[ | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), | |
} | |
), | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Value("string", id="sequence"), | |
} | |
), | |
], | |
codebase_urls=["https://github.com/Tiiiger/bert_score"], | |
reference_urls=[ | |
"https://github.com/lil-lab/newsroom", | |
"https://arxiv.org/pdf/2007.12626", | |
], | |
) | |
def _compute( | |
self, | |
predictions, | |
references, | |
n_gram: int = 3, | |
n_workers: int = 24, | |
lowercase: bool = False, | |
tokenize: bool = True | |
): | |
datastats = DataStatsMetric(n_gram, n_workers, lowercase, tokenize) | |
results = datastats.evaluate_batch(predictions, references) | |
return { | |
"coverage": results['coverage'], | |
"density": results['density'], | |
"compression": results['compression'] | |
} |