RASP-Synthesis / abstract_syntax_tree.py
CSquid333's picture
added a synthesizer on top
72cfe15
raw
history blame
2.65 kB
'''
ABSTRACT SYNTAX TREE
This file contains the Python class that represents programs created by our rasp synthesizer.
'''
from utils import *
class OperatorNode:
'''
Class to represent operator nodes (i.e., an operator and its operands) as an AST.
Args:
operator (object): operator object (e.g., Select, Aggregate, etc.)
children (list): list of children nodes (operands)
Example:
select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()])
select_node.str() = "select(tokens, tokens, ==)"
select_node.evaluate("hi") = [[1, 0], [0, 1]]
select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
'''
def __init__(self, operator, children):
self.operator = operator
self.children = children
self.weight = operator.weight + sum([child.weight for child in children])
self.return_type = operator.return_type
def str(self):
if len(self.children) != self.operator.n_args:
raise ValueError("Improper number of arguments for operator.")
operand_strings = [child.str() for child in self.children]
return f"({self.operator.str(*operand_strings)})"
def evaluate(self, input=None):
'''
Directly evaluate the python translation.
'''
exe = self.to_python()
return exe(input)
# DEPRECATED VERSION: uses the actual rasp repl
# exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"")
# return run_repl(exe)
def to_python(self):
if len(self.children) != self.operator.n_args:
raise ValueError("Improper number of arguments for operator.")
operands = [child.to_python() for child in self.children]
return self.operator.to_python(*operands)
'''
TESTING
'''
if __name__ == "__main__":
from python_embedded_rasp import *
from tracr.rasp import rasp
select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either?
assert (select_op.weight == 4)
select_op_str = select_op.str()
actual_so_str = "(select(tokens, tokens, ==))"
assert select_op_str == actual_so_str
select_op_res = select_op.evaluate("hi")
actual_so_res = [[1, 0],[0, 1]]
assert select_op_res == actual_so_res
select_op_python = select_op.to_python()
actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
assert type(select_op_python) == type(actual_so_python)
print("all tests passed hooray!")