''' ABSTRACT SYNTAX TREE This file contains the Python class that represents programs created by our rasp synthesizer. ''' from utils import * class OperatorNode: ''' Class to represent operator nodes (i.e., an operator and its operands) as an AST. Args: operator (object): operator object (e.g., Select, Aggregate, etc.) children (list): list of children nodes (operands) Example: select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) select_node.str() = "select(tokens, tokens, ==)" select_node.evaluate("hi") = [[1, 0], [0, 1]] select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) ''' def __init__(self, operator, children): self.operator = operator self.children = children self.weight = operator.weight + sum([child.weight for child in children]) self.return_type = operator.return_type def str(self): if len(self.children) != self.operator.n_args: raise ValueError("Improper number of arguments for operator.") operand_strings = [child.str() for child in self.children] return f"({self.operator.str(*operand_strings)})" def evaluate(self, input=None): ''' Directly evaluate the python translation. ''' exe = self.to_python() return exe(input) # DEPRECATED VERSION: uses the actual rasp repl # exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"") # return run_repl(exe) def to_python(self): if len(self.children) != self.operator.n_args: raise ValueError("Improper number of arguments for operator.") operands = [child.to_python() for child in self.children] return self.operator.to_python(*operands) ''' TESTING ''' if __name__ == "__main__": from python_embedded_rasp import * from tracr.rasp import rasp select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either? assert (select_op.weight == 4) select_op_str = select_op.str() actual_so_str = "(select(tokens, tokens, ==))" assert select_op_str == actual_so_str select_op_res = select_op.evaluate("hi") actual_so_res = [[1, 0],[0, 1]] assert select_op_res == actual_so_res select_op_python = select_op.to_python() actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) assert type(select_op_python) == type(actual_so_python) print("all tests passed hooray!")