syntaxgym / prediction.py
jgauthier's picture
move metric and tests from dataset repo
bcb8ccf
from typing import Union, Optional as TOptional, List as TList
from pyparsing import *
import numpy as np
METRICS = {
'sum': sum,
'mean': np.mean,
'median': np.median,
'range': np.ptp,
'max': max,
'min': min
}
# Enable parser packrat (caching)
ParserElement.enablePackrat()
# Relative and absolute tolerance thresholds for surprisal equality
EQUALITY_RTOL = 1e-5
EQUALITY_ATOL = 1e-3
#######
# Define a grammar for prediction formulae.
# References a surprisal region
lpar = Suppress("(")
rpar = Suppress(")")
region = lpar + (Word(nums) | "*") + Suppress(";%") + Word(alphanums + "_-") + Suppress("%") + rpar
literal_float = pyparsing_common.number
class Region(object):
def __init__(self, tokens):
self.region_number = tokens[0]
self.condition_name = tokens[1]
def __str__(self):
return "(%s;%%%s%%)" % (self.region_number, self.condition_name)
def __repr__(self):
return "Region(%s,%s)" % (self.condition_name, self.region_number)
def __call__(self, surprisal_dict):
if self.region_number == "*":
return sum(value for (condition, region), value in surprisal_dict.items()
if condition == self.condition_name)
return surprisal_dict[self.condition_name, int(self.region_number)]
class LiteralFloat(object):
def __init__(self, tokens):
self.value = float(tokens[0])
def __str__(self):
return "%f" % (self.value,)
def __repr__(self):
return "LiteralFloat(%f)" % (self.value,)
def __call__(self, surprisal_dict):
return self.value
class BinaryOp(object):
operators: TOptional[TList[str]]
def __init__(self, tokens):
self.operator = tokens[0][1]
if self.operators is not None and self.operator not in self.operators:
raise ValueError("Invalid %s operator %s" % (self.__class__.__name__,
self.operator))
self.operands = [tokens[0][0], tokens[0][2]]
def __str__(self):
return "(%s %s %s)" % (self.operands[0], self.operator, self.operands[1])
def __repr__(self):
return "%s(%s)(%s)" % (self.__class__.__name__, self.operator, ",".join(map(repr, self.operands)))
def __call__(self, surprisal_dict):
op_vals = [op(surprisal_dict) for op in self.operands]
return self._evaluate(op_vals, surprisal_dict)
def _evaluate(self, evaluated_operands, surprisal_dict):
raise NotImplementedError()
class BoolOp(BinaryOp):
operators = ["&", "|"]
def _evaluate(self, op_vals, surprisal_dict):
if self.operator == "&":
return op_vals[0] and op_vals[1]
elif self.operator == "|":
return op_vals[0] or op_vals[1]
class FloatOp(BinaryOp):
operators = ["-", "+"]
def _evaluate(self, op_vals, surprisal_dict):
if self.operator == "-":
return op_vals[0] - op_vals[1]
elif self.operator == "+":
return op_vals[0] + op_vals[1]
class ComparatorOp(BinaryOp):
operators = ["<", ">", "="]
def _evaluate(self, op_vals, surprisal_dict):
if self.operator == "<":
return op_vals[0] < op_vals[1]
elif self.operator == ">":
return op_vals[0] > op_vals[1]
elif self.operator == "=":
return np.isclose(op_vals[0], op_vals[1],
rtol=EQUALITY_RTOL,
atol=EQUALITY_ATOL)
def Chain(op_cls, left_assoc=True):
def chainer(tokens):
"""
Create a binary tree of BinaryOps from the given repeated application
of the op.
"""
operators = tokens[0][1::2]
args = tokens[0][0::2]
if not left_assoc:
raise NotImplementedError
arg1 = args.pop(0)
while len(args) > 0:
operator = operators.pop(0)
arg2 = args.pop(0)
arg1 = op_cls([[arg1, operator, arg2]])
return arg1
return chainer
atom = region.setParseAction(Region) | literal_float.setParseAction(LiteralFloat)
prediction_expr = infixNotation(
atom,
[
(oneOf("- +"), 2, opAssoc.LEFT, Chain(FloatOp)),
(oneOf("< > ="), 2, opAssoc.LEFT, ComparatorOp),
(oneOf("& |"), 2, opAssoc.LEFT, Chain(BoolOp)),
],
lpar=lpar, rpar=rpar
)
class Prediction(object):
"""
Predictions state expected relations between language model surprisal
measures in different regions and conditions of a test suite. For more
information, see :ref:`architecture`.
"""
def __init__(self, idx: int, formula: Union[str, BinaryOp], metric: str):
"""
Args:
idx: A unique prediction ID. This is only relevant for
serialization.
formula: A string representation of the prediction formula, or an
already parsed formula. For more information, see
:ref:`architecture`.
metric: Metric for aggregating surprisals within regions.
"""
if isinstance(formula, str):
try:
formula = prediction_expr.parseString(formula, parseAll=True)[0]
except ParseException as e:
raise ValueError("Invalid formula expression %r" % (formula,)) from e
self.idx = idx
self.formula = formula
if metric not in METRICS.keys():
raise ValueError("Unknown metric %s. Supported metrics: %s" %
(metric, " ".join(METRICS.keys())))
self.metric = metric
def __call__(self, item):
"""
Evaluate the prediction on the given item dict representation. For more
information on item representations, see :ref:`suite_json`.
"""
# Prepare relevant surprisal dict
surps = {(c["condition_name"], r["region_number"]): r["metric_value"][self.metric]
for c in item["conditions"]
for r in c["regions"]}
return self.formula(surps)
@classmethod
def from_dict(cls, pred_dict, idx: int, metric: str):
"""
Parse from a prediction dictionary representation (see
:ref:`suite_json`).
"""
if not pred_dict["type"] == "formula":
raise ValueError("Unknown prediction type %s" % (pred_dict["type"],))
return cls(formula=pred_dict["formula"], idx=idx, metric=metric)
@property
def referenced_regions(self):
"""
Get a set of the regions referenced by this formula.
Each item is a tuple of the form ``(condition_name, region_number)``.
"""
def traverse(x, acc):
if isinstance(x, BinaryOp):
for val in x.operands:
traverse(val, acc)
elif isinstance(x, Region):
acc.add((x.condition_name, int(x.region_number)))
return acc
return traverse(self.formula, set())
def as_dict(self):
"""
Serialize as a prediction dictionary representation (see
:ref:`suite_json`).
"""
return dict(type="formula", formula=str(self.formula))
def __str__(self):
return "Prediction(%s)" % (self.formula,)
__repr__ = __str__
def __hash__(self):
return hash(self.formula)
def __eq__(self, other):
return isinstance(other, Prediction) and hash(self) == hash(other)