syntaxgym / test /test_syntaxgym.py
Jon Gauthier
refactor metric to support evaluating `all-2020` split
8cca3d0
from typing import List
import datasets
import evaluate
import numpy as np
import pytest
@pytest.fixture(scope="session")
def syntaxgym_dataset():
return datasets.load_dataset("syntaxgym", "subordination_src-src")
@pytest.fixture(scope="session")
def syntaxgym_metric():
# TODO work out reference
return evaluate.load("./syntaxgym.py")
@pytest.fixture(scope="session")
def model_ref():
# return "hf-internal-testing/tiny-random-gpt_neo"
return "gpt2"
# Reference region surprisals computed with syntaxgym-core.
# See notebook in https://colab.research.google.com/drive/1qziyPcu65jffizSPi-ZGHKR0x7BaHFMS#scrollTo=RgtnScy6LLKi .
GPT2_SUBORDINATION_SRC_REFERENCE = \
[{('no-sub_matrix', 1): 13.151199615123803,
('no-sub_matrix', 2): 38.503222716703526,
('no-sub_matrix', 3): 27.623861034812286,
('no-sub_matrix', 4): 48.831672846038224,
('no-sub_matrix', 5): 38.08533699286694,
('no-sub_no-matrix', 1): 13.151199615123803,
('no-sub_no-matrix', 2): 38.503222716703526,
('no-sub_no-matrix', 3): 27.623861034812286,
('no-sub_no-matrix', 4): 48.831687980511504,
('no-sub_no-matrix', 5): 1.8096143510772873,
('sub_matrix', 1): 14.905592916748805,
('sub_matrix', 2): 39.06304309956175,
('sub_matrix', 3): 26.862648365854433,
('sub_matrix', 4): 50.56554401687938,
('sub_matrix', 5): 26.532245572980194,
('sub_no-matrix', 1): 14.905592916748805,
('sub_no-matrix', 2): 39.06304309956175,
('sub_no-matrix', 3): 26.862648365854433,
('sub_no-matrix', 4): 50.56553438585093,
('sub_no-matrix', 5): 7.470089829866611},
{('no-sub_matrix', 1): 10.116093820255577,
('no-sub_matrix', 2): 20.96513246705127,
('no-sub_matrix', 3): 20.02959138986416,
('no-sub_matrix', 4): 23.779661397107446,
('no-sub_matrix', 5): 33.2560281692696,
('no-sub_no-matrix', 1): 10.116093820255577,
('no-sub_no-matrix', 2): 20.96513246705127,
('no-sub_no-matrix', 3): 20.02959138986416,
('no-sub_no-matrix', 4): 23.779661397107446,
('no-sub_no-matrix', 5): 1.9449125865631063,
('sub_matrix', 1): 13.545157521732826,
('sub_matrix', 2): 24.96048395897244,
('sub_matrix', 3): 18.609464944317324,
('sub_matrix', 4): 23.057566440062317,
('sub_matrix', 5): 26.424454285669032,
('sub_no-matrix', 1): 13.545157521732826,
('sub_no-matrix', 2): 24.96048395897244,
('sub_no-matrix', 3): 18.609464944317324,
('sub_no-matrix', 4): 23.057566440062317,
('sub_no-matrix', 5): 2.807467838359704},
{('no-sub_matrix', 1): 11.992867568477442,
('no-sub_matrix', 2): 45.813114232935774,
('no-sub_matrix', 3): 24.57554828372551,
('no-sub_matrix', 4): 45.334025774062916,
('no-sub_matrix', 5): 26.208189541862073,
('no-sub_no-matrix', 1): 11.992867568477442,
('no-sub_no-matrix', 2): 45.813114232935774,
('no-sub_no-matrix', 3): 24.57554828372551,
('no-sub_no-matrix', 4): 45.33402766587207,
('no-sub_no-matrix', 5): 1.8284485151385752,
('sub_matrix', 1): 14.219887768799735,
('sub_matrix', 2): 46.25055434117979,
('sub_matrix', 3): 23.054221678472672,
('sub_matrix', 4): 47.08503858470256,
('sub_matrix', 5): 22.154772321452022,
('sub_no-matrix', 1): 14.219887768799735,
('sub_no-matrix', 2): 46.25055434117979,
('sub_no-matrix', 3): 23.054221678472672,
('sub_no-matrix', 4): 47.08503858470256,
('sub_no-matrix', 5): 3.0655133594366757},
{('no-sub_matrix', 1): 10.55002943802296,
('no-sub_matrix', 2): 52.419810137608856,
('no-sub_matrix', 3): 23.30710475332303,
('no-sub_matrix', 4): 37.957905964008944,
('no-sub_matrix', 5): 29.259648135104936,
('no-sub_no-matrix', 1): 10.55002943802296,
('no-sub_no-matrix', 2): 52.419810137608856,
('no-sub_no-matrix', 3): 23.30710475332303,
('no-sub_no-matrix', 4): 37.957905964008944,
('no-sub_no-matrix', 5): 1.9632913405649093,
('sub_matrix', 1): 15.289384584900025,
('sub_matrix', 2): 53.93652737134243,
('sub_matrix', 3): 19.43915835312633,
('sub_matrix', 4): 36.459591551099386,
('sub_matrix', 5): 22.185742699245417,
('sub_no-matrix', 1): 15.289384584900025,
('sub_no-matrix', 2): 53.93652737134243,
('sub_no-matrix', 3): 19.43915835312633,
('sub_no-matrix', 4): 36.4595598203003,
('sub_no-matrix', 5): 5.707732355645454},
{('no-sub_matrix', 1): 23.543723213902986,
('no-sub_matrix', 2): 31.967972102825854,
('no-sub_matrix', 3): 29.159572978411727,
('no-sub_matrix', 4): 36.61365345925747,
('no-sub_matrix', 5): 44.576591305970545,
('no-sub_no-matrix', 1): 23.543723213902986,
('no-sub_no-matrix', 2): 31.967972102825854,
('no-sub_no-matrix', 3): 29.159572978411727,
('no-sub_no-matrix', 4): 36.61365345925747,
('no-sub_no-matrix', 5): 3.2813457388593714,
('sub_matrix', 1): 27.118410129310597,
('sub_matrix', 2): 33.909617362987866,
('sub_matrix', 3): 28.791166362258743,
('sub_matrix', 4): 37.24960609010374,
('sub_matrix', 5): 31.660933798006262,
('sub_no-matrix', 1): 27.118410129310597,
('sub_no-matrix', 2): 33.909617362987866,
('sub_no-matrix', 3): 28.791166362258743,
('sub_no-matrix', 4): 37.24960609010374,
('sub_no-matrix', 5): 7.3613541428239015},
{('no-sub_matrix', 1): 14.22171869610082,
('no-sub_matrix', 2): 30.270423022911977,
('no-sub_matrix', 3): 25.973276891204705,
('no-sub_matrix', 4): 28.43856735947716,
('no-sub_matrix', 5): 57.39887418731055,
('no-sub_no-matrix', 1): 14.22171869610082,
('no-sub_no-matrix', 2): 30.270423022911977,
('no-sub_no-matrix', 3): 25.973276891204705,
('no-sub_no-matrix', 4): 28.43856735947716,
('no-sub_no-matrix', 5): 1.7127059109344136,
('sub_matrix', 1): 16.39289784951447,
('sub_matrix', 2): 31.5671111565765,
('sub_matrix', 3): 24.54307828171008,
('sub_matrix', 4): 29.249645624130757,
('sub_matrix', 5): 53.59155769093577,
('sub_no-matrix', 1): 16.39289784951447,
('sub_no-matrix', 2): 31.5671111565765,
('sub_no-matrix', 3): 24.54307828171008,
('sub_no-matrix', 4): 29.249645624130757,
('sub_no-matrix', 5): 7.225276653947023},
{('no-sub_matrix', 1): 13.729688714733188,
('no-sub_matrix', 2): 36.018118127225165,
('no-sub_matrix', 3): 28.232055923783275,
('no-sub_matrix', 4): 44.44634394296659,
('no-sub_matrix', 5): 38.277975147059344,
('no-sub_no-matrix', 1): 13.729688714733188,
('no-sub_no-matrix', 2): 36.018118127225165,
('no-sub_no-matrix', 3): 28.232055923783275,
('no-sub_no-matrix', 4): 44.44634394296659,
('no-sub_no-matrix', 5): 3.0318996942908414,
('sub_matrix', 1): 16.93528744674245,
('sub_matrix', 2): 36.545024814326574,
('sub_matrix', 3): 26.279603445823692,
('sub_matrix', 4): 46.501226364074995,
('sub_matrix', 5): 32.155418057793035,
('sub_no-matrix', 1): 16.93528744674245,
('sub_no-matrix', 2): 36.545024814326574,
('sub_no-matrix', 3): 26.279603445823692,
('sub_no-matrix', 4): 46.501226364074995,
('sub_no-matrix', 5): 4.4581122618864155},
{('no-sub_matrix', 1): 15.598113737151568,
('no-sub_matrix', 2): 56.12543415244172,
('no-sub_matrix', 3): 29.755667770007285,
('no-sub_matrix', 4): 51.689282097269995,
('no-sub_matrix', 5): 45.575230324010775,
('no-sub_no-matrix', 1): 15.598113737151568,
('no-sub_no-matrix', 2): 56.12543415244172,
('no-sub_no-matrix', 3): 29.755667770007285,
('no-sub_no-matrix', 4): 51.68928424705313,
('no-sub_no-matrix', 5): 1.235207173694806,
('sub_matrix', 1): 18.909088991066888,
('sub_matrix', 2): 57.753410746636746,
('sub_matrix', 3): 28.677667873674363,
('sub_matrix', 4): 51.99410775929489,
('sub_matrix', 5): 35.754144966112236,
('sub_no-matrix', 1): 18.909088991066888,
('sub_no-matrix', 2): 57.753410746636746,
('sub_no-matrix', 3): 28.677667873674363,
('sub_no-matrix', 4): 51.9941480032352,
('sub_no-matrix', 5): 5.033266273930268},
{('no-sub_matrix', 1): 14.859413855165633,
('no-sub_matrix', 2): 34.54519231993284,
('no-sub_matrix', 3): 24.26528519671309,
('no-sub_matrix', 4): 35.42343514121054,
('no-sub_matrix', 5): 55.85308623165151,
('no-sub_no-matrix', 1): 14.859413855165633,
('no-sub_no-matrix', 2): 34.54519231993284,
('no-sub_no-matrix', 3): 24.26528519671309,
('no-sub_no-matrix', 4): 35.42343514121054,
('no-sub_no-matrix', 5): 2.3309861205259734,
('sub_matrix', 1): 17.053809634549854,
('sub_matrix', 2): 33.66637542056656,
('sub_matrix', 3): 23.26181234829638,
('sub_matrix', 4): 35.61438567264568,
('sub_matrix', 5): 48.48551986050014,
('sub_no-matrix', 1): 17.053809634549854,
('sub_no-matrix', 2): 33.66637542056656,
('sub_no-matrix', 3): 23.26181234829638,
('sub_no-matrix', 4): 35.61438704850689,
('sub_no-matrix', 5): 2.969309360231736},
{('no-sub_matrix', 1): 13.708973748402064,
('no-sub_matrix', 2): 31.147590264691182,
('no-sub_matrix', 3): 30.495597241955565,
('no-sub_matrix', 4): 34.65164493728535,
('no-sub_matrix', 5): 35.87510990950117,
('no-sub_no-matrix', 1): 13.708973748402064,
('no-sub_no-matrix', 2): 31.147590264691182,
('no-sub_no-matrix', 3): 30.495597241955565,
('no-sub_no-matrix', 4): 34.65164493728535,
('no-sub_no-matrix', 5): 3.232032121481573,
('sub_matrix', 1): 17.681722076468287,
('sub_matrix', 2): 33.77225997922327,
('sub_matrix', 3): 29.435808932487806,
('sub_matrix', 4): 34.354368969668016,
('sub_matrix', 5): 20.802733205442486,
('sub_no-matrix', 1): 17.681722076468287,
('sub_no-matrix', 2): 33.77225997922327,
('sub_no-matrix', 3): 29.435808932487806,
('sub_no-matrix', 4): 34.354368969668016,
('sub_no-matrix', 5): 3.7902066303710424},
{('no-sub_matrix', 1): 15.72185319065555,
('no-sub_matrix', 2): 45.25539814380218,
('no-sub_matrix', 3): 24.94273362957689,
('no-sub_matrix', 4): 40.81704901026569,
('no-sub_matrix', 5): 42.898794519499596,
('no-sub_no-matrix', 1): 15.72185319065555,
('no-sub_no-matrix', 2): 45.25539814380218,
('no-sub_no-matrix', 3): 24.94273362957689,
('no-sub_no-matrix', 4): 40.81704901026569,
('no-sub_no-matrix', 5): 2.6826901255924644,
('sub_matrix', 1): 17.565795106862403,
('sub_matrix', 2): 46.9371803702329,
('sub_matrix', 3): 23.887805807796486,
('sub_matrix', 4): 39.058599411828766,
('sub_matrix', 5): 32.234453544910295,
('sub_no-matrix', 1): 17.565795106862403,
('sub_no-matrix', 2): 46.9371803702329,
('sub_no-matrix', 3): 23.887805807796486,
('sub_no-matrix', 4): 39.058599411828766,
('sub_no-matrix', 5): 4.214674259243127},
{('no-sub_matrix', 1): 13.910878628792588,
('no-sub_matrix', 2): 33.45626834359109,
('no-sub_matrix', 3): 16.127584513594687,
('no-sub_matrix', 4): 32.59623120264939,
('no-sub_matrix', 5): 29.87568851789407,
('no-sub_no-matrix', 1): 13.910878628792588,
('no-sub_no-matrix', 2): 33.45626834359109,
('no-sub_no-matrix', 3): 16.127584513594687,
('no-sub_no-matrix', 4): 32.59623120264939,
('no-sub_no-matrix', 5): 2.3891779982892625,
('sub_matrix', 1): 17.18981661053988,
('sub_matrix', 2): 36.38883326650068,
('sub_matrix', 3): 13.081088737716442,
('sub_matrix', 4): 33.419732612590224,
('sub_matrix', 5): 22.665485632721676,
('sub_no-matrix', 1): 17.18981661053988,
('sub_no-matrix', 2): 36.38883326650068,
('sub_no-matrix', 3): 13.081088737716442,
('sub_no-matrix', 4): 33.419732612590224,
('sub_no-matrix', 5): 6.155199912348024},
{('no-sub_matrix', 1): 18.196771699177763,
('no-sub_matrix', 2): 35.624058750852136,
('no-sub_matrix', 3): 23.746554392851053,
('no-sub_matrix', 4): 29.44669921790574,
('no-sub_matrix', 5): 39.72412918901379,
('no-sub_no-matrix', 1): 18.196771699177763,
('no-sub_no-matrix', 2): 35.624058750852136,
('no-sub_no-matrix', 3): 23.746554392851053,
('no-sub_no-matrix', 4): 29.44669921790574,
('no-sub_no-matrix', 5): 2.870123353843486,
('sub_matrix', 1): 20.38619930823735,
('sub_matrix', 2): 36.29781144853154,
('sub_matrix', 3): 22.13637404741934,
('sub_matrix', 4): 29.68729899086184,
('sub_matrix', 5): 36.993790238103884,
('sub_no-matrix', 1): 20.38619930823735,
('sub_no-matrix', 2): 36.29781144853154,
('sub_no-matrix', 3): 22.13637404741934,
('sub_no-matrix', 4): 29.68729899086184,
('sub_no-matrix', 5): 7.650303570399713},
{('no-sub_matrix', 1): 11.992867568477442,
('no-sub_matrix', 2): 26.44083030170154,
('no-sub_matrix', 3): 27.574921221726136,
('no-sub_matrix', 4): 28.94213565689118,
('no-sub_matrix', 5): 46.973469397495556,
('no-sub_no-matrix', 1): 11.992867568477442,
('no-sub_no-matrix', 2): 26.44083030170154,
('no-sub_no-matrix', 3): 27.574921221726136,
('no-sub_no-matrix', 4): 28.94213565689118,
('no-sub_no-matrix', 5): 3.354326576753004,
('sub_matrix', 1): 14.434047100994839,
('sub_matrix', 2): 26.76571524620116,
('sub_matrix', 3): 25.83488399989926,
('sub_matrix', 4): 30.263621195061678,
('sub_matrix', 5): 36.822532494114455,
('sub_no-matrix', 1): 14.434047100994839,
('sub_no-matrix', 2): 26.76571524620116,
('sub_no-matrix', 3): 25.83488399989926,
('sub_no-matrix', 4): 30.263621195061678,
('sub_no-matrix', 5): 6.748976893757906},
{('no-sub_matrix', 1): 16.27614914680276,
('no-sub_matrix', 2): 41.35282905624703,
('no-sub_matrix', 3): 25.173115913245226,
('no-sub_matrix', 4): 52.876981987369014,
('no-sub_matrix', 5): 49.49767321075167,
('no-sub_no-matrix', 1): 16.27614914680276,
('no-sub_no-matrix', 2): 41.35282905624703,
('no-sub_no-matrix', 3): 25.173115913245226,
('no-sub_no-matrix', 4): 52.876981987369014,
('no-sub_no-matrix', 5): 1.5962803636236758,
('sub_matrix', 1): 18.735912436641787,
('sub_matrix', 2): 43.36213985849511,
('sub_matrix', 3): 24.582800598631913,
('sub_matrix', 4): 53.1616607417586,
('sub_matrix', 5): 41.2664433745972,
('sub_no-matrix', 1): 18.735912436641787,
('sub_no-matrix', 2): 43.36213985849511,
('sub_no-matrix', 3): 24.582800598631913,
('sub_no-matrix', 4): 53.16165799003619,
('sub_no-matrix', 5): 6.4917878462822305},
{('no-sub_matrix', 1): 14.036280122634507,
('no-sub_matrix', 2): 53.72802368862095,
('no-sub_matrix', 3): 18.940766131564004,
('no-sub_matrix', 4): 40.74964840745327,
('no-sub_matrix', 5): 39.57008490907742,
('no-sub_no-matrix', 1): 14.036280122634507,
('no-sub_no-matrix', 2): 53.72802368862095,
('no-sub_no-matrix', 3): 18.940766131564004,
('no-sub_no-matrix', 4): 40.74964840745327,
('no-sub_no-matrix', 5): 2.1275557540222967,
('sub_matrix', 1): 19.641722357026286,
('sub_matrix', 2): 52.709120728751486,
('sub_matrix', 3): 17.976257844509426,
('sub_matrix', 4): 42.51851542500959,
('sub_matrix', 5): 28.25018664655579,
('sub_no-matrix', 1): 19.641722357026286,
('sub_no-matrix', 2): 52.709120728751486,
('sub_no-matrix', 3): 17.976257844509426,
('sub_no-matrix', 4): 42.51851267328718,
('sub_no-matrix', 5): 5.409622788119386},
{('no-sub_matrix', 1): 16.961927903326398,
('no-sub_matrix', 2): 38.5455951142925,
('no-sub_matrix', 3): 25.122316709729276,
('no-sub_matrix', 4): 35.90131439006518,
('no-sub_matrix', 5): 41.65886977570029,
('no-sub_no-matrix', 1): 16.961927903326398,
('no-sub_no-matrix', 2): 38.5455951142925,
('no-sub_no-matrix', 3): 25.122316709729276,
('no-sub_no-matrix', 4): 35.90131439006518,
('no-sub_no-matrix', 5): 3.2679255886472447,
('sub_matrix', 1): 20.247934372024154,
('sub_matrix', 2): 40.408716019775625,
('sub_matrix', 3): 23.782735071043668,
('sub_matrix', 4): 37.00513584758997,
('sub_matrix', 5): 29.22700479607527,
('sub_no-matrix', 1): 20.247934372024154,
('sub_no-matrix', 2): 40.408716019775625,
('sub_no-matrix', 3): 23.782735071043668,
('sub_no-matrix', 4): 37.00513584758997,
('sub_no-matrix', 5): 4.780011845541033},
{('no-sub_matrix', 1): 12.109815771064152,
('no-sub_matrix', 2): 38.32406752938649,
('no-sub_matrix', 3): 25.987801084044044,
('no-sub_matrix', 4): 40.40950903177875,
('no-sub_matrix', 5): 52.86522525335603,
('no-sub_no-matrix', 1): 12.109815771064152,
('no-sub_no-matrix', 2): 38.32406752938649,
('no-sub_no-matrix', 3): 25.987801084044044,
('no-sub_no-matrix', 4): 40.40950903177875,
('no-sub_no-matrix', 5): 3.61917194787979,
('sub_matrix', 1): 15.130341564722832,
('sub_matrix', 2): 37.89719334728088,
('sub_matrix', 3): 24.65681032273433,
('sub_matrix', 4): 40.731610867030774,
('sub_matrix', 5): 37.566910985257906,
('sub_no-matrix', 1): 15.130341564722832,
('sub_no-matrix', 2): 37.89719334728088,
('sub_no-matrix', 3): 24.65681032273433,
('sub_no-matrix', 4): 40.731610867030774,
('sub_no-matrix', 5): 9.39736249989602},
{('no-sub_matrix', 1): 16.25058564557851,
('no-sub_matrix', 2): 37.20405682898803,
('no-sub_matrix', 3): 30.5107090995129,
('no-sub_matrix', 4): 44.537084655292894,
('no-sub_matrix', 5): 46.50046620075818,
('no-sub_no-matrix', 1): 16.25058564557851,
('no-sub_no-matrix', 2): 37.20405682898803,
('no-sub_no-matrix', 3): 30.5107090995129,
('no-sub_no-matrix', 4): 44.537084655292894,
('no-sub_no-matrix', 5): 1.8752506698658238,
('sub_matrix', 1): 18.440281483079957,
('sub_matrix', 2): 38.54769605435544,
('sub_matrix', 3): 30.510800250317864,
('sub_matrix', 4): 44.99740645329493,
('sub_matrix', 5): 39.55738177603457,
('sub_no-matrix', 1): 18.440281483079957,
('sub_no-matrix', 2): 38.54769605435544,
('sub_no-matrix', 3): 30.510800250317864,
('sub_no-matrix', 4): 44.99740645329493,
('sub_no-matrix', 5): 2.6233048602148386},
{('no-sub_matrix', 1): 16.324447378609865,
('no-sub_matrix', 2): 30.87308462806543,
('no-sub_matrix', 3): 22.765564836381643,
('no-sub_matrix', 4): 38.337445027901204,
('no-sub_matrix', 5): 40.98815076599078,
('no-sub_no-matrix', 1): 16.324447378609865,
('no-sub_no-matrix', 2): 30.87308462806543,
('no-sub_no-matrix', 3): 22.765564836381643,
('no-sub_no-matrix', 4): 38.337445027901204,
('no-sub_no-matrix', 5): 1.4796406979126138,
('sub_matrix', 1): 17.9623592385626,
('sub_matrix', 2): 32.36568198294609,
('sub_matrix', 3): 22.438215466486483,
('sub_matrix', 4): 40.900713840387546,
('sub_matrix', 5): 33.396627340011634,
('sub_no-matrix', 1): 17.9623592385626,
('sub_no-matrix', 2): 32.36568198294609,
('sub_no-matrix', 3): 22.438215466486483,
('sub_no-matrix', 4): 40.900713840387546,
('sub_no-matrix', 5): 6.609518913895668},
{('no-sub_matrix', 1): 14.033258731424148,
('no-sub_matrix', 2): 28.37206528002418,
('no-sub_matrix', 3): 27.043658386061033,
('no-sub_matrix', 4): 36.167049513436204,
('no-sub_matrix', 5): 52.280797076864395,
('no-sub_no-matrix', 1): 14.033258731424148,
('no-sub_no-matrix', 2): 28.37206528002418,
('no-sub_no-matrix', 3): 27.043658386061033,
('no-sub_no-matrix', 4): 36.167049513436204,
('no-sub_no-matrix', 5): 1.9358795417918389,
('sub_matrix', 1): 16.606623097498794,
('sub_matrix', 2): 29.98729916366884,
('sub_matrix', 3): 24.737985875967603,
('sub_matrix', 4): 34.93154214402433,
('sub_matrix', 5): 42.35241303296243,
('sub_no-matrix', 1): 16.606623097498794,
('sub_no-matrix', 2): 29.98729916366884,
('sub_no-matrix', 3): 24.737985875967603,
('sub_no-matrix', 4): 34.931551775052775,
('sub_no-matrix', 5): 7.151971456773863},
{('no-sub_matrix', 1): 10.482293039084738,
('no-sub_matrix', 2): 52.67861788579445,
('no-sub_matrix', 3): 21.665543335527666,
('no-sub_matrix', 4): 23.53727708917033,
('no-sub_matrix', 5): 32.2645584918966,
('no-sub_no-matrix', 1): 10.482293039084738,
('no-sub_no-matrix', 2): 52.67861788579445,
('no-sub_no-matrix', 3): 21.665543335527666,
('no-sub_no-matrix', 4): 23.53727708917033,
('no-sub_no-matrix', 5): 2.5207572809328243,
('sub_matrix', 1): 11.523882918360123,
('sub_matrix', 2): 57.336257883871156,
('sub_matrix', 3): 21.647716645835132,
('sub_matrix', 4): 23.491483569694733,
('sub_matrix', 5): 24.264706351480406,
('sub_no-matrix', 1): 11.523882918360123,
('sub_no-matrix', 2): 57.336257883871156,
('sub_no-matrix', 3): 21.647716645835132,
('sub_no-matrix', 4): 23.491462243846026,
('sub_no-matrix', 5): 9.714244661694366},
{('no-sub_matrix', 1): 11.992867568477442,
('no-sub_matrix', 2): 28.861638231250264,
('no-sub_matrix', 3): 24.222607873884137,
('no-sub_matrix', 4): 41.28280460012173,
('no-sub_matrix', 5): 56.6084264455065,
('no-sub_no-matrix', 1): 11.992867568477442,
('no-sub_no-matrix', 2): 28.861638231250264,
('no-sub_no-matrix', 3): 24.222607873884137,
('no-sub_no-matrix', 4): 41.28280460012173,
('no-sub_no-matrix', 5): 2.4980576348107437,
('sub_matrix', 1): 14.531057698832324,
('sub_matrix', 2): 31.280393934821902,
('sub_matrix', 3): 20.756528260470358,
('sub_matrix', 4): 42.15937712589425,
('sub_matrix', 5): 52.45767194621365,
('sub_no-matrix', 1): 14.531057698832324,
('sub_no-matrix', 2): 31.280393934821902,
('sub_no-matrix', 3): 20.756528260470358,
('sub_no-matrix', 4): 42.15937712589425,
('sub_no-matrix', 5): 4.819862633503057}]
def test_gpt_subordination_region_totals(syntaxgym_metric):
"""
Check region-level surprisals against the original syntaxgym-core
implementation, using the same underlying `gpt2` model.
"""
suite_name = "subordination_src-src"
dataset = datasets.load_dataset("cpllab/syntaxgym", suite_name)
result = syntaxgym_metric.compute(dataset=dataset["test"], model_id="gpt2")
region_totals = result[suite_name].region_totals
from pprint import pprint
pprint(region_totals[0])
pprint(GPT2_SUBORDINATION_SRC_REFERENCE[0])
keys = region_totals[0].keys()
assert set(keys) == set(GPT2_SUBORDINATION_SRC_REFERENCE[0].keys())
result_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys])
for region_totals_i in region_totals])
reference_ndarray = np.concatenate([np.array([region_totals_i[key] for key in keys])
for region_totals_i in GPT2_SUBORDINATION_SRC_REFERENCE])
pprint(sorted(zip(keys, np.abs(result_ndarray - reference_ndarray)),
key=lambda x: -x[1]))
np.testing.assert_allclose(result_ndarray, reference_ndarray, atol=1e-3)
def test_evaluation_all_vs_single(syntaxgym_metric):
"""
Check that a suite's performance is the same when evaluated in the composite
benchmark vs. evaluated independently.
"""
suite_name = "number_prep"
full_dataset = datasets.load_dataset("cpllab/syntaxgym")
sub_dataset = datasets.load_dataset("cpllab/syntaxgym", suite_name)
model_id = "hf-internal-testing/tiny-xlm-roberta"
full_result = syntaxgym_metric.compute(dataset=full_dataset["test"], model_id=model_id)
sub_result = syntaxgym_metric.compute(dataset=sub_dataset["test"], model_id=model_id)
assert full_result[suite_name].prediction_results == sub_result[suite_name].prediction_results