jgauthier commited on
Commit
bcb8ccf
1 Parent(s): 9a11e1b

move metric and tests from dataset repo

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. prediction.py +235 -0
  3. syntaxgym.py +225 -52
  4. test.py +516 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ __pycache__
prediction.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional as TOptional, List as TList
2
+
3
+
4
+ from pyparsing import *
5
+ import numpy as np
6
+
7
+ METRICS = {
8
+ 'sum': sum,
9
+ 'mean': np.mean,
10
+ 'median': np.median,
11
+ 'range': np.ptp,
12
+ 'max': max,
13
+ 'min': min
14
+ }
15
+
16
+
17
+ # Enable parser packrat (caching)
18
+ ParserElement.enablePackrat()
19
+
20
+ # Relative and absolute tolerance thresholds for surprisal equality
21
+ EQUALITY_RTOL = 1e-5
22
+ EQUALITY_ATOL = 1e-3
23
+
24
+
25
+ #######
26
+ # Define a grammar for prediction formulae.
27
+
28
+ # References a surprisal region
29
+ lpar = Suppress("(")
30
+ rpar = Suppress(")")
31
+ region = lpar + (Word(nums) | "*") + Suppress(";%") + Word(alphanums + "_-") + Suppress("%") + rpar
32
+ literal_float = pyparsing_common.number
33
+
34
+ class Region(object):
35
+ def __init__(self, tokens):
36
+ self.region_number = tokens[0]
37
+ self.condition_name = tokens[1]
38
+
39
+ def __str__(self):
40
+ return "(%s;%%%s%%)" % (self.region_number, self.condition_name)
41
+
42
+ def __repr__(self):
43
+ return "Region(%s,%s)" % (self.condition_name, self.region_number)
44
+
45
+ def __call__(self, surprisal_dict):
46
+ if self.region_number == "*":
47
+ return sum(value for (condition, region), value in surprisal_dict.items()
48
+ if condition == self.condition_name)
49
+
50
+ return surprisal_dict[self.condition_name, int(self.region_number)]
51
+
52
+ class LiteralFloat(object):
53
+ def __init__(self, tokens):
54
+ self.value = float(tokens[0])
55
+
56
+ def __str__(self):
57
+ return "%f" % (self.value,)
58
+
59
+ def __repr__(self):
60
+ return "LiteralFloat(%f)" % (self.value,)
61
+
62
+ def __call__(self, surprisal_dict):
63
+ return self.value
64
+
65
+ class BinaryOp(object):
66
+ operators: TOptional[TList[str]]
67
+
68
+ def __init__(self, tokens):
69
+ self.operator = tokens[0][1]
70
+ if self.operators is not None and self.operator not in self.operators:
71
+ raise ValueError("Invalid %s operator %s" % (self.__class__.__name__,
72
+ self.operator))
73
+ self.operands = [tokens[0][0], tokens[0][2]]
74
+
75
+ def __str__(self):
76
+ return "(%s %s %s)" % (self.operands[0], self.operator, self.operands[1])
77
+
78
+ def __repr__(self):
79
+ return "%s(%s)(%s)" % (self.__class__.__name__, self.operator, ",".join(map(repr, self.operands)))
80
+
81
+ def __call__(self, surprisal_dict):
82
+ op_vals = [op(surprisal_dict) for op in self.operands]
83
+ return self._evaluate(op_vals, surprisal_dict)
84
+
85
+ def _evaluate(self, evaluated_operands, surprisal_dict):
86
+ raise NotImplementedError()
87
+
88
+ class BoolOp(BinaryOp):
89
+ operators = ["&", "|"]
90
+ def _evaluate(self, op_vals, surprisal_dict):
91
+ if self.operator == "&":
92
+ return op_vals[0] and op_vals[1]
93
+ elif self.operator == "|":
94
+ return op_vals[0] or op_vals[1]
95
+
96
+ class FloatOp(BinaryOp):
97
+ operators = ["-", "+"]
98
+ def _evaluate(self, op_vals, surprisal_dict):
99
+ if self.operator == "-":
100
+ return op_vals[0] - op_vals[1]
101
+ elif self.operator == "+":
102
+ return op_vals[0] + op_vals[1]
103
+
104
+ class ComparatorOp(BinaryOp):
105
+ operators = ["<", ">", "="]
106
+ def _evaluate(self, op_vals, surprisal_dict):
107
+ if self.operator == "<":
108
+ return op_vals[0] < op_vals[1]
109
+ elif self.operator == ">":
110
+ return op_vals[0] > op_vals[1]
111
+ elif self.operator == "=":
112
+ return np.isclose(op_vals[0], op_vals[1],
113
+ rtol=EQUALITY_RTOL,
114
+ atol=EQUALITY_ATOL)
115
+
116
+ def Chain(op_cls, left_assoc=True):
117
+ def chainer(tokens):
118
+ """
119
+ Create a binary tree of BinaryOps from the given repeated application
120
+ of the op.
121
+ """
122
+ operators = tokens[0][1::2]
123
+ args = tokens[0][0::2]
124
+ if not left_assoc:
125
+ raise NotImplementedError
126
+
127
+ arg1 = args.pop(0)
128
+ while len(args) > 0:
129
+ operator = operators.pop(0)
130
+ arg2 = args.pop(0)
131
+ arg1 = op_cls([[arg1, operator, arg2]])
132
+
133
+ return arg1
134
+
135
+ return chainer
136
+
137
+ atom = region.setParseAction(Region) | literal_float.setParseAction(LiteralFloat)
138
+
139
+ prediction_expr = infixNotation(
140
+ atom,
141
+ [
142
+ (oneOf("- +"), 2, opAssoc.LEFT, Chain(FloatOp)),
143
+ (oneOf("< > ="), 2, opAssoc.LEFT, ComparatorOp),
144
+ (oneOf("& |"), 2, opAssoc.LEFT, Chain(BoolOp)),
145
+ ],
146
+ lpar=lpar, rpar=rpar
147
+ )
148
+
149
+
150
+ class Prediction(object):
151
+ """
152
+ Predictions state expected relations between language model surprisal
153
+ measures in different regions and conditions of a test suite. For more
154
+ information, see :ref:`architecture`.
155
+ """
156
+
157
+ def __init__(self, idx: int, formula: Union[str, BinaryOp], metric: str):
158
+ """
159
+ Args:
160
+ idx: A unique prediction ID. This is only relevant for
161
+ serialization.
162
+ formula: A string representation of the prediction formula, or an
163
+ already parsed formula. For more information, see
164
+ :ref:`architecture`.
165
+ metric: Metric for aggregating surprisals within regions.
166
+ """
167
+ if isinstance(formula, str):
168
+ try:
169
+ formula = prediction_expr.parseString(formula, parseAll=True)[0]
170
+ except ParseException as e:
171
+ raise ValueError("Invalid formula expression %r" % (formula,)) from e
172
+
173
+ self.idx = idx
174
+ self.formula = formula
175
+
176
+ if metric not in METRICS.keys():
177
+ raise ValueError("Unknown metric %s. Supported metrics: %s" %
178
+ (metric, " ".join(METRICS.keys())))
179
+ self.metric = metric
180
+
181
+ def __call__(self, item):
182
+ """
183
+ Evaluate the prediction on the given item dict representation. For more
184
+ information on item representations, see :ref:`suite_json`.
185
+ """
186
+ # Prepare relevant surprisal dict
187
+ surps = {(c["condition_name"], r["region_number"]): r["metric_value"][self.metric]
188
+ for c in item["conditions"]
189
+ for r in c["regions"]}
190
+ return self.formula(surps)
191
+
192
+ @classmethod
193
+ def from_dict(cls, pred_dict, idx: int, metric: str):
194
+ """
195
+ Parse from a prediction dictionary representation (see
196
+ :ref:`suite_json`).
197
+ """
198
+ if not pred_dict["type"] == "formula":
199
+ raise ValueError("Unknown prediction type %s" % (pred_dict["type"],))
200
+
201
+ return cls(formula=pred_dict["formula"], idx=idx, metric=metric)
202
+
203
+ @property
204
+ def referenced_regions(self):
205
+ """
206
+ Get a set of the regions referenced by this formula.
207
+ Each item is a tuple of the form ``(condition_name, region_number)``.
208
+ """
209
+ def traverse(x, acc):
210
+ if isinstance(x, BinaryOp):
211
+ for val in x.operands:
212
+ traverse(val, acc)
213
+ elif isinstance(x, Region):
214
+ acc.add((x.condition_name, int(x.region_number)))
215
+
216
+ return acc
217
+
218
+ return traverse(self.formula, set())
219
+
220
+ def as_dict(self):
221
+ """
222
+ Serialize as a prediction dictionary representation (see
223
+ :ref:`suite_json`).
224
+ """
225
+ return dict(type="formula", formula=str(self.formula))
226
+
227
+ def __str__(self):
228
+ return "Prediction(%s)" % (self.formula,)
229
+ __repr__ = __str__
230
+
231
+ def __hash__(self):
232
+ return hash(self.formula)
233
+
234
+ def __eq__(self, other):
235
+ return isinstance(other, Prediction) and hash(self) == hash(other)
syntaxgym.py CHANGED
@@ -13,83 +13,256 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
- import evaluate
 
 
 
