|
import os |
|
import pathlib |
|
import re |
|
import collections |
|
import functools |
|
import inspect |
|
import sys |
|
from typing import List, Union |
|
|
|
import torch |
|
|
|
from omegaconf import OmegaConf |
|
|
|
import sacrebleu |
|
from rouge_score import rouge_scorer, scoring |
|
|
|
|
|
class ExitCodeError(Exception): |
|
pass |
|
|
|
|
|
def sh(x): |
|
if os.system(x): |
|
raise ExitCodeError() |
|
|
|
|
|
def simple_parse_args_string(args_string): |
|
""" |
|
Parses something like |
|
args1=val1,arg2=val2 |
|
Into a dictionary |
|
""" |
|
args_string = args_string.strip() |
|
if not args_string: |
|
return {} |
|
arg_list = args_string.split(",") |
|
args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list)) |
|
return args_dict |
|
|
|
|
|
def join_iters(iters): |
|
for iter in iters: |
|
yield from iter |
|
|
|
|
|
def chunks(iter, n): |
|
arr = [] |
|
for x in iter: |
|
arr.append(x) |
|
if len(arr) == n: |
|
yield arr |
|
arr = [] |
|
|
|
if arr: |
|
yield arr |
|
|
|
|
|
def group(arr, fn): |
|
res = collections.defaultdict(list) |
|
|
|
for ob in arr: |
|
res[fn(ob)].append(ob) |
|
|
|
return list(res.values()) |
|
|
|
|
|
def general_detokenize(string): |
|
string = string.replace(" n't", "n't") |
|
string = string.replace(" )", ")") |
|
string = string.replace("( ", "(") |
|
string = string.replace('" ', '"') |
|
string = string.replace(' "', '"') |
|
string = re.sub(r" (['.,])", r"\1", string) |
|
return string |
|
|
|
|
|
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): |
|
""" |
|
- context_len allows for a rolling window context, allowing each prediction window to potentially |
|
condition on some context |
|
|
|
:param token_list: list |
|
List of tokens to be PREDICTED |
|
:param max_seq_len: int |
|
max_seq_len of model (or max_seq_len we want to use) |
|
:param context_len: int |
|
Amount of desired token context for prediction. Needs to be at least 1. |
|
:param prefix_token: token |
|
Dummy token like <eos> so the first token has something to condition on |
|
:return: generator |
|
Generator of tuples |
|
(input_tokens, pred_tokens) |
|
Note: Score only the last len(pred_tokens) logits of the LM |
|
""" |
|
assert 1 <= context_len <= max_seq_len |
|
if not token_list: |
|
return |
|
|
|
pred_len = max_seq_len - context_len + 1 |
|
predicted = 0 |
|
|
|
|
|
first_seq_len = min(max_seq_len, len(token_list)) |
|
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) |
|
predicted += first_seq_len |
|
|
|
while predicted < len(token_list): |
|
window_pred_len = min(len(token_list) - predicted, pred_len) |
|
window_end = predicted + window_pred_len |
|
|
|
yield ( |
|
token_list[window_end - max_seq_len - 1 : window_end - 1], |
|
token_list[window_end - window_pred_len : window_end], |
|
) |
|
predicted += window_pred_len |
|
|
|
|
|
def make_disjoint_window(pair): |
|
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" |
|
a, b = pair |
|
return a[: len(a) - (len(b) - 1)], b |
|
|
|
|
|
def select_continuation_from_batch_left_padding( |
|
generations: Union[List[List[int]], torch.Tensor], max_context_size: int |
|
): |
|
"""Select the continuation from the batch, removing prompts of different lengths. |
|
Args: |
|
generations (Union[List[List[int]], torch.Tensor]): |
|
A tensor or list-of-lists of shape [batch_size, sequence length]. |
|
max_context_size (int): |
|
The size of the biggest context; generations will proceed from that |
|
index. |
|
Example: |
|
PAD PAD Continue : The dog chased the cat [every day of the week] |
|
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD |
|
Output: |
|
[every day of the week] |
|
[yesterday] PAD PAD PAD PAD |
|
""" |
|
return generations[:, max_context_size:] |
|
|
|
|
|
class Reorderer: |
|
def __init__(self, arr, fn): |
|
self.size = len(arr) |
|
arr = list(enumerate(arr)) |
|
arr = group(arr, lambda x: fn(x[1])) |
|
arr = [([y[0] for y in x], x[0][1]) for x in arr] |
|
arr.sort(key=lambda x: fn(x[1])) |
|
|
|
self.arr = arr |
|
|
|
def get_reordered(self): |
|
return [x[1] for x in self.arr] |
|
|
|
def get_original(self, newarr): |
|
res = [None] * self.size |
|
cov = [False] * self.size |
|
|
|
for (inds, _), v in zip(self.arr, newarr): |
|
for ind in inds: |
|
res[ind] = v |
|
cov[ind] = True |
|
|
|
assert all(cov) |
|
|
|
return res |
|
|
|
|
|
def positional_deprecated(fn): |
|
""" |
|
A decorator to nudge users into passing only keyword args (`kwargs`) to the |
|
wrapped function, `fn`. |
|
""" |
|
|
|
@functools.wraps(fn) |
|
def _wrapper(*args, **kwargs): |
|
if len(args) != 1 if inspect.ismethod(fn) else 0: |
|
print( |
|
f"WARNING: using {fn.__name__} with positional arguments is " |
|
"deprecated and will be disallowed in a future version of " |
|
"lm-evaluation-harness!" |
|
) |
|
return fn(*args, **kwargs) |
|
|
|
return _wrapper |
|
|
|
|
|
@positional_deprecated |
|
def find_test_root(start_path: pathlib.Path) -> pathlib.Path: |
|
""" |
|
Search upward in the directory tree to a maximum of three layers |
|
to find and return the package root (containing the 'tests' folder) |
|
""" |
|
cur_path = start_path.resolve() |
|
max_layers = 3 |
|
for _ in range(max_layers): |
|
if (cur_path / "tests" / "test_version_stable.py").exists(): |
|
return cur_path |
|
else: |
|
cur_path = cur_path.parent.resolve() |
|
raise FileNotFoundError( |
|
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" |
|
) |
|
|
|
|
|
@positional_deprecated |
|
def run_task_tests(task_list: List[str]): |
|
""" |
|
Find the package root and run the tests for the given tasks |
|
""" |
|
import pytest |
|
|
|
package_root = find_test_root(start_path=pathlib.Path(__file__)) |
|
task_string = " or ".join(task_list) |
|
args = [ |
|
f"{package_root}/tests/test_version_stable.py", |
|
f"--rootdir={package_root}", |
|
"-k", |
|
f"{task_string}", |
|
] |
|
sys.path.append(str(package_root)) |
|
pytest_return_val = pytest.main(args) |
|
if pytest_return_val: |
|
raise ValueError( |
|
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" |
|
) |
|
|
|
|
|
def bleu(refs, preds): |
|
""" |
|
Returns `t5` style BLEU scores. See the related implementation: |
|
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41 |
|
|
|
:param refs: |
|
A `list` of `list` of reference `str`s. |
|
:param preds: |
|
A `list` of predicted `str`s. |
|
""" |
|
score = sacrebleu.corpus_bleu( |
|
preds, |
|
refs, |
|
smooth_method="exp", |
|
smooth_value=0.0, |
|
force=False, |
|
lowercase=False, |
|
tokenize="intl", |
|
use_effective_order=False, |
|
).score |
|
return score |
|
|
|
|
|
def rouge(refs, preds): |
|
""" |
|
Returns `t5` style ROUGE scores. See the related implementation: |
|
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68 |
|
|
|
:param refs: |
|
A `list` of reference `strs`. |
|
:param preds: |
|
A `list` of predicted `strs`. |
|
""" |
|
rouge_types = ["rouge1", "rouge2", "rougeLsum"] |
|
scorer = rouge_scorer.RougeScorer(rouge_types) |
|
|
|
|
|
def _prepare_summary(summary): |
|
summary = summary.replace(" . ", ".\n") |
|
return summary |
|
|
|
|
|
aggregator = scoring.BootstrapAggregator() |
|
for ref, pred in zip(refs, preds): |
|
ref = _prepare_summary(ref) |
|
pred = _prepare_summary(pred) |
|
aggregator.add_scores(scorer.score(ref, pred)) |
|
result = aggregator.aggregate() |
|
return {type: result[type].mid.fmeasure * 100 for type in rouge_types} |
|
|
|
|
|
def rouge2_mecab(refs, preds, tokenizer): |
|
"""This uses a MeCab tokenizer for Japanese text. |
|
|
|
Besides specifying the tokenizer, this does not perform the rougeLsum |
|
related sentence/newline normalization, and only calculates rouge2. |
|
Otherwise it is the same as the generic rouge scoring. |
|
""" |
|
rouge_types = ["rouge2"] |
|
|
|
scorer = rouge_scorer.RougeScorer( |
|
rouge_types, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
|
|
aggregator = scoring.BootstrapAggregator() |
|
for ref, pred in zip(refs, preds): |
|
aggregator.add_scores(scorer.score(ref, pred)) |
|
result = aggregator.aggregate() |
|
return {type: result[type].mid.fmeasure * 100 for type in rouge_types} |
|
|