Spaces:
Sleeping
Sleeping
''' | |
ABSTRACT SYNTAX TREE | |
This file contains the Python class that represents programs created by our rasp synthesizer. | |
''' | |
from utils import * | |
class OperatorNode: | |
''' | |
Class to represent operator nodes (i.e., an operator and its operands) as an AST. | |
Args: | |
operator (object): operator object (e.g., Select, Aggregate, etc.) | |
children (list): list of children nodes (operands) | |
Example: | |
select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) | |
select_node.str() = "select(tokens, tokens, ==)" | |
select_node.evaluate("hi") = [[1, 0], [0, 1]] | |
select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) | |
''' | |
def __init__(self, operator, children): | |
self.operator = operator | |
self.children = children | |
self.weight = operator.weight + sum([child.weight for child in children]) | |
self.return_type = operator.return_type | |
def str(self): | |
if len(self.children) != self.operator.n_args: | |
raise ValueError("Improper number of arguments for operator.") | |
operand_strings = [child.str() for child in self.children] | |
return f"({self.operator.str(*operand_strings)})" | |
def evaluate(self, input=None): | |
''' | |
Directly evaluate the python translation. | |
''' | |
exe = self.to_python() | |
return exe(input) | |
# DEPRECATED VERSION: uses the actual rasp repl | |
# exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"") | |
# return run_repl(exe) | |
def to_python(self): | |
if len(self.children) != self.operator.n_args: | |
raise ValueError("Improper number of arguments for operator.") | |
operands = [child.to_python() for child in self.children] | |
return self.operator.to_python(*operands) | |
''' | |
TESTING | |
''' | |
if __name__ == "__main__": | |
from python_embedded_rasp import * | |
from tracr.rasp import rasp | |
select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either? | |
assert (select_op.weight == 4) | |
select_op_str = select_op.str() | |
actual_so_str = "(select(tokens, tokens, ==))" | |
assert select_op_str == actual_so_str | |
select_op_res = select_op.evaluate("hi") | |
actual_so_res = [[1, 0],[0, 1]] | |
assert select_op_res == actual_so_res | |
select_op_python = select_op.to_python() | |
actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) | |
assert type(select_op_python) == type(actual_so_python) | |
print("all tests passed hooray!") |