from typing import List import datasets import evaluate import numpy as np import pytest @pytest.fixture(scope="session") def syntaxgym_dataset(): return datasets.load_dataset("syntaxgym", "subordination_src-src") @pytest.fixture(scope="session") def syntaxgym_metric(): # TODO work out reference return evaluate.load("./syntaxgym.py") @pytest.fixture(scope="session") def model_ref(): # return "hf-internal-testing/tiny-random-gpt_neo" return "gpt2" # Reference region surprisals computed with syntaxgym-core. # See notebook in https://colab.research.google.com/drive/1qziyPcu65jffizSPi-ZGHKR0x7BaHFMS#scrollTo=RgtnScy6LLKi . GPT2_SUBORDINATION_SRC_REFERENCE = \ [{('no-sub_matrix', 1): 13.151199615123803, ('no-sub_matrix', 2): 38.503222716703526, ('no-sub_matrix', 3): 27.623861034812286, ('no-sub_matrix', 4): 48.831672846038224, ('no-sub_matrix', 5): 38.08533699286694, ('no-sub_no-matrix', 1): 13.151199615123803, ('no-sub_no-matrix', 2): 38.503222716703526, ('no-sub_no-matrix', 3): 27.623861034812286, ('no-sub_no-matrix', 4): 48.831687980511504, ('no-sub_no-matrix', 5): 1.8096143510772873, ('sub_matrix', 1): 14.905592916748805, ('sub_matrix', 2): 39.06304309956175, ('sub_matrix', 3): 26.862648365854433, ('sub_matrix', 4): 50.56554401687938, ('sub_matrix', 5): 26.532245572980194, ('sub_no-matrix', 1): 14.905592916748805, ('sub_no-matrix', 2): 39.06304309956175, ('sub_no-matrix', 3): 26.862648365854433, ('sub_no-matrix', 4): 50.56553438585093, ('sub_no-matrix', 5): 7.470089829866611}, {('no-sub_matrix', 1): 10.116093820255577, ('no-sub_matrix', 2): 20.96513246705127, ('no-sub_matrix', 3): 20.02959138986416, ('no-sub_matrix', 4): 23.779661397107446, ('no-sub_matrix', 5): 33.2560281692696, ('no-sub_no-matrix', 1): 10.116093820255577, ('no-sub_no-matrix', 2): 20.96513246705127, ('no-sub_no-matrix', 3): 20.02959138986416, ('no-sub_no-matrix', 4): 23.779661397107446, ('no-sub_no-matrix', 5): 1.9449125865631063, ('sub_matrix', 1): 13.545157521732826, ('sub_matrix', 2): 24.96048395897244, ('sub_matrix', 3): 18.609464944317324, ('sub_matrix', 4): 23.057566440062317, ('sub_matrix', 5): 26.424454285669032, ('sub_no-matrix', 1): 13.545157521732826, ('sub_no-matrix', 2): 24.96048395897244, ('sub_no-matrix', 3): 18.609464944317324, ('sub_no-matrix', 4): 23.057566440062317, ('sub_no-matrix', 5): 2.807467838359704}, {('no-sub_matrix', 1): 11.992867568477442, ('no-sub_matrix', 2): 45.813114232935774, ('no-sub_matrix', 3): 24.57554828372551, ('no-sub_matrix', 4): 45.334025774062916, ('no-sub_matrix', 5): 26.208189541862073, ('no-sub_no-matrix', 1): 11.992867568477442, ('no-sub_no-matrix', 2): 45.813114232935774, ('no-sub_no-matrix', 3): 24.57554828372551, ('no-sub_no-matrix', 4): 45.33402766587207, ('no-sub_no-matrix', 5): 1.8284485151385752, ('sub_matrix', 1): 14.219887768799735, ('sub_matrix', 2): 46.25055434117979, ('sub_matrix', 3): 23.054221678472672, ('sub_matrix', 4): 47.08503858470256, ('sub_matrix', 5): 22.154772321452022, ('sub_no-matrix', 1): 14.219887768799735, ('sub_no-matrix', 2): 46.25055434117979, ('sub_no-matrix', 3): 23.054221678472672, ('sub_no-matrix', 4): 47.08503858470256, ('sub_no-matrix', 5): 3.0655133594366757}, {('no-sub_matrix', 1): 10.55002943802296, ('no-sub_matrix', 2): 52.419810137608856, ('no-sub_matrix', 3): 23.30710475332303, ('no-sub_matrix', 4): 37.957905964008944, ('no-sub_matrix', 5): 29.259648135104936, ('no-sub_no-matrix', 1): 10.55002943802296, ('no-sub_no-matrix', 2): 52.419810137608856, ('no-sub_no-matrix', 3): 23.30710475332303, ('no-sub_no-matrix', 4): 37.957905964008944, ('no-sub_no-matrix', 5): 1.9632913405649093, ('sub_matrix', 1): 15.289384584900025, ('sub_matrix', 2): 53.93652737134243, ('sub_matrix', 3): 19.43915835312633, ('sub_matrix', 4): 36.459591551099386, ('sub_matrix', 5): 22.185742699245417, ('sub_no-matrix', 1): 15.289384584900025, ('sub_no-matrix', 2): 53.93652737134243, ('sub_no-matrix', 3): 19.43915835312633, ('sub_no-matrix', 4): 36.4595598203003, ('sub_no-matrix', 5): 5.707732355645454}, {('no-sub_matrix', 1): 23.543723213902986, ('no-sub_matrix', 2): 31.967972102825854, ('no-sub_matrix', 3): 29.159572978411727, ('no-sub_matrix', 4): 36.61365345925747, ('no-sub_matrix', 5): 44.576591305970545, ('no-sub_no-matrix', 1): 23.543723213902986, ('no-sub_no-matrix', 2): 31.967972102825854, ('no-sub_no-matrix', 3): 29.159572978411727, ('no-sub_no-matrix', 4): 36.61365345925747, ('no-sub_no-matrix', 5): 3.2813457388593714, ('sub_matrix', 1): 27.118410129310597, ('sub_matrix', 2): 33.909617362987866, ('sub_matrix', 3): 28.791166362258743, ('sub_matrix', 4): 37.24960609010374, ('sub_matrix', 5): 31.660933798006262, ('sub_no-matrix', 1): 27.118410129310597, ('sub_no-matrix', 2): 33.909617362987866, ('sub_no-matrix', 3): 28.791166362258743, ('sub_no-matrix', 4): 37.24960609010374, ('sub_no-matrix', 5): 7.3613541428239015}, {('no-sub_matrix', 1): 14.22171869610082, ('no-sub_matrix', 2): 30.270423022911977, ('no-sub_matrix', 3): 25.973276891204705, ('no-sub_matrix', 4): 28.43856735947716, ('no-sub_matrix', 5): 57.39887418731055, ('no-sub_no-matrix', 1): 14.22171869610082, ('no-sub_no-matrix', 2): 30.270423022911977, ('no-sub_no-matrix', 3): 25.973276891204705, ('no-sub_no-matrix', 4): 28.43856735947716, ('no-sub_no-matrix', 5): 1.7127059109344136, ('sub_matrix', 1): 16.39289784951447, ('sub_matrix', 2): 31.5671111565765, ('sub_matrix', 3): 24.54307828171008, ('sub_matrix', 4): 29.249645624130757, ('sub_matrix', 5): 53.59155769093577, ('sub_no-matrix', 1): 16.39289784951447, ('sub_no-matrix', 2): 31.5671111565765, ('sub_no-matrix', 3): 24.54307828171008, ('sub_no-matrix', 4): 29.249645624130757, ('sub_no-matrix', 5): 7.225276653947023}, {('no-sub_matrix', 1): 13.729688714733188, ('no-sub_matrix', 2): 36.018118127225165, ('no-sub_matrix', 3): 28.232055923783275, ('no-sub_matrix', 4): 44.44634394296659, ('no-sub_matrix', 5): 38.277975147059344, ('no-sub_no-matrix', 1): 13.729688714733188, ('no-sub_no-matrix', 2): 36.018118127225165, ('no-sub_no-matrix', 3): 28.232055923783275, ('no-sub_no-matrix', 4): 44.44634394296659, ('no-sub_no-matrix', 5): 3.0318996942908414, ('sub_matrix', 1): 16.93528744674245, ('sub_matrix', 2): 36.545024814326574, ('sub_matrix', 3): 26.279603445823692, ('sub_matrix', 4): 46.501226364074995, ('sub_matrix', 5): 32.155418057793035, ('sub_no-matrix', 1): 16.93528744674245, ('sub_no-matrix', 2): 36.545024814326574, ('sub_no-matrix', 3): 26.279603445823692, ('sub_no-matrix', 4): 46.501226364074995, ('sub_no-matrix', 5): 4.4581122618864155}, {('no-sub_matrix', 1): 15.598113737151568, ('no-sub_matrix', 2): 56.12543415244172, ('no-sub_matrix', 3): 29.755667770007285, ('no-sub_matrix', 4): 51.689282097269995, ('no-sub_matrix', 5): 45.575230324010775, ('no-sub_no-matrix', 1): 15.598113737151568, ('no-sub_no-matrix', 2): 56.12543415244172, ('no-sub_no-matrix', 3): 29.755667770007285, ('no-sub_no-matrix', 4): 51.68928424705313, ('no-sub_no-matrix', 5): 1.235207173694806, ('sub_matrix', 1): 18.909088991066888, ('sub_matrix', 2): 57.753410746636746, ('sub_matrix', 3): 28.677667873674363, ('sub_matrix', 4): 51.99410775929489, ('sub_matrix', 5): 35.754144966112236, ('sub_no-matrix', 1): 18.909088991066888, ('sub_no-matrix', 2): 57.753410746636746, ('sub_no-matrix', 3): 28.677667873674363, ('sub_no-matrix', 4): 51.9941480032352, ('sub_no-matrix', 5): 5.033266273930268}, {('no-sub_matrix', 1): 14.859413855165633, ('no-sub_matrix', 2): 34.54519231993284, ('no-sub_matrix', 3): 24.26528519671309, ('no-sub_matrix', 4): 35.42343514121054, ('no-sub_matrix', 5): 55.85308623165151, ('no-sub_no-matrix', 1): 14.859413855165633, ('no-sub_no-matrix', 2): 34.54519231993284, ('no-sub_no-matrix', 3): 24.26528519671309, ('no-sub_no-matrix', 4): 35.42343514121054, ('no-sub_no-matrix', 5): 2.3309861205259734, ('sub_matrix', 1): 17.053809634549854, ('sub_matrix', 2): 33.66637542056656, ('sub_matrix', 3): 23.26181234829638, ('sub_matrix', 4): 35.61438567264568, ('sub_matrix', 5): 48.48551986050014, ('sub_no-matrix', 1): 17.053809634549854, ('sub_no-matrix', 2): 33.66637542056656, ('sub_no-matrix', 3): 23.26181234829638, ('sub_no-matrix', 4): 35.61438704850689, ('sub_no-matrix', 5): 2.969309360231736}, {('no-sub_matrix', 1): 13.708973748402064, ('no-sub_matrix', 2): 31.147590264691182, ('no-sub_matrix', 3): 30.495597241955565, ('no-sub_matrix', 4): 34.65164493728535, ('no-sub_matrix', 5): 35.87510990950117, ('no-sub_no-matrix', 1): 13.708973748402064, ('no-sub_no-matrix', 2): 31.147590264691182, ('no-sub_no-matrix', 3): 30.495597241955565, ('no-sub_no-matrix', 4): 34.65164493728535, ('no-sub_no-matrix', 5): 3.232032121481573, ('sub_matrix', 1): 17.681722076468287, ('sub_matrix', 2): 33.77225997922327, ('sub_matrix', 3): 29.435808932487806, ('sub_matrix', 4): 34.354368969668016, ('sub_matrix', 5): 20.802733205442486, ('sub_no-matrix', 1): 17.681722076468287, ('sub_no-matrix', 2): 33.77225997922327, ('sub_no-matrix', 3): 29.435808932487806, ('sub_no-matrix', 4): 34.354368969668016, ('sub_no-matrix', 5): 3.7902066303710424}, {('no-sub_matrix', 1): 15.72185319065555, ('no-sub_matrix', 2): 45.25539814380218, ('no-sub_matrix', 3): 24.94273362957689, ('no-sub_matrix', 4): 40.81704901026569, ('no-sub_matrix', 5): 42.898794519499596, ('no-sub_no-matrix', 1): 15.72185319065555, ('no-sub_no-matrix', 2): 45.25539814380218, ('no-sub_no-matrix', 3): 24.94273362957689, ('no-sub_no-matrix', 4): 40.81704901026569, ('no-sub_no-matrix', 5): 2.6826901255924644, ('sub_matrix', 1): 17.565795106862403, ('sub_matrix', 2): 46.9371803702329, ('sub_matrix', 3): 23.887805807796486, ('sub_matrix', 4): 39.058599411828766, ('sub_matrix', 5): 32.234453544910295, ('sub_no-matrix', 1): 17.565795106862403, ('sub_no-matrix', 2): 46.9371803702329, ('sub_no-matrix', 3): 23.887805807796486, ('sub_no-matrix', 4): 39.058599411828766, ('sub_no-matrix', 5): 4.214674259243127}, {('no-sub_matrix', 1): 13.910878628792588, ('no-sub_matrix', 2): 33.45626834359109, ('no-sub_matrix', 3): 16.127584513594687, ('no-sub_matrix', 4): 32.59623120264939, ('no-sub_matrix', 5): 29.87568851789407, ('no-sub_no-matrix', 1): 13.910878628792588, ('no-sub_no-matrix', 2): 33.45626834359109, ('no-sub_no-matrix', 3): 16.127584513594687, ('no-sub_no-matrix', 4): 32.59623120264939, ('no-sub_no-matrix', 5): 2.3891779982892625, ('sub_matrix', 1): 17.18981661053988, ('sub_matrix', 2): 36.38883326650068, ('sub_matrix', 3): 13.081088737716442, ('sub_matrix', 4): 33.419732612590224, ('sub_matrix', 5): 22.665485632721676, ('sub_no-matrix', 1): 17.18981661053988, ('sub_no-matrix', 2): 36.38883326650068, ('sub_no-matrix', 3): 13.081088737716442, ('sub_no-matrix', 4): 33.419732612590224, ('sub_no-matrix', 5): 6.155199912348024}, {('no-sub_matrix', 1): 18.196771699177763, ('no-sub_matrix', 2): 35.624058750852136, ('no-sub_matrix', 3): 23.746554392851053, ('no-sub_matrix', 4): 29.44669921790574, ('no-sub_matrix', 5): 39.72412918901379, ('no-sub_no-matrix', 1): 18.196771699177763, ('no-sub_no-matrix', 2): 35.624058750852136, ('no-sub_no-matrix', 3): 23.746554392851053, ('no-sub_no-matrix', 4): 29.44669921790574, ('no-sub_no-matrix', 5): 2.870123353843486, ('sub_matrix', 1): 20.38619930823735, ('sub_matrix', 2): 36.29781144853154, ('sub_matrix', 3): 22.13637404741934, ('sub_matrix', 4): 29.68729899086184, ('sub_matrix', 5): 36.993790238103884, ('sub_no-matrix', 1): 20.38619930823735, ('sub_no-matrix', 2): 36.29781144853154, ('sub_no-matrix', 3): 22.13637404741934, ('sub_no-matrix', 4): 29.68729899086184, ('sub_no-matrix', 5): 7.650303570399713}, {('no-sub_matrix', 1): 11.992867568477442, ('no-sub_matrix', 2): 26.44083030170154, ('no-sub_matrix', 3): 27.574921221726136, ('no-sub_matrix', 4): 28.94213565689118, ('no-sub_matrix', 5): 46.973469397495556, ('no-sub_no-matrix', 1): 11.992867568477442, ('no-sub_no-matrix', 2): 26.44083030170154, ('no-sub_no-matrix', 3): 27.574921221726136, ('no-sub_no-matrix', 4): 28.94213565689118, ('no-sub_no-matrix', 5): 3.354326576753004, ('sub_matrix', 1): 14.434047100994839, ('sub_matrix', 2): 26.76571524620116, ('sub_matrix', 3): 25.83488399989926, ('sub_matrix', 4): 30.263621195061678, ('sub_matrix', 5): 36.822532494114455, ('sub_no-matrix', 1): 14.434047100994839, ('sub_no-matrix', 2): 26.76571524620116, ('sub_no-matrix', 3): 25.83488399989926, ('sub_no-matrix', 4): 30.263621195061678, ('sub_no-matrix', 5): 6.748976893757906}, {('no-sub_matrix', 1): 16.27614914680276, ('no-sub_matrix', 2): 41.35282905624703, ('no-sub_matrix', 3): 25.173115913245226, ('no-sub_matrix', 4): 52.876981987369014, ('no-sub_matrix', 5): 49.49767321075167, ('no-sub_no-matrix', 1): 16.27614914680276, ('no-sub_no-matrix', 2): 41.35282905624703, ('no-sub_no-matrix', 3): 25.173115913245226, ('no-sub_no-matrix', 4): 52.876981987369014, ('no-sub_no-matrix', 5): 1.5962803636236758, ('sub_matrix', 1): 18.735912436641787, ('sub_matrix', 2): 43.36213985849511, ('sub_matrix', 3): 24.582800598631913, ('sub_matrix', 4): 53.1616607417586, ('sub_matrix', 5): 41.2664433745972, ('sub_no-matrix', 1): 18.735912436641787, ('sub_no-matrix', 2): 43.36213985849511, ('sub_no-matrix', 3): 24.582800598631913, ('sub_no-matrix', 4): 53.16165799003619, ('sub_no-matrix', 5): 6.4917878462822305}, {('no-sub_matrix', 1): 14.036280122634507, ('no-sub_matrix', 2): 53.72802368862095, ('no-sub_matrix', 3): 18.940766131564004, ('no-sub_matrix', 4): 40.74964840745327, ('no-sub_matrix', 5): 39.57008490907742, ('no-sub_no-matrix', 1): 14.036280122634507, ('no-sub_no-matrix', 2): 53.72802368862095, ('no-sub_no-matrix', 3): 18.940766131564004, ('no-sub_no-matrix', 4): 40.74964840745327, ('no-sub_no-matrix', 5): 2.1275557540222967, ('sub_matrix', 1): 19.641722357026286, ('sub_matrix', 2): 52.709120728751486, ('sub_matrix', 3): 17.976257844509426, ('sub_matrix', 4): 42.51851542500959, ('sub_matrix', 5): 28.25018664655579, ('sub_no-matrix', 1): 19.641722357026286, ('sub_no-matrix', 2): 52.709120728751486, ('sub_no-matrix', 3): 17.976257844509426, ('sub_no-matrix', 4): 42.51851267328718, ('sub_no-matrix', 5): 5.409622788119386}, {('no-sub_matrix', 1): 16.961927903326398, ('no-sub_matrix', 2): 38.5455951142925, ('no-sub_matrix', 3): 25.122316709729276, ('no-sub_matrix', 4): 35.90131439006518, ('no-sub_matrix', 5): 41.65886977570029, ('no-sub_no-matrix', 1): 16.961927903326398, ('no-sub_no-matrix', 2): 38.5455951142925, ('no-sub_no-matrix', 3): 25.122316709729276, ('no-sub_no-matrix', 4): 35.90131439006518, ('no-sub_no-matrix', 5): 3.2679255886472447, ('sub_matrix', 1): 20.247934372024154, ('sub_matrix', 2): 40.408716019775625, ('sub_matrix', 3): 23.782735071043668, ('sub_matrix', 4): 37.00513584758997, ('sub_matrix', 5): 29.22700479607527, ('sub_no-matrix', 1): 20.247934372024154, ('sub_no-matrix', 2): 40.408716019775625, ('sub_no-matrix', 3): 23.782735071043668, ('sub_no-matrix', 4): 37.00513584758997, ('sub_no-matrix', 5): 4.780011845541033}, {('no-sub_matrix', 1): 12.109815771064152, ('no-sub_matrix', 2): 38.32406752938649, ('no-sub_matrix', 3): 25.987801084044044, ('no-sub_matrix', 4): 40.40950903177875, ('no-sub_matrix', 5): 52.86522525335603, ('no-sub_no-matrix', 1): 12.109815771064152, ('no-sub_no-matrix', 2): 38.32406752938649, ('no-sub_no-matrix', 3): 25.987801084044044, ('no-sub_no-matrix', 4): 40.40950903177875, ('no-sub_no-matrix', 5): 3.61917194787979, ('sub_matrix', 1): 15.130341564722832, ('sub_matrix', 2): 37.89719334728088, ('sub_matrix', 3): 24.65681032273433, ('sub_matrix', 4): 40.731610867030774, ('sub_matrix', 5): 37.566910985257906, ('sub_no-matrix', 1): 15.130341564722832, ('sub_no-matrix', 2): 37.89719334728088, ('sub_no-matrix', 3): 24.65681032273433, ('sub_no-matrix', 4): 40.731610867030774, ('sub_no-matrix', 5): 9.39736249989602}, {('no-sub_matrix', 1): 16.25058564557851, ('no-sub_matrix', 2): 37.20405682898803, ('no-sub_matrix', 3): 30.5107090995129, ('no-sub_matrix', 4): 44.537084655292894, ('no-sub_matrix', 5): 46.50046620075818, ('no-sub_no-matrix', 1): 16.25058564557851, ('no-sub_no-matrix', 2): 37.20405682898803, ('no-sub_no-matrix', 3): 30.5107090995129, ('no-sub_no-matrix', 4): 44.537084655292894, ('no-sub_no-matrix', 5): 1.8752506698658238, ('sub_matrix', 1): 18.440281483079957, ('sub_matrix', 2): 38.54769605435544, ('sub_matrix', 3): 30.510800250317864, ('sub_matrix', 4): 44.99740645329493, ('sub_matrix', 5): 39.55738177603457, ('sub_no-matrix', 1): 18.440281483079957, ('sub_no-matrix', 2): 38.54769605435544, ('sub_no-matrix', 3): 30.510800250317864, ('sub_no-matrix', 4): 44.99740645329493, ('sub_no-matrix', 5): 2.6233048602148386}, {('no-sub_matrix', 1): 16.324447378609865, ('no-sub_matrix', 2): 30.87308462806543, ('no-sub_matrix', 3): 22.765564836381643, ('no-sub_matrix', 4): 38.337445027901204, ('no-sub_matrix', 5): 40.98815076599078, ('no-sub_no-matrix', 1): 16.324447378609865, ('no-sub_no-matrix', 2): 30.87308462806543, ('no-sub_no-matrix', 3): 22.765564836381643, ('no-sub_no-matrix', 4): 38.337445027901204, ('no-sub_no-matrix', 5): 1.4796406979126138, ('sub_matrix', 1): 17.9623592385626, ('sub_matrix', 2): 32.36568198294609, ('sub_matrix', 3): 22.438215466486483, ('sub_matrix', 4): 40.900713840387546, ('sub_matrix', 5): 33.396627340011634, ('sub_no-matrix', 1): 17.9623592385626, ('sub_no-matrix', 2): 32.36568198294609, ('sub_no-matrix', 3): 22.438215466486483, ('sub_no-matrix', 4): 40.900713840387546, ('sub_no-matrix', 5): 6.609518913895668}, {('no-sub_matrix', 1): 14.033258731424148, ('no-sub_matrix', 2): 28.37206528002418, ('no-sub_matrix', 3): 27.043658386061033, ('no-sub_matrix', 4): 36.167049513436204, ('no-sub_matrix', 5): 52.280797076864395, ('no-sub_no-matrix', 1): 14.033258731424148, ('no-sub_no-matrix', 2): 28.37206528002418, ('no-sub_no-matrix', 3): 27.043658386061033, ('no-sub_no-matrix', 4): 36.167049513436204, ('no-sub_no-matrix', 5): 1.9358795417918389, ('sub_matrix', 1): 16.606623097498794, ('sub_matrix', 2): 29.98729916366884, ('sub_matrix', 3): 24.737985875967603, ('sub_matrix', 4): 34.93154214402433, ('sub_matrix', 5): 42.35241303296243, ('sub_no-matrix', 1): 16.606623097498794, ('sub_no-matrix', 2): 29.98729916366884, ('sub_no-matrix', 3): 24.737985875967603, ('sub_no-matrix', 4): 34.931551775052775, ('sub_no-matrix', 5): 7.151971456773863}, {('no-sub_matrix', 1): 10.482293039084738, ('no-sub_matrix', 2): 52.67861788579445, ('no-sub_matrix', 3): 21.665543335527666, ('no-sub_matrix', 4): 23.53727708917033, ('no-sub_matrix', 5): 32.2645584918966, ('no-sub_no-matrix', 1): 10.482293039084738, ('no-sub_no-matrix', 2): 52.67861788579445, ('no-sub_no-matrix', 3): 21.665543335527666, ('no-sub_no-matrix', 4): 23.53727708917033, ('no-sub_no-matrix', 5): 2.5207572809328243, ('sub_matrix', 1): 11.523882918360123, ('sub_matrix', 2): 57.336257883871156, ('sub_matrix', 3): 21.647716645835132, ('sub_matrix', 4): 23.491483569694733, ('sub_matrix', 5): 24.264706351480406, ('sub_no-matrix', 1): 11.523882918360123, ('sub_no-matrix', 2): 57.336257883871156, ('sub_no-matrix', 3): 21.647716645835132, ('sub_no-matrix', 4): 23.491462243846026, ('sub_no-matrix', 5): 9.714244661694366}, {('no-sub_matrix', 1): 11.992867568477442, ('no-sub_matrix', 2): 28.861638231250264, ('no-sub_matrix', 3): 24.222607873884137, ('no-sub_matrix', 4): 41.28280460012173, ('no-sub_matrix', 5): 56.6084264455065, ('no-sub_no-matrix', 1): 11.992867568477442, ('no-sub_no-matrix', 2): 28.861638231250264, ('no-sub_no-matrix', 3): 24.222607873884137, ('no-sub_no-matrix', 4): 41.28280460012173, ('no-sub_no-matrix', 5): 2.4980576348107437, ('sub_matrix', 1): 14.531057698832324, ('sub_matrix', 2): 31.280393934821902, ('sub_matrix', 3): 20.756528260470358, ('sub_matrix', 4): 42.15937712589425, ('sub_matrix', 5): 52.45767194621365, ('sub_no-matrix', 1): 14.531057698832324, ('sub_no-matrix', 2): 31.280393934821902, ('sub_no-matrix', 3): 20.756528260470358, ('sub_no-matrix', 4): 42.15937712589425, ('sub_no-matrix', 5): 4.819862633503057}] def test_gpt_subordination_region_totals(syntaxgym_metric): """ Check region-level surprisals against the original syntaxgym-core implementation, using the same underlying `gpt2` model. """ suite_name = "subordination_src-src" dataset = datasets.load_dataset("cpllab/syntaxgym", suite_name) result = syntaxgym_metric.compute(dataset=dataset["test"], model_id="gpt2") region_totals = result[suite_name].region_totals from pprint import pprint pprint(region_totals[0]) pprint(GPT2_SUBORDINATION_SRC_REFERENCE[0]) keys = region_totals[0].keys() assert set(keys) == set(GPT2_SUBORDINATION_SRC_REFERENCE[0].keys()) result_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys]) for region_totals_i in region_totals]) reference_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys]) for region_totals_i in GPT2_SUBORDINATION_SRC_REFERENCE]) pprint(sorted(zip(keys, np.abs(result_ndarray - reference_ndarray)), key=lambda x: -x[1])) np.testing.assert_allclose(result_ndarray, reference_ndarray, atol=1e-3) def test_evaluation_all_vs_single(syntaxgym_metric): """ Check that a suite's performance is the same when evaluated in the composite benchmark vs. evaluated independently. """ suite_name = "number_prep" full_dataset = datasets.load_dataset("cpllab/syntaxgym") sub_dataset = datasets.load_dataset("cpllab/syntaxgym", suite_name) model_id = "hf-internal-testing/tiny-xlm-roberta" full_result = syntaxgym_metric.compute(dataset=full_dataset["test"], model_id=model_id) sub_result = syntaxgym_metric.compute(dataset=sub_dataset["test"], model_id=model_id) assert full_result[suite_name].prediction_results == sub_result[suite_name].prediction_results