File size: 2,649 Bytes
72cfe15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
'''
ABSTRACT SYNTAX TREE
This file contains the Python class that represents programs created by our rasp synthesizer.
'''
from utils import *

class OperatorNode:
    '''
    Class to represent operator nodes (i.e., an operator and its operands) as an AST.

    Args:
        operator (object): operator object (e.g., Select, Aggregate, etc.)
        children (list): list of children nodes (operands)
    
    Example:
        select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()])
        select_node.str() = "select(tokens, tokens, ==)"
        select_node.evaluate("hi") = [[1, 0], [0, 1]]
        select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
    '''
    def __init__(self, operator, children):
        self.operator = operator
        self.children = children
        self.weight = operator.weight + sum([child.weight for child in children])
        self.return_type = operator.return_type
    
    def str(self):
        if len(self.children) != self.operator.n_args:
            raise ValueError("Improper number of arguments for operator.")
        operand_strings = [child.str() for child in self.children]
        return f"({self.operator.str(*operand_strings)})" 
        
    def evaluate(self, input=None):
        '''
        Directly evaluate the python translation.
        '''
        exe = self.to_python()
        return exe(input)

        # DEPRECATED VERSION: uses the actual rasp repl
        # exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"")
        # return run_repl(exe)
        
    def to_python(self):
        if len(self.children) != self.operator.n_args:
            raise ValueError("Improper number of arguments for operator.")  
        operands = [child.to_python() for child in self.children]
        return self.operator.to_python(*operands)
    
'''
TESTING
'''
if __name__ == "__main__":
    from python_embedded_rasp import *
    from tracr.rasp import rasp
    
    select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either?
    assert (select_op.weight == 4)
    
    select_op_str = select_op.str()
    actual_so_str = "(select(tokens, tokens, ==))"
    assert select_op_str == actual_so_str
    
    select_op_res = select_op.evaluate("hi")
    actual_so_res = [[1, 0],[0, 1]]
    assert select_op_res == actual_so_res 
    
    select_op_python = select_op.to_python()
    actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
    assert type(select_op_python) == type(actual_so_python)
    
    print("all tests passed hooray!")