17
  import datasets
 
 
 
 
 
 
18
 
19
 
20
- # TODO: Add BibTeX citation
21
  _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
 
26
  }
27
  """
28
 
29
  # TODO: Add description of the module here
30
- _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
32
  """
33
 
34
 
35
  # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
 
 
 
 
 
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
 
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
- >>> print(results)
53
- {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
- class {{ cookiecutter.module_class_name }}(evaluate.EvaluationModule):
62
- """TODO: Short description of my evaluation module."""
 
 
63
 
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
 
 
 
66
  return evaluate.EvaluationModuleInfo(
67
- # This is the description that will appear on the modules page.
68
- module_type="{{ cookiecutter.module_type }}",
69
- description=_DESCRIPTION,
70
  citation=_CITATION,
71
- inputs_description=_KWARGS_DESCRIPTION,
72
- # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ from collections import defaultdict
17
+ from typing import List, Dict, Tuple
18
+ from typing_extensions import TypedDict
19
+
20
  import datasets
21
+ import evaluate
22
+ import numpy as np
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM
25
+
26
+ from .prediction import Prediction
27
 
28
 
 
29
  _CITATION = """\
30
+ @inproceedings{Hu:et-al:2020,
31
+ author = {Hu, Jennifer and Gauthier, Jon and Qian, Peng and Wilcox, Ethan and Levy, Roger},
32
+ title = {A systematic assessment of syntactic generalization in neural language models},
33
+ booktitle = {Proceedings of the Association of Computational Linguistics},
34
+ year = {2020}
35
  }
36
  """
37
 
38
  # TODO: Add description of the module here
39
+ _DESCRIPTION = """
 
40
  """
41
 
42
 
43
  # TODO: Add description of the arguments of the module here
44
  _KWARGS_DESCRIPTION = """
45
+ Runs SyntaxGym evaluations on the given model and test suite.
46
  Args:
47
+ suite (Dataset): SyntaxGym test suite loaded as a Dataset.
48
+ model_id (str): model used for calculating surprisals
49
+ NOTE: The SyntaxGym evaluations are only well-defined for causal language models.
50
+ This includes models such as gpt2, causal variations of bert,
51
+ causal versions of t5, and more (the full list can be found
52
+ in the AutoModelForCausalLM documentation here:
53
+ https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
54
  Returns:
55
+ prediction_results: A list of prediction results per item. A list of lists,
56
+ one per item, containing the boolean prediction result for each
57
+ prediction in the test suite,
58
+ region_totals: A list of total surprisals for each region (nested within
59
+ condition and item). A list of dictionaries (one per item), each
60
+ mapping tuples (condition_name, region_number) to a float
61
+ total surprisal value (i.e. negative log-2 probability).
62
  Examples:
63
+ TODO
 
64
 
65
+ >>> my_new_module = evaluate.load("cpllab/syntaxgym")
66
+ >>> ...
 
 
67
  """
68
 
69
+
70
+ SUITE_DATASET_CONDITION_SPEC = {
71
+ "condition_name": datasets.Value("string"),
72
+ "content": datasets.Value("string"),
73
+ "regions": datasets.Sequence({
74
+ "region_number": datasets.Value("int32"),
75
+ "content": datasets.Value("string")
76
+ })
77
+ }
78
+
79
+
80
+ SUITE_DATASET_SPEC = {
81
+ "item_number": datasets.Value("int32"),
82
+ "conditions": datasets.Sequence(SUITE_DATASET_CONDITION_SPEC),
83
+ "predictions": datasets.Sequence(datasets.Value("string")),
84
+ }
85
+
86
+
87
+ class SyntaxGymMetricResult(TypedDict):
88
+ prediction_results: List[List[bool]]
89
+ region_totals: List[Dict[Tuple[str, int], float]]
90
 
91
 
92
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
93
+ class SyntaxGym(evaluate.EvaluationModule):
94
+ """
95
+ Defines SyntaxGym evaluation logic for causal language models.
96
+ """
97
 
98
  def _info(self):
99
+ seq = datasets.Sequence
100
+ features = datasets.Features({
101
+ "suite": SUITE_DATASET_SPEC
102
+ })
103
  return evaluate.EvaluationModuleInfo(
104
+ module_type="metric",
105
+ description="TODO",
 
106
  citation=_CITATION,
107
+ inputs_description="TODO",
108
+ features=features,
109
+ homepage="https://syntaxgym.org",
110
+ codebase_urls=["https://github.com/cpllab/syntaxgym-core"],
 
 
 
 
 
 
 
111
  )
112
 
113
+ def _compute(self, suite, model_id, device=None) -> SyntaxGymMetricResult:
114
+ if device is not None:
115
+ assert device in ["gpu", "cpu", "cuda"]
116
+ if device == "gpu":
117
+ device = "cuda"
118
+ else:
119
+ device = "cuda" if torch.cuda.is_available() else "cpu"
120
+
121
+ model = AutoModelForCausalLM.from_pretrained(model_id)
122
+ model = model.to(device)
123
+ model.eval()
124
+
125
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
126
+ # TODO copy from perplexity metric
127
+ tokenizer.pad_token = tokenizer.eos_token
128
+
129
+ results = {"prediction_results": [], "region_totals": []}
130
+ # TODO batch all items together
131
+ for item in datasets.logging.tqdm(suite):
132
+ result_single = self._compute_single(item, tokenizer, model, device)
133
+
134
+ for k in ["prediction_results", "region_totals"]:
135
+ results[k].append(result_single[k])
136
+
137
+ return results
138
+
139
+ def _compute_single(self, item, tokenizer, model, device):
140
+ tokenized = tokenizer(item["conditions"]["content"],
141
+ padding=True,
142
+ return_tensors="pt",
143
+ return_offsets_mapping=True).to(device)
144
+
145
+ # input_ids: B * T
146
+ input_ids = tokenized["input_ids"]
147
+ assert input_ids.ndim == 2
148
+
149
+ # Compute sentence level surprisals.
150
+ with torch.no_grad():
151
+ # Pre-softmax predictive distribution B * T * V
152
+ logits = model(input_ids).logits
153
+ surprisals = -logits.log_softmax(dim=2) / np.log(2)
154
+
155
+ # surprisals: B * T * V
156
+ assert surprisals.ndim == 3
157
+
158
+ # Get surprisals of expected words.
159
+ surps_shifted = surprisals[:, :-1, :]
160
+ expected_ids = input_ids[:, 1:]
161
+
162
+ # TODO: check this logic
163
+ tt = expected_ids.unsqueeze(2)
164
+ # reindexed surprisals: B * (T - 1)
165
+ surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
166
+ .squeeze(2)
167
+ # This is the original, which works but not with multiple axes in expected_ids
168
+ # surprisals = surps_shifted[range(surps_shifted.shape[0]), expected_ids]
169
+
170
+ # surprisals is now B * (T - 1)
171
+
172
+ #### aggregate
173
+ condition_names = item["conditions"]["condition_name"]
174
+ region_totals = {condition_name: defaultdict(float)
175
+ for condition_name in condition_names}
176
+ region2tokens = self.compute_region_token_mapping(
177
+ item, input_ids, tokenized["offset_mapping"])
178
+
179
+ for i, (i_cond, i_inputs) in enumerate(zip(condition_names, input_ids)):
180
+ for region_number, region_tokens in region2tokens[i_cond].items():
181
+ for token in region_tokens:
182
+ if token == 0:
183
+ # surprisal not defined. pass.
184
+ continue
185
+ elif token <= surprisals.shape[1]:
186
+ region_totals[i_cond][region_number] += surprisals[i, token - 1]
187
+ else:
188
+ # TODO don't think this is an issue, just should clean
189
+ # up the aggregation output
190
+ assert token == surprisals.shape[1], \
191
+ "%s %s" % (token, surprisals.shape[1])
192
+
193
+ region_totals = {(condition_name, region_number): float(total)
194
+ for condition_name, totals in region_totals.items()
195
+ for region_number, total in totals.items()}
196
+
197
+ results = {
198
+ "prediction_results": [
199
+ Prediction(i, formula, "sum").formula(region_totals)
200
+ for i, formula in enumerate(item["predictions"])
201
+ ],
202
+
203
+ "region_totals": region_totals
204
+ }
205
+ return results
206
+
207
+ def get_region_edges(self, item, condition_idx):
208
+ """
209
+ Get left edge of each region as a character index.
210
+ """
211
+ # NB this is coupled with `condition_to_string` logic of course
212
+
213
+ regions = item["conditions"]["regions"][condition_idx]
214
+
215
+ idx = 0
216
+ ret = []
217
+ for r_idx, region_content in enumerate(regions["content"]):
218
+ ret.append(idx)
219
+
220
+ region_size = len(region_content)
221
+ if region_content.strip() != "" and r_idx != 0 and not region_content.startswith(","):
222
+ # Add joining space
223
+ region_size += 1
224
+
225
+ idx += region_size
226
+
227
+ return ret
228
+
229
+ def compute_region_token_mapping(self, item, input_ids: torch.LongTensor,
230
+ offset_mapping: List[Tuple[int, int]]
231
+ ) -> Dict[str, Dict[int, List[int]]]:
232
+ # input_ids: B * T
233
+ # offset_mapping: B * T * 2
234
+ # assumes batch is sorted according to item's condition_name order
235
+
236
+ condition_names = item["conditions"]["condition_name"]
237
+ region2tokens = {cond: defaultdict(list) for cond in condition_names}
238
+
239
+ max_long = torch.iinfo(torch.int64).max
240
+
241
+ input_ids = input_ids.detach()
242
+ for i_cond, (i_tokens, i_offsets) in enumerate(zip(input_ids, offset_mapping)):
243
+ region_edges = self.get_region_edges(item, i_cond)
244
+
245
+ t_cursor, r_cursor = 0, 0
246
+ while t_cursor < i_tokens.shape[0]:
247
+ # token = i_tokens[t_cursor]
248
+ token_char_start, token_char_end = i_offsets[t_cursor]
249
+
250
+ if token_char_start == token_char_end == 0:
251
+ # This is a padding token. Skip.
252
+ # TODO what about BOS/EOS? some models incorporate them
253
+ t_cursor += 1
254
+ continue
255
+
256
+ region_start = region_edges[r_cursor]
257
+ region_end = region_edges[r_cursor + 1] \
258
+ if r_cursor + 1 < len(region_edges) else max_long
259
+
260
+ # NB region boundaries are left edges, hence the >= here.
261
+ if token_char_start >= region_end:
262
+ r_cursor += 1
263
+ continue
264
+
265
+ region2tokens[condition_names[i_cond]][r_cursor + 1].append(t_cursor)
266
+ t_cursor += 1
267
+
268
+ return region2tokens
test.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import datasets
4
+ import evaluate
5
+ import numpy as np
6
+
7
+ import pytest
8
+
9
+
10
+ @pytest.fixture(scope="session")
11
+ def syntaxgym_dataset():
12
+ return datasets.load_dataset("syntaxgym", "subordination_src-src")
13
+
14
+
15
+ @pytest.fixture(scope="session")
16
+ def syntaxgym_metric():
17
+ return evaluate.load("./syntaxgym.py")
18
+
19
+
20
+ @pytest.fixture(scope="session")
21
+ def model_ref():
22
+ # return "hf-internal-testing/tiny-random-gpt_neo"
23
+ return "gpt2"
24
+
25
+
26
+ # Reference region surprisals computed with syntaxgym-core.
27
+ # See notebook in https://colab.research.google.com/drive/1qziyPcu65jffizSPi-ZGHKR0x7BaHFMS#scrollTo=RgtnScy6LLKi .
28
+ GPT2_SUBORDINATION_SRC_REFERENCE = \
29
+ [{('no-sub_matrix', 1): 13.151199615123803,
30
+ ('no-sub_matrix', 2): 38.503222716703526,
31
+ ('no-sub_matrix', 3): 27.623861034812286,
32
+ ('no-sub_matrix', 4): 48.831672846038224,
33
+ ('no-sub_matrix', 5): 38.08533699286694,
34
+ ('no-sub_no-matrix', 1): 13.151199615123803,
35
+ ('no-sub_no-matrix', 2): 38.503222716703526,
36
+ ('no-sub_no-matrix', 3): 27.623861034812286,
37
+ ('no-sub_no-matrix', 4): 48.831687980511504,
38
+ ('no-sub_no-matrix', 5): 1.8096143510772873,
39
+ ('sub_matrix', 1): 14.905592916748805,
40
+ ('sub_matrix', 2): 39.06304309956175,
41
+ ('sub_matrix', 3): 26.862648365854433,
42
+ ('sub_matrix', 4): 50.56554401687938,
43
+ ('sub_matrix', 5): 26.532245572980194,
44
+ ('sub_no-matrix', 1): 14.905592916748805,
45
+ ('sub_no-matrix', 2): 39.06304309956175,
46
+ ('sub_no-matrix', 3): 26.862648365854433,
47
+ ('sub_no-matrix', 4): 50.56553438585093,
48
+ ('sub_no-matrix', 5): 7.470089829866611},
49
+ {('no-sub_matrix', 1): 10.116093820255577,
50
+ ('no-sub_matrix', 2): 20.96513246705127,
51
+ ('no-sub_matrix', 3): 20.02959138986416,
52
+ ('no-sub_matrix', 4): 23.779661397107446,
53
+ ('no-sub_matrix', 5): 33.2560281692696,
54
+ ('no-sub_no-matrix', 1): 10.116093820255577,
55
+ ('no-sub_no-matrix', 2): 20.96513246705127,
56
+ ('no-sub_no-matrix', 3): 20.02959138986416,
57
+ ('no-sub_no-matrix', 4): 23.779661397107446,
58
+ ('no-sub_no-matrix', 5): 1.9449125865631063,
59
+ ('sub_matrix', 1): 13.545157521732826,
60
+ ('sub_matrix', 2): 24.96048395897244,
61
+ ('sub_matrix', 3): 18.609464944317324,
62
+ ('sub_matrix', 4): 23.057566440062317,
63
+ ('sub_matrix', 5): 26.424454285669032,
64
+ ('sub_no-matrix', 1): 13.545157521732826,
65
+ ('sub_no-matrix', 2): 24.96048395897244,
66
+ ('sub_no-matrix', 3): 18.609464944317324,
67
+ ('sub_no-matrix', 4): 23.057566440062317,
68
+ ('sub_no-matrix', 5): 2.807467838359704},
69
+ {('no-sub_matrix', 1): 11.992867568477442,
70
+ ('no-sub_matrix', 2): 45.813114232935774,
71
+ ('no-sub_matrix', 3): 24.57554828372551,
72
+ ('no-sub_matrix', 4): 45.334025774062916,
73
+ ('no-sub_matrix', 5): 26.208189541862073,
74
+ ('no-sub_no-matrix', 1): 11.992867568477442,
75
+ ('no-sub_no-matrix', 2): 45.813114232935774,
76
+ ('no-sub_no-matrix', 3): 24.57554828372551,
77
+ ('no-sub_no-matrix', 4): 45.33402766587207,
78
+ ('no-sub_no-matrix', 5): 1.8284485151385752,
79
+ ('sub_matrix', 1): 14.219887768799735,
80
+ ('sub_matrix', 2): 46.25055434117979,
81
+ ('sub_matrix', 3): 23.054221678472672,
82
+ ('sub_matrix', 4): 47.08503858470256,
83
+ ('sub_matrix', 5): 22.154772321452022,
84
+ ('sub_no-matrix', 1): 14.219887768799735,
85
+ ('sub_no-matrix', 2): 46.25055434117979,
86
+ ('sub_no-matrix', 3): 23.054221678472672,
87
+ ('sub_no-matrix', 4): 47.08503858470256,
88
+ ('sub_no-matrix', 5): 3.0655133594366757},
89
+ {('no-sub_matrix', 1): 10.55002943802296,
90
+ ('no-sub_matrix', 2): 52.419810137608856,
91
+ ('no-sub_matrix', 3): 23.30710475332303,
92
+ ('no-sub_matrix', 4): 37.957905964008944,
93
+ ('no-sub_matrix', 5): 29.259648135104936,
94
+ ('no-sub_no-matrix', 1): 10.55002943802296,
95
+ ('no-sub_no-matrix', 2): 52.419810137608856,
96
+ ('no-sub_no-matrix', 3): 23.30710475332303,
97
+ ('no-sub_no-matrix', 4): 37.957905964008944,
98
+ ('no-sub_no-matrix', 5): 1.9632913405649093,
99
+ ('sub_matrix', 1): 15.289384584900025,
100
+ ('sub_matrix', 2): 53.93652737134243,
101
+ ('sub_matrix', 3): 19.43915835312633,
102
+ ('sub_matrix', 4): 36.459591551099386,
103
+ ('sub_matrix', 5): 22.185742699245417,
104
+ ('sub_no-matrix', 1): 15.289384584900025,
105
+ ('sub_no-matrix', 2): 53.93652737134243,
106
+ ('sub_no-matrix', 3): 19.43915835312633,
107
+ ('sub_no-matrix', 4): 36.4595598203003,
108
+ ('sub_no-matrix', 5): 5.707732355645454},
109
+ {('no-sub_matrix', 1): 23.543723213902986,
110
+ ('no-sub_matrix', 2): 31.967972102825854,
111
+ ('no-sub_matrix', 3): 29.159572978411727,
112
+ ('no-sub_matrix', 4): 36.61365345925747,
113
+ ('no-sub_matrix', 5): 44.576591305970545,
114
+ ('no-sub_no-matrix', 1): 23.543723213902986,
115
+ ('no-sub_no-matrix', 2): 31.967972102825854,
116
+ ('no-sub_no-matrix', 3): 29.159572978411727,
117
+ ('no-sub_no-matrix', 4): 36.61365345925747,
118
+ ('no-sub_no-matrix', 5): 3.2813457388593714,
119
+ ('sub_matrix', 1): 27.118410129310597,
120
+ ('sub_matrix', 2): 33.909617362987866,
121
+ ('sub_matrix', 3): 28.791166362258743,
122
+ ('sub_matrix', 4): 37.24960609010374,
123
+ ('sub_matrix', 5): 31.660933798006262,
124
+ ('sub_no-matrix', 1): 27.118410129310597,
125
+ ('sub_no-matrix', 2): 33.909617362987866,
126
+ ('sub_no-matrix', 3): 28.791166362258743,
127
+ ('sub_no-matrix', 4): 37.24960609010374,
128
+ ('sub_no-matrix', 5): 7.3613541428239015},
129
+ {('no-sub_matrix', 1): 14.22171869610082,
130
+ ('no-sub_matrix', 2): 30.270423022911977,
131
+ ('no-sub_matrix', 3): 25.973276891204705,
132
+ ('no-sub_matrix', 4): 28.43856735947716,
133
+ ('no-sub_matrix', 5): 57.39887418731055,
134
+ ('no-sub_no-matrix', 1): 14.22171869610082,
135
+ ('no-sub_no-matrix', 2): 30.270423022911977,
136
+ ('no-sub_no-matrix', 3): 25.973276891204705,
137
+ ('no-sub_no-matrix', 4): 28.43856735947716,
138
+ ('no-sub_no-matrix', 5): 1.7127059109344136,
139
+ ('sub_matrix', 1): 16.39289784951447,
140
+ ('sub_matrix', 2): 31.5671111565765,
141
+ ('sub_matrix', 3): 24.54307828171008,
142
+ ('sub_matrix', 4): 29.249645624130757,
143
+ ('sub_matrix', 5): 53.59155769093577,
144
+ ('sub_no-matrix', 1): 16.39289784951447,
145
+ ('sub_no-matrix', 2): 31.5671111565765,
146
+ ('sub_no-matrix', 3): 24.54307828171008,
147
+ ('sub_no-matrix', 4): 29.249645624130757,
148
+ ('sub_no-matrix', 5): 7.225276653947023},
149
+ {('no-sub_matrix', 1): 13.729688714733188,
150
+ ('no-sub_matrix', 2): 36.018118127225165,
151
+ ('no-sub_matrix', 3): 28.232055923783275,
152
+ ('no-sub_matrix', 4): 44.44634394296659,
153
+ ('no-sub_matrix', 5): 38.277975147059344,
154
+ ('no-sub_no-matrix', 1): 13.729688714733188,
155
+ ('no-sub_no-matrix', 2): 36.018118127225165,
156
+ ('no-sub_no-matrix', 3): 28.232055923783275,
157
+ ('no-sub_no-matrix', 4): 44.44634394296659,
158
+ ('no-sub_no-matrix', 5): 3.0318996942908414,
159
+ ('sub_matrix', 1): 16.93528744674245,
160
+ ('sub_matrix', 2): 36.545024814326574,
161
+ ('sub_matrix', 3): 26.279603445823692,
162
+ ('sub_matrix', 4): 46.501226364074995,
163
+ ('sub_matrix', 5): 32.155418057793035,
164
+ ('sub_no-matrix', 1): 16.93528744674245,
165
+ ('sub_no-matrix', 2): 36.545024814326574,
166
+ ('sub_no-matrix', 3): 26.279603445823692,
167
+ ('sub_no-matrix', 4): 46.501226364074995,
168
+ ('sub_no-matrix', 5): 4.4581122618864155},
169
+ {('no-sub_matrix', 1): 15.598113737151568,
170
+ ('no-sub_matrix', 2): 56.12543415244172,
171
+ ('no-sub_matrix', 3): 29.755667770007285,
172
+ ('no-sub_matrix', 4): 51.689282097269995,
173
+ ('no-sub_matrix', 5): 45.575230324010775,
174
+ ('no-sub_no-matrix', 1): 15.598113737151568,
175
+ ('no-sub_no-matrix', 2): 56.12543415244172,
176
+ ('no-sub_no-matrix', 3): 29.755667770007285,
177
+ ('no-sub_no-matrix', 4): 51.68928424705313,
178
+ ('no-sub_no-matrix', 5): 1.235207173694806,
179
+ ('sub_matrix', 1): 18.909088991066888,
180
+ ('sub_matrix', 2): 57.753410746636746,
181
+ ('sub_matrix', 3): 28.677667873674363,
182
+ ('sub_matrix', 4): 51.99410775929489,
183
+ ('sub_matrix', 5): 35.754144966112236,
184
+ ('sub_no-matrix', 1): 18.909088991066888,
185
+ ('sub_no-matrix', 2): 57.753410746636746,
186
+ ('sub_no-matrix', 3): 28.677667873674363,
187
+ ('sub_no-matrix', 4): 51.9941480032352,
188
+ ('sub_no-matrix', 5): 5.033266273930268},
189
+ {('no-sub_matrix', 1): 14.859413855165633,
190
+ ('no-sub_matrix', 2): 34.54519231993284,
191
+ ('no-sub_matrix', 3): 24.26528519671309,
192
+ ('no-sub_matrix', 4): 35.42343514121054,
193
+ ('no-sub_matrix', 5): 55.85308623165151,
194
+ ('no-sub_no-matrix', 1): 14.859413855165633,
195
+ ('no-sub_no-matrix', 2): 34.54519231993284,
196
+ ('no-sub_no-matrix', 3): 24.26528519671309,
197
+ ('no-sub_no-matrix', 4): 35.42343514121054,
198
+ ('no-sub_no-matrix', 5): 2.3309861205259734,
199
+ ('sub_matrix', 1): 17.053809634549854,
200
+ ('sub_matrix', 2): 33.66637542056656,
201
+ ('sub_matrix', 3): 23.26181234829638,
202
+ ('sub_matrix', 4): 35.61438567264568,
203
+ ('sub_matrix', 5): 48.48551986050014,
204
+ ('sub_no-matrix', 1): 17.053809634549854,
205
+ ('sub_no-matrix', 2): 33.66637542056656,
206
+ ('sub_no-matrix', 3): 23.26181234829638,
207
+ ('sub_no-matrix', 4): 35.61438704850689,
208
+ ('sub_no-matrix', 5): 2.969309360231736},
209
+ {('no-sub_matrix', 1): 13.708973748402064,
210
+ ('no-sub_matrix', 2): 31.147590264691182,
211
+ ('no-sub_matrix', 3): 30.495597241955565,
212
+ ('no-sub_matrix', 4): 34.65164493728535,
213
+ ('no-sub_matrix', 5): 35.87510990950117,
214
+ ('no-sub_no-matrix', 1): 13.708973748402064,
215
+ ('no-sub_no-matrix', 2): 31.147590264691182,
216
+ ('no-sub_no-matrix', 3): 30.495597241955565,
217
+ ('no-sub_no-matrix', 4): 34.65164493728535,
218
+ ('no-sub_no-matrix', 5): 3.232032121481573,
219
+ ('sub_matrix', 1): 17.681722076468287,
220
+ ('sub_matrix', 2): 33.77225997922327,
221
+ ('sub_matrix', 3): 29.435808932487806,
222
+ ('sub_matrix', 4): 34.354368969668016,
223
+ ('sub_matrix', 5): 20.802733205442486,
224
+ ('sub_no-matrix', 1): 17.681722076468287,
225
+ ('sub_no-matrix', 2): 33.77225997922327,
226
+ ('sub_no-matrix', 3): 29.435808932487806,
227
+ ('sub_no-matrix', 4): 34.354368969668016,
228
+ ('sub_no-matrix', 5): 3.7902066303710424},
229
+ {('no-sub_matrix', 1): 15.72185319065555,
230
+ ('no-sub_matrix', 2): 45.25539814380218,
231
+ ('no-sub_matrix', 3): 24.94273362957689,
232
+ ('no-sub_matrix', 4): 40.81704901026569,
233
+ ('no-sub_matrix', 5): 42.898794519499596,
234
+ ('no-sub_no-matrix', 1): 15.72185319065555,
235
+ ('no-sub_no-matrix', 2): 45.25539814380218,
236
+ ('no-sub_no-matrix', 3): 24.94273362957689,
237
+ ('no-sub_no-matrix', 4): 40.81704901026569,
238
+ ('no-sub_no-matrix', 5): 2.6826901255924644,
239
+ ('sub_matrix', 1): 17.565795106862403,
240
+ ('sub_matrix', 2): 46.9371803702329,
241
+ ('sub_matrix', 3): 23.887805807796486,
242
+ ('sub_matrix', 4): 39.058599411828766,
243
+ ('sub_matrix', 5): 32.234453544910295,
244
+ ('sub_no-matrix', 1): 17.565795106862403,
245
+ ('sub_no-matrix', 2): 46.9371803702329,
246
+ ('sub_no-matrix', 3): 23.887805807796486,
247
+ ('sub_no-matrix', 4): 39.058599411828766,
248
+ ('sub_no-matrix', 5): 4.214674259243127},
249
+ {('no-sub_matrix', 1): 13.910878628792588,
250
+ ('no-sub_matrix', 2): 33.45626834359109,
251
+ ('no-sub_matrix', 3): 16.127584513594687,
252
+ ('no-sub_matrix', 4): 32.59623120264939,
253
+ ('no-sub_matrix', 5): 29.87568851789407,
254
+ ('no-sub_no-matrix', 1): 13.910878628792588,
255
+ ('no-sub_no-matrix', 2): 33.45626834359109,
256
+ ('no-sub_no-matrix', 3): 16.127584513594687,
257
+ ('no-sub_no-matrix', 4): 32.59623120264939,
258
+ ('no-sub_no-matrix', 5): 2.3891779982892625,
259
+ ('sub_matrix', 1): 17.18981661053988,
260
+ ('sub_matrix', 2): 36.38883326650068,
261
+ ('sub_matrix', 3): 13.081088737716442,
262
+ ('sub_matrix', 4): 33.419732612590224,
263
+ ('sub_matrix', 5): 22.665485632721676,
264
+ ('sub_no-matrix', 1): 17.18981661053988,
265
+ ('sub_no-matrix', 2): 36.38883326650068,
266
+ ('sub_no-matrix', 3): 13.081088737716442,
267
+ ('sub_no-matrix', 4): 33.419732612590224,
268
+ ('sub_no-matrix', 5): 6.155199912348024},
269
+ {('no-sub_matrix', 1): 18.196771699177763,
270
+ ('no-sub_matrix', 2): 35.624058750852136,
271
+ ('no-sub_matrix', 3): 23.746554392851053,
272
+ ('no-sub_matrix', 4): 29.44669921790574,
273
+ ('no-sub_matrix', 5): 39.72412918901379,
274
+ ('no-sub_no-matrix', 1): 18.196771699177763,
275
+ ('no-sub_no-matrix', 2): 35.624058750852136,
276
+ ('no-sub_no-matrix', 3): 23.746554392851053,
277
+ ('no-sub_no-matrix', 4): 29.44669921790574,
278
+ ('no-sub_no-matrix', 5): 2.870123353843486,
279
+ ('sub_matrix', 1): 20.38619930823735,
280
+ ('sub_matrix', 2): 36.29781144853154,
281
+ ('sub_matrix', 3): 22.13637404741934,
282
+ ('sub_matrix', 4): 29.68729899086184,
283
+ ('sub_matrix', 5): 36.993790238103884,
284
+ ('sub_no-matrix', 1): 20.38619930823735,
285
+ ('sub_no-matrix', 2): 36.29781144853154,
286
+ ('sub_no-matrix', 3): 22.13637404741934,
287
+ ('sub_no-matrix', 4): 29.68729899086184,
288
+ ('sub_no-matrix', 5): 7.650303570399713},
289
+ {('no-sub_matrix', 1): 11.992867568477442,
290
+ ('no-sub_matrix', 2): 26.44083030170154,
291
+ ('no-sub_matrix', 3): 27.574921221726136,
292
+ ('no-sub_matrix', 4): 28.94213565689118,
293
+ ('no-sub_matrix', 5): 46.973469397495556,
294
+ ('no-sub_no-matrix', 1): 11.992867568477442,
295
+ ('no-sub_no-matrix', 2): 26.44083030170154,
296
+ ('no-sub_no-matrix', 3): 27.574921221726136,
297
+ ('no-sub_no-matrix', 4): 28.94213565689118,
298
+ ('no-sub_no-matrix', 5): 3.354326576753004,
299
+ ('sub_matrix', 1): 14.434047100994839,
300
+ ('sub_matrix', 2): 26.76571524620116,
301
+ ('sub_matrix', 3): 25.83488399989926,
302
+ ('sub_matrix', 4): 30.263621195061678,
303
+ ('sub_matrix', 5): 36.822532494114455,
304
+ ('sub_no-matrix', 1): 14.434047100994839,
305
+ ('sub_no-matrix', 2): 26.76571524620116,
306
+ ('sub_no-matrix', 3): 25.83488399989926,
307
+ ('sub_no-matrix', 4): 30.263621195061678,
308
+ ('sub_no-matrix', 5): 6.748976893757906},
309
+ {('no-sub_matrix', 1): 16.27614914680276,
310
+ ('no-sub_matrix', 2): 41.35282905624703,
311
+ ('no-sub_matrix', 3): 25.173115913245226,
312
+ ('no-sub_matrix', 4): 52.876981987369014,
313
+ ('no-sub_matrix', 5): 49.49767321075167,
314
+ ('no-sub_no-matrix', 1): 16.27614914680276,
315
+ ('no-sub_no-matrix', 2): 41.35282905624703,
316
+ ('no-sub_no-matrix', 3): 25.173115913245226,
317
+ ('no-sub_no-matrix', 4): 52.876981987369014,
318
+ ('no-sub_no-matrix', 5): 1.5962803636236758,
319
+ ('sub_matrix', 1): 18.735912436641787,
320
+ ('sub_matrix', 2): 43.36213985849511,
321
+ ('sub_matrix', 3): 24.582800598631913,
322
+ ('sub_matrix', 4): 53.1616607417586,
323
+ ('sub_matrix', 5): 41.2664433745972,
324
+ ('sub_no-matrix', 1): 18.735912436641787,
325
+ ('sub_no-matrix', 2): 43.36213985849511,
326
+ ('sub_no-matrix', 3): 24.582800598631913,
327
+ ('sub_no-matrix', 4): 53.16165799003619,
328
+ ('sub_no-matrix', 5): 6.4917878462822305},
329
+ {('no-sub_matrix', 1): 14.036280122634507,
330
+ ('no-sub_matrix', 2): 53.72802368862095,
331
+ ('no-sub_matrix', 3): 18.940766131564004,
332
+ ('no-sub_matrix', 4): 40.74964840745327,
333
+ ('no-sub_matrix', 5): 39.57008490907742,
334
+ ('no-sub_no-matrix', 1): 14.036280122634507,
335
+ ('no-sub_no-matrix', 2): 53.72802368862095,
336
+ ('no-sub_no-matrix', 3): 18.940766131564004,
337
+ ('no-sub_no-matrix', 4): 40.74964840745327,
338
+ ('no-sub_no-matrix', 5): 2.1275557540222967,
339
+ ('sub_matrix', 1): 19.641722357026286,
340
+ ('sub_matrix', 2): 52.709120728751486,
341
+ ('sub_matrix', 3): 17.976257844509426,
342
+ ('sub_matrix', 4): 42.51851542500959,
343
+ ('sub_matrix', 5): 28.25018664655579,
344
+ ('sub_no-matrix', 1): 19.641722357026286,
345
+ ('sub_no-matrix', 2): 52.709120728751486,
346
+ ('sub_no-matrix', 3): 17.976257844509426,
347
+ ('sub_no-matrix', 4): 42.51851267328718,
348
+ ('sub_no-matrix', 5): 5.409622788119386},
349
+ {('no-sub_matrix', 1): 16.961927903326398,
350
+ ('no-sub_matrix', 2): 38.5455951142925,
351
+ ('no-sub_matrix', 3): 25.122316709729276,
352
+ ('no-sub_matrix', 4): 35.90131439006518,
353
+ ('no-sub_matrix', 5): 41.65886977570029,
354
+ ('no-sub_no-matrix', 1): 16.961927903326398,
355
+ ('no-sub_no-matrix', 2): 38.5455951142925,
356
+ ('no-sub_no-matrix', 3): 25.122316709729276,
357
+ ('no-sub_no-matrix', 4): 35.90131439006518,
358
+ ('no-sub_no-matrix', 5): 3.2679255886472447,
359
+ ('sub_matrix', 1): 20.247934372024154,
360
+ ('sub_matrix', 2): 40.408716019775625,
361
+ ('sub_matrix', 3): 23.782735071043668,
362
+ ('sub_matrix', 4): 37.00513584758997,
363
+ ('sub_matrix', 5): 29.22700479607527,
364
+ ('sub_no-matrix', 1): 20.247934372024154,
365
+ ('sub_no-matrix', 2): 40.408716019775625,
366
+ ('sub_no-matrix', 3): 23.782735071043668,
367
+ ('sub_no-matrix', 4): 37.00513584758997,
368
+ ('sub_no-matrix', 5): 4.780011845541033},
369
+ {('no-sub_matrix', 1): 12.109815771064152,
370
+ ('no-sub_matrix', 2): 38.32406752938649,
371
+ ('no-sub_matrix', 3): 25.987801084044044,
372
+ ('no-sub_matrix', 4): 40.40950903177875,
373
+ ('no-sub_matrix', 5): 52.86522525335603,
374
+ ('no-sub_no-matrix', 1): 12.109815771064152,
375
+ ('no-sub_no-matrix', 2): 38.32406752938649,
376
+ ('no-sub_no-matrix', 3): 25.987801084044044,
377
+ ('no-sub_no-matrix', 4): 40.40950903177875,
378
+ ('no-sub_no-matrix', 5): 3.61917194787979,
379
+ ('sub_matrix', 1): 15.130341564722832,
380
+ ('sub_matrix', 2): 37.89719334728088,
381
+ ('sub_matrix', 3): 24.65681032273433,
382
+ ('sub_matrix', 4): 40.731610867030774,
383
+ ('sub_matrix', 5): 37.566910985257906,
384
+ ('sub_no-matrix', 1): 15.130341564722832,
385
+ ('sub_no-matrix', 2): 37.89719334728088,
386
+ ('sub_no-matrix', 3): 24.65681032273433,
387
+ ('sub_no-matrix', 4): 40.731610867030774,
388
+ ('sub_no-matrix', 5): 9.39736249989602},
389
+ {('no-sub_matrix', 1): 16.25058564557851,
390
+ ('no-sub_matrix', 2): 37.20405682898803,
391
+ ('no-sub_matrix', 3): 30.5107090995129,
392
+ ('no-sub_matrix', 4): 44.537084655292894,
393
+ ('no-sub_matrix', 5): 46.50046620075818,
394
+ ('no-sub_no-matrix', 1): 16.25058564557851,
395
+ ('no-sub_no-matrix', 2): 37.20405682898803,
396
+ ('no-sub_no-matrix', 3): 30.5107090995129,
397
+ ('no-sub_no-matrix', 4): 44.537084655292894,
398
+ ('no-sub_no-matrix', 5): 1.8752506698658238,
399
+ ('sub_matrix', 1): 18.440281483079957,
400
+ ('sub_matrix', 2): 38.54769605435544,
401
+ ('sub_matrix', 3): 30.510800250317864,
402
+ ('sub_matrix', 4): 44.99740645329493,
403
+ ('sub_matrix', 5): 39.55738177603457,
404
+ ('sub_no-matrix', 1): 18.440281483079957,
405
+ ('sub_no-matrix', 2): 38.54769605435544,
406
+ ('sub_no-matrix', 3): 30.510800250317864,
407
+ ('sub_no-matrix', 4): 44.99740645329493,
408
+ ('sub_no-matrix', 5): 2.6233048602148386},
409
+ {('no-sub_matrix', 1): 16.324447378609865,
410
+ ('no-sub_matrix', 2): 30.87308462806543,
411
+ ('no-sub_matrix', 3): 22.765564836381643,
412
+ ('no-sub_matrix', 4): 38.337445027901204,
413
+ ('no-sub_matrix', 5): 40.98815076599078,
414
+ ('no-sub_no-matrix', 1): 16.324447378609865,
415
+ ('no-sub_no-matrix', 2): 30.87308462806543,
416
+ ('no-sub_no-matrix', 3): 22.765564836381643,
417
+ ('no-sub_no-matrix', 4): 38.337445027901204,
418
+ ('no-sub_no-matrix', 5): 1.4796406979126138,
419
+ ('sub_matrix', 1): 17.9623592385626,
420
+ ('sub_matrix', 2): 32.36568198294609,
421
+ ('sub_matrix', 3): 22.438215466486483,
422
+ ('sub_matrix', 4): 40.900713840387546,
423
+ ('sub_matrix', 5): 33.396627340011634,
424
+ ('sub_no-matrix', 1): 17.9623592385626,
425
+ ('sub_no-matrix', 2): 32.36568198294609,
426
+ ('sub_no-matrix', 3): 22.438215466486483,
427
+ ('sub_no-matrix', 4): 40.900713840387546,
428
+ ('sub_no-matrix', 5): 6.609518913895668},
429
+ {('no-sub_matrix', 1): 14.033258731424148,
430
+ ('no-sub_matrix', 2): 28.37206528002418,
431
+ ('no-sub_matrix', 3): 27.043658386061033,
432
+ ('no-sub_matrix', 4): 36.167049513436204,
433
+ ('no-sub_matrix', 5): 52.280797076864395,
434
+ ('no-sub_no-matrix', 1): 14.033258731424148,
435
+ ('no-sub_no-matrix', 2): 28.37206528002418,
436
+ ('no-sub_no-matrix', 3): 27.043658386061033,
437
+ ('no-sub_no-matrix', 4): 36.167049513436204,
438
+ ('no-sub_no-matrix', 5): 1.9358795417918389,
439
+ ('sub_matrix', 1): 16.606623097498794,
440
+ ('sub_matrix', 2): 29.98729916366884,
441
+ ('sub_matrix', 3): 24.737985875967603,
442
+ ('sub_matrix', 4): 34.93154214402433,
443
+ ('sub_matrix', 5): 42.35241303296243,
444
+ ('sub_no-matrix', 1): 16.606623097498794,
445
+ ('sub_no-matrix', 2): 29.98729916366884,
446
+ ('sub_no-matrix', 3): 24.737985875967603,
447
+ ('sub_no-matrix', 4): 34.931551775052775,
448
+ ('sub_no-matrix', 5): 7.151971456773863},
449
+ {('no-sub_matrix', 1): 10.482293039084738,
450
+ ('no-sub_matrix', 2): 52.67861788579445,
451
+ ('no-sub_matrix', 3): 21.665543335527666,
452
+ ('no-sub_matrix', 4): 23.53727708917033,
453
+ ('no-sub_matrix', 5): 32.2645584918966,
454
+ ('no-sub_no-matrix', 1): 10.482293039084738,
455
+ ('no-sub_no-matrix', 2): 52.67861788579445,
456
+ ('no-sub_no-matrix', 3): 21.665543335527666,
457
+ ('no-sub_no-matrix', 4): 23.53727708917033,
458
+ ('no-sub_no-matrix', 5): 2.5207572809328243,
459
+ ('sub_matrix', 1): 11.523882918360123,
460
+ ('sub_matrix', 2): 57.336257883871156,
461
+ ('sub_matrix', 3): 21.647716645835132,
462
+ ('sub_matrix', 4): 23.491483569694733,
463
+ ('sub_matrix', 5): 24.264706351480406,
464
+ ('sub_no-matrix', 1): 11.523882918360123,
465
+ ('sub_no-matrix', 2): 57.336257883871156,
466
+ ('sub_no-matrix', 3): 21.647716645835132,
467
+ ('sub_no-matrix', 4): 23.491462243846026,
468
+ ('sub_no-matrix', 5): 9.714244661694366},
469
+ {('no-sub_matrix', 1): 11.992867568477442,
470
+ ('no-sub_matrix', 2): 28.861638231250264,
471
+ ('no-sub_matrix', 3): 24.222607873884137,
472
+ ('no-sub_matrix', 4): 41.28280460012173,
473
+ ('no-sub_matrix', 5): 56.6084264455065,
474
+ ('no-sub_no-matrix', 1): 11.992867568477442,
475
+ ('no-sub_no-matrix', 2): 28.861638231250264,
476
+ ('no-sub_no-matrix', 3): 24.222607873884137,
477
+ ('no-sub_no-matrix', 4): 41.28280460012173,
478
+ ('no-sub_no-matrix', 5): 2.4980576348107437,
479
+ ('sub_matrix', 1): 14.531057698832324,
480
+ ('sub_matrix', 2): 31.280393934821902,
481
+ ('sub_matrix', 3): 20.756528260470358,
482
+ ('sub_matrix', 4): 42.15937712589425,
483
+ ('sub_matrix', 5): 52.45767194621365,
484
+ ('sub_no-matrix', 1): 14.531057698832324,
485
+ ('sub_no-matrix', 2): 31.280393934821902,
486
+ ('sub_no-matrix', 3): 20.756528260470358,
487
+ ('sub_no-matrix', 4): 42.15937712589425,
488
+ ('sub_no-matrix', 5): 4.819862633503057}]
489
+
490
+
491
+ def test_gpt_subordination_region_totals():
492
+ """
493
+ Check region-level surprisals against the original syntaxgym-core
494
+ implementation, using the same underlying `gpt2` model.
495
+ """
496
+ reference = ... # TODO
497
+
498
+ # TODO work out references
499
+ dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
500
+ metric = evaluate.load("./syntaxgym.py")
501
+ result = metric.compute(suite=dataset["test"], model_id="gpt2")
502
+
503
+ from pprint import pprint
504
+ pprint(result["region_totals"][0])
505
+ pprint(GPT2_SUBORDINATION_SRC_REFERENCE[0])
506
+
507
+ keys = result["region_totals"][0].keys()
508
+ assert set(keys) == set(GPT2_SUBORDINATION_SRC_REFERENCE[0].keys())
509
+
510
+ result_ndarray = np.concatenate([np.array([region_totals[key] for key in keys])
511
+ for region_totals in result["region_totals"]])
512
+ reference_ndarray = np.concatenate([np.array([region_totals[key] for key in keys])
513
+ for region_totals in GPT2_SUBORDINATION_SRC_REFERENCE])
514
+ pprint(sorted(zip(keys, np.abs(result_ndarray - reference_ndarray)),
515
+ key=lambda x: -x[1]))
516
+ np.testing.assert_allclose(result_ndarray, reference_ndarray, atol=1e-3)