import os import pathlib import re import collections import functools import inspect import sys from typing import List, Union import torch from omegaconf import OmegaConf import sacrebleu from rouge_score import rouge_scorer, scoring class ExitCodeError(Exception): pass def sh(x): if os.system(x): raise ExitCodeError() def simple_parse_args_string(args_string): """ Parses something like args1=val1,arg2=val2 Into a dictionary """ args_string = args_string.strip() if not args_string: return {} arg_list = args_string.split(",") args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list)) return args_dict def join_iters(iters): for iter in iters: yield from iter def chunks(iter, n): arr = [] for x in iter: arr.append(x) if len(arr) == n: yield arr arr = [] if arr: yield arr def group(arr, fn): res = collections.defaultdict(list) for ob in arr: res[fn(ob)].append(ob) return list(res.values()) def general_detokenize(string): string = string.replace(" n't", "n't") string = string.replace(" )", ")") string = string.replace("( ", "(") string = string.replace('" ', '"') string = string.replace(' "', '"') string = re.sub(r" (['.,])", r"\1", string) return string def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): """ - context_len allows for a rolling window context, allowing each prediction window to potentially condition on some context :param token_list: list List of tokens to be PREDICTED :param max_seq_len: int max_seq_len of model (or max_seq_len we want to use) :param context_len: int Amount of desired token context for prediction. Needs to be at least 1. :param prefix_token: token Dummy token like so the first token has something to condition on :return: generator Generator of tuples (input_tokens, pred_tokens) Note: Score only the last len(pred_tokens) logits of the LM """ assert 1 <= context_len <= max_seq_len if not token_list: return # +1 offset, going from input->preds pred_len = max_seq_len - context_len + 1 predicted = 0 # Special handling for first window: predict all tokens first_seq_len = min(max_seq_len, len(token_list)) yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) predicted += first_seq_len while predicted < len(token_list): window_pred_len = min(len(token_list) - predicted, pred_len) window_end = predicted + window_pred_len yield ( token_list[window_end - max_seq_len - 1 : window_end - 1], token_list[window_end - window_pred_len : window_end], ) predicted += window_pred_len def make_disjoint_window(pair): """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" a, b = pair return a[: len(a) - (len(b) - 1)], b def select_continuation_from_batch_left_padding( generations: Union[List[List[int]], torch.Tensor], max_context_size: int ): """Select the continuation from the batch, removing prompts of different lengths. Args: generations (Union[List[List[int]], torch.Tensor]): A tensor or list-of-lists of shape [batch_size, sequence length]. max_context_size (int): The size of the biggest context; generations will proceed from that index. Example: PAD PAD Continue : The dog chased the cat [every day of the week] Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD Output: [every day of the week] [yesterday] PAD PAD PAD PAD """ return generations[:, max_context_size:] class Reorderer: def __init__(self, arr, fn): self.size = len(arr) arr = list(enumerate(arr)) arr = group(arr, lambda x: fn(x[1])) arr = [([y[0] for y in x], x[0][1]) for x in arr] arr.sort(key=lambda x: fn(x[1])) self.arr = arr def get_reordered(self): return [x[1] for x in self.arr] def get_original(self, newarr): res = [None] * self.size cov = [False] * self.size for (inds, _), v in zip(self.arr, newarr): for ind in inds: res[ind] = v cov[ind] = True assert all(cov) return res def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the wrapped function, `fn`. """ @functools.wraps(fn) def _wrapper(*args, **kwargs): if len(args) != 1 if inspect.ismethod(fn) else 0: print( f"WARNING: using {fn.__name__} with positional arguments is " "deprecated and will be disallowed in a future version of " "lm-evaluation-harness!" ) return fn(*args, **kwargs) return _wrapper @positional_deprecated def find_test_root(start_path: pathlib.Path) -> pathlib.Path: """ Search upward in the directory tree to a maximum of three layers to find and return the package root (containing the 'tests' folder) """ cur_path = start_path.resolve() max_layers = 3 for _ in range(max_layers): if (cur_path / "tests" / "test_version_stable.py").exists(): return cur_path else: cur_path = cur_path.parent.resolve() raise FileNotFoundError( f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" ) @positional_deprecated def run_task_tests(task_list: List[str]): """ Find the package root and run the tests for the given tasks """ import pytest package_root = find_test_root(start_path=pathlib.Path(__file__)) task_string = " or ".join(task_list) args = [ f"{package_root}/tests/test_version_stable.py", f"--rootdir={package_root}", "-k", f"{task_string}", ] sys.path.append(str(package_root)) pytest_return_val = pytest.main(args) if pytest_return_val: raise ValueError( f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" ) def bleu(refs, preds): """ Returns `t5` style BLEU scores. See the related implementation: https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41 :param refs: A `list` of `list` of reference `str`s. :param preds: A `list` of predicted `str`s. """ score = sacrebleu.corpus_bleu( preds, refs, smooth_method="exp", smooth_value=0.0, force=False, lowercase=False, tokenize="intl", use_effective_order=False, ).score return score def rouge(refs, preds): """ Returns `t5` style ROUGE scores. See the related implementation: https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68 :param refs: A `list` of reference `strs`. :param preds: A `list` of predicted `strs`. """ rouge_types = ["rouge1", "rouge2", "rougeLsum"] scorer = rouge_scorer.RougeScorer(rouge_types) # Add newlines between sentences to correctly compute `rougeLsum`. def _prepare_summary(summary): summary = summary.replace(" . ", ".\n") return summary # Accumulate confidence intervals. aggregator = scoring.BootstrapAggregator() for ref, pred in zip(refs, preds): ref = _prepare_summary(ref) pred = _prepare_summary(pred) aggregator.add_scores(scorer.score(ref, pred)) result = aggregator.aggregate() return {type: result[type].mid.fmeasure * 100 for type in rouge_types} def rouge2_mecab(refs, preds, tokenizer): """This uses a MeCab tokenizer for Japanese text. Besides specifying the tokenizer, this does not perform the rougeLsum related sentence/newline normalization, and only calculates rouge2. Otherwise it is the same as the generic rouge scoring. """ rouge_types = ["rouge2"] # mecab-based rouge scorer = rouge_scorer.RougeScorer( rouge_types, tokenizer=tokenizer, ) # Accumulate confidence intervals. aggregator = scoring.BootstrapAggregator() for ref, pred in zip(refs, preds): aggregator.add_scores(scorer.score(ref, pred)) result = aggregator.aggregate() return {type: result[type].mid.fmeasure * 100 for type in rouge_types}