Spaces:
Sleeping
Sleeping
added a synthesizer on top
Browse files- __pycache__/abstract_syntax_tree.cpython-39.pyc +0 -0
- __pycache__/python_embedded_rasp.cpython-39.pyc +0 -0
- __pycache__/rasp_synthesizer.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- abstract_syntax_tree.py +72 -0
- app.py +18 -0
- comp_flows/( tokens_int . 1 )(.1. 2.).pdf +0 -0
- outtest.txt +36 -0
- python_embedded_rasp.py +308 -0
- rasp_synthesizer.py +257 -0
- reverse-viz.ipynb +0 -0
- testouts.txt +55 -0
- tracr/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/assemble.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/basis_inference.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/compiling.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/nodes.cpython-39.pyc +0 -0
- tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/bases.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/transformers.cpython-39.pyc +0 -0
- tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc +0 -0
- tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc +0 -0
- tracr/rasp/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/rasp/__pycache__/rasp.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/attention.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/encoder.cpython-39.pyc +0 -0
- tracr/transformer/__pycache__/model.cpython-39.pyc +0 -0
- tracr/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- tracr/utils/__pycache__/errors.cpython-39.pyc +0 -0
- utils.py +80 -0
__pycache__/abstract_syntax_tree.cpython-39.pyc
ADDED
Binary file (2.94 kB). View file
|
|
__pycache__/python_embedded_rasp.cpython-39.pyc
ADDED
Binary file (9.03 kB). View file
|
|
__pycache__/rasp_synthesizer.cpython-39.pyc
ADDED
Binary file (9.09 kB). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.69 kB). View file
|
|
abstract_syntax_tree.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
ABSTRACT SYNTAX TREE
|
3 |
+
This file contains the Python class that represents programs created by our rasp synthesizer.
|
4 |
+
'''
|
5 |
+
from utils import *
|
6 |
+
|
7 |
+
class OperatorNode:
|
8 |
+
'''
|
9 |
+
Class to represent operator nodes (i.e., an operator and its operands) as an AST.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
operator (object): operator object (e.g., Select, Aggregate, etc.)
|
13 |
+
children (list): list of children nodes (operands)
|
14 |
+
|
15 |
+
Example:
|
16 |
+
select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()])
|
17 |
+
select_node.str() = "select(tokens, tokens, ==)"
|
18 |
+
select_node.evaluate("hi") = [[1, 0], [0, 1]]
|
19 |
+
select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
20 |
+
'''
|
21 |
+
def __init__(self, operator, children):
|
22 |
+
self.operator = operator
|
23 |
+
self.children = children
|
24 |
+
self.weight = operator.weight + sum([child.weight for child in children])
|
25 |
+
self.return_type = operator.return_type
|
26 |
+
|
27 |
+
def str(self):
|
28 |
+
if len(self.children) != self.operator.n_args:
|
29 |
+
raise ValueError("Improper number of arguments for operator.")
|
30 |
+
operand_strings = [child.str() for child in self.children]
|
31 |
+
return f"({self.operator.str(*operand_strings)})"
|
32 |
+
|
33 |
+
def evaluate(self, input=None):
|
34 |
+
'''
|
35 |
+
Directly evaluate the python translation.
|
36 |
+
'''
|
37 |
+
exe = self.to_python()
|
38 |
+
return exe(input)
|
39 |
+
|
40 |
+
# DEPRECATED VERSION: uses the actual rasp repl
|
41 |
+
# exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"")
|
42 |
+
# return run_repl(exe)
|
43 |
+
|
44 |
+
def to_python(self):
|
45 |
+
if len(self.children) != self.operator.n_args:
|
46 |
+
raise ValueError("Improper number of arguments for operator.")
|
47 |
+
operands = [child.to_python() for child in self.children]
|
48 |
+
return self.operator.to_python(*operands)
|
49 |
+
|
50 |
+
'''
|
51 |
+
TESTING
|
52 |
+
'''
|
53 |
+
if __name__ == "__main__":
|
54 |
+
from python_embedded_rasp import *
|
55 |
+
from tracr.rasp import rasp
|
56 |
+
|
57 |
+
select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either?
|
58 |
+
assert (select_op.weight == 4)
|
59 |
+
|
60 |
+
select_op_str = select_op.str()
|
61 |
+
actual_so_str = "(select(tokens, tokens, ==))"
|
62 |
+
assert select_op_str == actual_so_str
|
63 |
+
|
64 |
+
select_op_res = select_op.evaluate("hi")
|
65 |
+
actual_so_res = [[1, 0],[0, 1]]
|
66 |
+
assert select_op_res == actual_so_res
|
67 |
+
|
68 |
+
select_op_python = select_op.to_python()
|
69 |
+
actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
70 |
+
assert type(select_op_python) == type(actual_so_python)
|
71 |
+
|
72 |
+
print("all tests passed hooray!")
|
app.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
For future reference with downloading model files:
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import pickle
|
6 |
+
import base64
|
7 |
+
|
8 |
+
x = {"my": "data"}
|
9 |
+
|
10 |
+
def download_model(model):
|
11 |
+
output_model = pickle.dumps(model)
|
12 |
+
b64 = base64.b64encode(output_model).decode()
|
13 |
+
href = f'<a href="data:file/output_model;base64,{b64}" download="myfile.pkl">Download Trained Model .pkl File</a>'
|
14 |
+
st.markdown(href, unsafe_allow_html=True)
|
15 |
+
|
16 |
+
|
17 |
+
download_model(x)
|
18 |
+
'''
|
comp_flows/( tokens_int . 1 )(.1. 2.).pdf
ADDED
Binary file (19.2 kB). View file
|
|
outtest.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Received the following input and output examples:
|
2 |
+
[(['h', 'e', 'l', 'l', 'o'], [1, 1, 2, 2, 1])]
|
3 |
+
Running synthesizer with
|
4 |
+
Vocab: {'o', 'e', 'h', 'l'}
|
5 |
+
Max sequence length: 5
|
6 |
+
Max weight: 25
|
7 |
+
(indices - indices)
|
8 |
+
[[0, 0, 0, 0, 0]]
|
9 |
+
(indices - 0)
|
10 |
+
[[0, 1, 2, 3, 4]]
|
11 |
+
(indices - 1)
|
12 |
+
[[-1, 0, 1, 2, 3]]
|
13 |
+
(0 - indices)
|
14 |
+
[[0, -1, -2, -3, -4]]
|
15 |
+
(1 - indices)
|
16 |
+
[[1, 0, -1, -2, -3]]
|
17 |
+
(select(tokens, tokens, ==))
|
18 |
+
[[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, True, False], [False, False, True, True, False], [False, False, False, False, True]]]
|
19 |
+
(select(tokens, tokens, true))
|
20 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
21 |
+
(select(tokens, indices, ==))
|
22 |
+
[[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
|
23 |
+
(select(tokens, indices, true))
|
24 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
25 |
+
(select(indices, tokens, ==))
|
26 |
+
[[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
|
27 |
+
(select(indices, tokens, true))
|
28 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
29 |
+
(select(indices, indices, ==))
|
30 |
+
[[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]]]
|
31 |
+
(select(indices, indices, true))
|
32 |
+
[[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
|
33 |
+
(select_width((select(tokens, tokens, ==))))
|
34 |
+
[[1, 1, 2, 2, 1]]
|
35 |
+
The following program has been compiled to a transformer with 1 layer(s):
|
36 |
+
(select_width((select(tokens, tokens, ==))))
|
python_embedded_rasp.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
RASP OPERATORS THAT ARE SUPPORTED BY TRACR'S PYTHON EMBEDDING
|
3 |
+
This file contains Python classes that define the rasp operators supported by TRACR's python embedding of the langauge.
|
4 |
+
This is subset of everything that TRACR supports in python, due to project time constraints.
|
5 |
+
'''
|
6 |
+
import random
|
7 |
+
from typing import (Any, Callable, Dict, Generic, List, Mapping, Optional,
|
8 |
+
Sequence, TypeVar, Union)
|
9 |
+
from tracr.rasp import rasp
|
10 |
+
import subprocess
|
11 |
+
import time
|
12 |
+
|
13 |
+
'''
|
14 |
+
CLASS DEFINITIONS
|
15 |
+
'''
|
16 |
+
class Tokens:
|
17 |
+
'''
|
18 |
+
Tokens constant.
|
19 |
+
'''
|
20 |
+
def __init__(self):
|
21 |
+
self.n_args = 0
|
22 |
+
self.arg_types = []
|
23 |
+
self.return_type = rasp.SOp
|
24 |
+
self.weight = 1
|
25 |
+
|
26 |
+
def to_python(self):
|
27 |
+
# return an object that can be compiled into a TRACR transformer
|
28 |
+
# arguments should be python objects
|
29 |
+
return rasp.tokens
|
30 |
+
|
31 |
+
def str(self):
|
32 |
+
# represent rasp operator in string form
|
33 |
+
# expects arguments to be strings
|
34 |
+
return "tokens"
|
35 |
+
|
36 |
+
class Indices:
|
37 |
+
def __init__(self):
|
38 |
+
self.n_args = 0
|
39 |
+
self.arg_types = []
|
40 |
+
self.return_type = rasp.SOp
|
41 |
+
self.weight = 1
|
42 |
+
|
43 |
+
def to_python(self):
|
44 |
+
# return an object that can be compiled into a TRACR transformer
|
45 |
+
# arguments should be python objects
|
46 |
+
return rasp.indices
|
47 |
+
|
48 |
+
def str(self):
|
49 |
+
# represent rasp operator in string form
|
50 |
+
# expects arguments to be strings
|
51 |
+
return "indices"
|
52 |
+
|
53 |
+
class Zero:
|
54 |
+
def __init__(self):
|
55 |
+
self.n_args = 0
|
56 |
+
self.arg_types = []
|
57 |
+
self.return_type = int
|
58 |
+
self.weight = 1
|
59 |
+
|
60 |
+
def to_python(self):
|
61 |
+
# return an object that can be compiled into a TRACR transformer
|
62 |
+
# arguments should be python objects
|
63 |
+
return 0
|
64 |
+
|
65 |
+
def str(self):
|
66 |
+
# represent rasp operator in string form
|
67 |
+
# expects arguments to be strings
|
68 |
+
return "0"
|
69 |
+
|
70 |
+
class One:
|
71 |
+
def __init__(self):
|
72 |
+
self.n_args = 0
|
73 |
+
self.arg_types = []
|
74 |
+
self.return_type = int
|
75 |
+
self.weight = 1
|
76 |
+
|
77 |
+
def to_python(self):
|
78 |
+
# return an object that can be compiled into a TRACR transformer
|
79 |
+
# arguments should be python objects
|
80 |
+
return 1
|
81 |
+
|
82 |
+
def str(self):
|
83 |
+
# represent rasp operator in string form
|
84 |
+
# expects arguments to be strings
|
85 |
+
return "1"
|
86 |
+
|
87 |
+
class Equal:
|
88 |
+
'''
|
89 |
+
Comparison Equal constant.
|
90 |
+
'''
|
91 |
+
def __init__(self):
|
92 |
+
self.n_args = 0
|
93 |
+
self.arg_types = []
|
94 |
+
self.return_type = rasp.Predicate
|
95 |
+
self.weight = 1
|
96 |
+
|
97 |
+
def to_python(self):
|
98 |
+
# return an object that can be compiled into a TRACR transformer
|
99 |
+
# arguments should be python objects
|
100 |
+
return rasp.Comparison.EQ
|
101 |
+
|
102 |
+
def str(self):
|
103 |
+
# represent rasp operator in string form
|
104 |
+
# expects arguments to be strings
|
105 |
+
return "=="
|
106 |
+
|
107 |
+
class GT:
|
108 |
+
'''
|
109 |
+
Greater Than comparison operator.
|
110 |
+
'''
|
111 |
+
pass
|
112 |
+
|
113 |
+
class LT:
|
114 |
+
'''
|
115 |
+
Less Than comparison operator
|
116 |
+
'''
|
117 |
+
pass
|
118 |
+
|
119 |
+
class LEQ:
|
120 |
+
pass
|
121 |
+
|
122 |
+
class GEQ:
|
123 |
+
pass
|
124 |
+
|
125 |
+
class TRUE:
|
126 |
+
'''
|
127 |
+
Comparison True constant.
|
128 |
+
'''
|
129 |
+
def __init__(self):
|
130 |
+
self.n_args = 0
|
131 |
+
self.arg_types = []
|
132 |
+
self.return_type = rasp.Predicate
|
133 |
+
self.weight = 1
|
134 |
+
|
135 |
+
def to_python(self):
|
136 |
+
# return an object that can be compiled into a TRACR transformer
|
137 |
+
# arguments should be python objects
|
138 |
+
return rasp.Comparison.TRUE
|
139 |
+
|
140 |
+
def str(self):
|
141 |
+
# represent rasp operator in string form
|
142 |
+
# expects arguments to be strings
|
143 |
+
return "true"
|
144 |
+
|
145 |
+
class FALSE:
|
146 |
+
pass
|
147 |
+
|
148 |
+
class Add:
|
149 |
+
'''
|
150 |
+
Element-wise.
|
151 |
+
Input can be either int, float or s-op.
|
152 |
+
'''
|
153 |
+
pass
|
154 |
+
|
155 |
+
class Subtract:
|
156 |
+
'''
|
157 |
+
Element-wise.
|
158 |
+
Input can be either int, float or s-op.
|
159 |
+
'''
|
160 |
+
def __init__(self):
|
161 |
+
self.n_args = 2
|
162 |
+
self.arg_types = [Union[rasp.SOp, float, int], Union[rasp.SOp, float, int]]
|
163 |
+
self.return_type = Union[rasp.SOp, int, float]
|
164 |
+
self.weight = 1
|
165 |
+
|
166 |
+
def to_python(self, x, y):
|
167 |
+
# return an object that can be compiled into a TRACR transformer
|
168 |
+
# arguments should be python objects
|
169 |
+
if type(x) == type(rasp.tokens):
|
170 |
+
return None
|
171 |
+
if type(y) == type(rasp.tokens):
|
172 |
+
return None
|
173 |
+
return x - y
|
174 |
+
|
175 |
+
def str(self, x, y):
|
176 |
+
# represent rasp operator in string form
|
177 |
+
# expects arguments to be strings
|
178 |
+
return f"{x} - {y}"
|
179 |
+
|
180 |
+
class Mult:
|
181 |
+
'''
|
182 |
+
Element-wise.
|
183 |
+
Input can be either int, float or s-op.
|
184 |
+
'''
|
185 |
+
pass
|
186 |
+
|
187 |
+
class Divide:
|
188 |
+
'''
|
189 |
+
Element-wise.
|
190 |
+
Input can be either int, float or s-op.
|
191 |
+
'''
|
192 |
+
pass
|
193 |
+
|
194 |
+
class Fill:
|
195 |
+
'''
|
196 |
+
Given fill value and length, returns Sop of that length with that fill value.
|
197 |
+
Fill value can be int, float, or char.
|
198 |
+
Length must be a positive integer.
|
199 |
+
'''
|
200 |
+
pass
|
201 |
+
|
202 |
+
class SelectorAnd:
|
203 |
+
'''
|
204 |
+
Input can be bool or s-op.
|
205 |
+
'''
|
206 |
+
pass
|
207 |
+
|
208 |
+
class SelectorOr:
|
209 |
+
'''
|
210 |
+
Input can be bool or s-op.
|
211 |
+
'''
|
212 |
+
pass
|
213 |
+
|
214 |
+
class SelectorNot:
|
215 |
+
'''
|
216 |
+
Input is an s-op of bools. (Or bool-convertible values.)
|
217 |
+
'''
|
218 |
+
pass
|
219 |
+
|
220 |
+
class Select:
|
221 |
+
'''
|
222 |
+
Select operator.
|
223 |
+
'''
|
224 |
+
def __init__(self):
|
225 |
+
self.n_args = 3
|
226 |
+
self.arg_types = [rasp.SOp, rasp.SOp, rasp.Predicate]
|
227 |
+
self.return_type = rasp.Selector
|
228 |
+
self.weight = 1
|
229 |
+
|
230 |
+
def to_python(self, sop1, sop2, comp):
|
231 |
+
# return an object that can be compiled into a TRACR transformer
|
232 |
+
# arguments should be python objects
|
233 |
+
return rasp.Select(sop1, sop2, comp)
|
234 |
+
|
235 |
+
def str(self, sop1, sop2, comp):
|
236 |
+
# represent rasp operator in string form
|
237 |
+
# expects arguments to be strings
|
238 |
+
return f"select({sop1}, {sop2}, {comp})"
|
239 |
+
|
240 |
+
class Aggregate:
|
241 |
+
'''
|
242 |
+
The Aggregate operator.
|
243 |
+
'''
|
244 |
+
def __init__(self):
|
245 |
+
self.n_args = 2
|
246 |
+
self.arg_types = [rasp.Selector, rasp.SOp]
|
247 |
+
self.return_type = rasp.SOp
|
248 |
+
self.weight = 1
|
249 |
+
|
250 |
+
def to_python(self, sel, sop):
|
251 |
+
# return an object that can be compiled into a TRACR transformer
|
252 |
+
# arguments should be python objects
|
253 |
+
return rasp.Aggregate(sel, sop)
|
254 |
+
|
255 |
+
def str(self, sel, sop):
|
256 |
+
# represent rasp operator in string form
|
257 |
+
# expects arguments to be strings
|
258 |
+
return f"aggregate({sel}, {sop})"
|
259 |
+
|
260 |
+
class SelectorWidth:
|
261 |
+
'''
|
262 |
+
The selector_width operator.
|
263 |
+
'''
|
264 |
+
def __init__(self):
|
265 |
+
self.n_args = 1
|
266 |
+
self.arg_types = [rasp.Selector]
|
267 |
+
self.return_type = rasp.SOp
|
268 |
+
self.weight = 1
|
269 |
+
|
270 |
+
def to_python(self, sel):
|
271 |
+
# return an object that can be compiled into a TRACR transformer
|
272 |
+
# arguments should be python objects
|
273 |
+
return rasp.SelectorWidth(sel)
|
274 |
+
|
275 |
+
def str(self, sel):
|
276 |
+
# represent rasp operator in string form
|
277 |
+
# expects arguments to be strings
|
278 |
+
return f"select_width({sel})"
|
279 |
+
|
280 |
+
'''
|
281 |
+
GLOBAL CONSTANTS
|
282 |
+
'''
|
283 |
+
|
284 |
+
# define operators
|
285 |
+
rasp_operators = [Select(), SelectorWidth(), Aggregate(), Subtract()]
|
286 |
+
rasp_consts = [Tokens(), Tokens(), Equal(), TRUE(), Indices(), Indices(), Zero(), One()]
|
287 |
+
'''
|
288 |
+
TESTING
|
289 |
+
'''
|
290 |
+
if __name__ == "__main__":
|
291 |
+
test_select = Select()
|
292 |
+
|
293 |
+
test_select_python = test_select.to_python(Tokens().to_python(), Tokens().to_python(), Equal().to_python())
|
294 |
+
actual_ts_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
295 |
+
assert type(Tokens().to_python()) == type(rasp.tokens)
|
296 |
+
assert type(Equal().to_python() == type(rasp.Comparison.EQ))
|
297 |
+
assert type(test_select_python) == type(actual_ts_python)
|
298 |
+
|
299 |
+
test_select_string = test_select.str(Tokens().str(), Tokens().str(), Equal().str())
|
300 |
+
actual_ts_string = "select(tokens, tokens, ==)"
|
301 |
+
assert(test_select_string == actual_ts_string)
|
302 |
+
|
303 |
+
|
304 |
+
test_aggregate = Aggregate()
|
305 |
+
print(rasp.Aggregate(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), rasp.tokens)("hi"))
|
306 |
+
|
307 |
+
print("all tests passed hooray!")
|
308 |
+
|
rasp_synthesizer.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python rasp_synthesis.py --examples
|
6 |
+
'''
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
import itertools
|
10 |
+
import time
|
11 |
+
import ast
|
12 |
+
import re
|
13 |
+
from tracr.compiler import compiling
|
14 |
+
from typing import get_args
|
15 |
+
import inspect
|
16 |
+
|
17 |
+
from abstract_syntax_tree import *
|
18 |
+
from python_embedded_rasp import *
|
19 |
+
|
20 |
+
# PARSE ARGUMENTS
|
21 |
+
def parse_args():
|
22 |
+
'''
|
23 |
+
Parse command line arguments.
|
24 |
+
'''
|
25 |
+
parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.")
|
26 |
+
parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis")
|
27 |
+
parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.")
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
# ANALYZE EXAMPLES
|
32 |
+
def analyze_examples(inputs):
|
33 |
+
'''
|
34 |
+
Returns a list of unique (input_sequence, output_sequence) tuples of proper python types.
|
35 |
+
Ensures each example is only numeric values or only char values.
|
36 |
+
Returns useful constants given the input examples.
|
37 |
+
'''
|
38 |
+
example_ins = []
|
39 |
+
example_outs = []
|
40 |
+
try:
|
41 |
+
# Safely evaluate the string to a Python object
|
42 |
+
examples_lst = ast.literal_eval(inputs)
|
43 |
+
except (SyntaxError, ValueError) as e:
|
44 |
+
raise argparse.ArgumentTypeError(f"Invalid examples format: {e}")
|
45 |
+
|
46 |
+
if not isinstance(examples_lst, list):
|
47 |
+
raise ValueError("Input should be a list.")
|
48 |
+
for ex in examples_lst:
|
49 |
+
try:
|
50 |
+
ins, outs = ex[0], ex[1]
|
51 |
+
except:
|
52 |
+
raise argparse.ArgumentTypeError(f"Invalid examples format.")
|
53 |
+
|
54 |
+
def same_legal_type(lst):
|
55 |
+
return (all(isinstance(x, int) for x in lst) or
|
56 |
+
all(isinstance(x, float) for x in lst) or
|
57 |
+
all(isinstance(x, bool) for x in lst) or
|
58 |
+
all(isinstance(x, str) for x in lst))
|
59 |
+
|
60 |
+
if same_legal_type(ins) and same_legal_type(outs):
|
61 |
+
example_ins.append(ins)
|
62 |
+
example_outs.append(outs)
|
63 |
+
continue
|
64 |
+
raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}")
|
65 |
+
|
66 |
+
return example_ins, example_outs
|
67 |
+
|
68 |
+
# GET VOCABULARY
|
69 |
+
def get_vocabulary(examples):
|
70 |
+
'''
|
71 |
+
Returns vocabulary for later compiling the RASP model.
|
72 |
+
'''
|
73 |
+
vocab = []
|
74 |
+
for ex in examples:
|
75 |
+
ins, outs = ex[0], ex[1]
|
76 |
+
vocab.extend([obj for obj in ins])
|
77 |
+
return set(vocab)
|
78 |
+
|
79 |
+
# CHECK OBSERVATIONAL EQUIVALENCE
|
80 |
+
def check_obs_equivalence(examples, program_a, program_b):
|
81 |
+
try:
|
82 |
+
inputs = [example[0] for example in examples]
|
83 |
+
a_output = None
|
84 |
+
b_output = None
|
85 |
+
if program_a not in rasp_consts:
|
86 |
+
a_output = [program_a.evaluate(input) for input in inputs]
|
87 |
+
if program_b not in rasp_consts:
|
88 |
+
b_output = [program_b.evaluate(input) for input in inputs]
|
89 |
+
except:
|
90 |
+
return True # force the synthesizer to not consider this program
|
91 |
+
|
92 |
+
return a_output == b_output
|
93 |
+
|
94 |
+
# CHECK CORRECTNESS
|
95 |
+
def check_correctness(examples, program):
|
96 |
+
'''
|
97 |
+
Checks if the programs output matches expected output on all examples.
|
98 |
+
'''
|
99 |
+
try:
|
100 |
+
inputs = [example[0] for example in examples]
|
101 |
+
outputs = [example[1] for example in examples]
|
102 |
+
program_output = [program.evaluate(input) for input in inputs]
|
103 |
+
except:
|
104 |
+
return False
|
105 |
+
|
106 |
+
print(program.str())
|
107 |
+
print(program_output)
|
108 |
+
|
109 |
+
# TODO return number that match and return this
|
110 |
+
|
111 |
+
return program_output == outputs
|
112 |
+
|
113 |
+
# COMPARE TYPE SIGNATURES
|
114 |
+
def compare_types(list1, list2):
|
115 |
+
for idx, type1 in enumerate(list1):
|
116 |
+
if idx >= len(list2):
|
117 |
+
return False # The first list is longer than the second list
|
118 |
+
|
119 |
+
type2 = list2[idx]
|
120 |
+
|
121 |
+
# Check if type2 is a Union
|
122 |
+
if hasattr(type2, '__origin__') and type2.__origin__ is Union:
|
123 |
+
# Extract types from Union
|
124 |
+
types_in_union2 = get_args(type2)
|
125 |
+
# Check if type1 is a Union
|
126 |
+
if hasattr(type1, '__origin__') and type1.__origin__ is Union:
|
127 |
+
types_in_union1 = get_args(type1)
|
128 |
+
# Check if all types in type1's Union are in type2's Union
|
129 |
+
if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1):
|
130 |
+
return False
|
131 |
+
else:
|
132 |
+
# Check if type1 is in type2's Union
|
133 |
+
if not any(type1 == t2 for t2 in types_in_union2):
|
134 |
+
return False
|
135 |
+
else:
|
136 |
+
# Direct type comparison
|
137 |
+
if type1 != type2:
|
138 |
+
return False
|
139 |
+
|
140 |
+
return True
|
141 |
+
|
142 |
+
# RUN SYNTHESIZER
|
143 |
+
def run_synthesizer(examples, max_weight):
|
144 |
+
'''
|
145 |
+
Run bottom-up enumerative synthesis.
|
146 |
+
'''
|
147 |
+
program_bank = rasp_consts
|
148 |
+
program_bank_str = [p.str() for p in program_bank]
|
149 |
+
|
150 |
+
# TODO: store approximate programs, measured by number of output examples that match
|
151 |
+
|
152 |
+
# iterate over each level
|
153 |
+
for weight in range(2, max_weight):
|
154 |
+
|
155 |
+
for op in rasp_operators:
|
156 |
+
combinations = itertools.permutations(program_bank, op.n_args)
|
157 |
+
|
158 |
+
for combination in combinations:
|
159 |
+
|
160 |
+
type_signature = [p.return_type for p in combination]
|
161 |
+
|
162 |
+
if not compare_types(type_signature, op.arg_types):
|
163 |
+
continue
|
164 |
+
|
165 |
+
if sum([p.weight for p in combination]) > weight:
|
166 |
+
continue
|
167 |
+
|
168 |
+
program = OperatorNode(op, combination)
|
169 |
+
|
170 |
+
if program.str() in program_bank_str:
|
171 |
+
continue
|
172 |
+
|
173 |
+
if any([check_obs_equivalence(examples, program, p) for p in program_bank]):
|
174 |
+
continue
|
175 |
+
|
176 |
+
program_bank.append(program)
|
177 |
+
program_bank_str.append(program.str())
|
178 |
+
|
179 |
+
if check_correctness(examples, program):
|
180 |
+
return(program)
|
181 |
+
|
182 |
+
return None
|
183 |
+
|
184 |
+
# COMPILE RASP MODEL
|
185 |
+
if __name__ == "__main__":
|
186 |
+
|
187 |
+
'''
|
188 |
+
Some examples:
|
189 |
+
Identify anagrams:
|
190 |
+
[[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]]
|
191 |
+
Output: times out
|
192 |
+
Calculate the median of a list of numbers:
|
193 |
+
[[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]]
|
194 |
+
Output: times out
|
195 |
+
Identity function:
|
196 |
+
[[['h','i'], ['h','i']]]
|
197 |
+
Output: (aggregate((select(tokens, tokens, ==)), tokens))
|
198 |
+
Histogram:
|
199 |
+
[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]
|
200 |
+
Output: (select_width((select(tokens, tokens, ==))))
|
201 |
+
Length:
|
202 |
+
[[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]]
|
203 |
+
Output: (select_width((select(tokens, tokens, true))))
|
204 |
+
Calculate mean of list of numbers:
|
205 |
+
[[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]]
|
206 |
+
Output: (aggregate((select(tokens, tokens, true)), tokens))
|
207 |
+
Reverse a string:
|
208 |
+
[[['h', 'i'], ['i', 'h']]]
|
209 |
+
Output: times out
|
210 |
+
Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens);
|
211 |
+
PERSONAL TODOS:
|
212 |
+
- output several similar programs
|
213 |
+
-
|
214 |
+
|
215 |
+
'''
|
216 |
+
|
217 |
+
args = parse_args()
|
218 |
+
inputs, outs = analyze_examples(args.examples)
|
219 |
+
examples = list(zip(inputs, outs))
|
220 |
+
print("Received the following input and output examples:")
|
221 |
+
print(examples)
|
222 |
+
max_seq_len = 0
|
223 |
+
for i in inputs:
|
224 |
+
max_seq_len = max(len(i), max_seq_len)
|
225 |
+
vocab = get_vocabulary(examples)
|
226 |
+
|
227 |
+
print("Running synthesizer with")
|
228 |
+
print("Vocab: {}".format(vocab))
|
229 |
+
print("Max sequence length: {}".format(max_seq_len))
|
230 |
+
print("Max weight: {}".format(args.max_weight))
|
231 |
+
|
232 |
+
program = run_synthesizer(examples, args.max_weight)
|
233 |
+
|
234 |
+
if program:
|
235 |
+
algorithm = program.to_python()
|
236 |
+
|
237 |
+
bos = "BOS"
|
238 |
+
model = compiling.compile_rasp_to_model(
|
239 |
+
algorithm,
|
240 |
+
vocab=vocab,
|
241 |
+
max_seq_len=max_seq_len,
|
242 |
+
compiler_bos=bos,
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
def extract_layer_number(s):
|
247 |
+
match = re.search(r'layer_(\d+)', s)
|
248 |
+
if match:
|
249 |
+
return int(match.group(1)) + 1
|
250 |
+
else:
|
251 |
+
return None
|
252 |
+
|
253 |
+
layer_num = extract_layer_number(list(model.params.keys())[-1])
|
254 |
+
print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
|
255 |
+
print(program.str())
|
256 |
+
else:
|
257 |
+
print("No program found.")
|
reverse-viz.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
testouts.txt
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Received the following input and output examples:
|
2 |
+
[(['h', 'i'], ['h', 'h'])]
|
3 |
+
Running synthesizer with
|
4 |
+
Vocab: {'h', 'i'}
|
5 |
+
Max sequence length: 2
|
6 |
+
Max weight: 15
|
7 |
+
- Searching level 2 with 4 primitives.
|
8 |
+
- Searching level 3 with 4 primitives.
|
9 |
+
(select(tokens, tokens, ==))
|
10 |
+
[[[True, False], [False, True]]]
|
11 |
+
(select(tokens, tokens, true))
|
12 |
+
[[[True, True], [True, True]]]
|
13 |
+
- Searching level 4 with 6 primitives.
|
14 |
+
(select_width((select(tokens, tokens, ==))))
|
15 |
+
[[1, 1]]
|
16 |
+
(select_width((select(tokens, tokens, true))))
|
17 |
+
[[2, 2]]
|
18 |
+
- Searching level 5 with 8 primitives.
|
19 |
+
- Searching level 6 with 8 primitives.
|
20 |
+
- Searching level 7 with 8 primitives.
|
21 |
+
- Searching level 8 with 8 primitives.
|
22 |
+
- Searching level 9 with 8 primitives.
|
23 |
+
(aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))
|
24 |
+
[[1.0, 1.0]]
|
25 |
+
(aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))
|
26 |
+
[[2.0, 2.0]]
|
27 |
+
(aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))
|
28 |
+
[[1.0, 1.0]]
|
29 |
+
(aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))
|
30 |
+
[[2.0, 2.0]]
|
31 |
+
- Searching level 10 with 12 primitives.
|
32 |
+
- Searching level 11 with 12 primitives.
|
33 |
+
- Searching level 12 with 12 primitives.
|
34 |
+
- Searching level 13 with 12 primitives.
|
35 |
+
- Searching level 14 with 12 primitives.
|
36 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))))
|
37 |
+
[[1.0, 1.0]]
|
38 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))))
|
39 |
+
[[2.0, 2.0]]
|
40 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))))
|
41 |
+
[[1.0, 1.0]]
|
42 |
+
(aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))))
|
43 |
+
[[2.0, 2.0]]
|
44 |
+
> c:\users\18084\desktop\cs252r\final_project\tracr-synthesis\rasp_synthesizer.py(94)check_obs_equivalence()
|
45 |
+
-> return a_output == b_output
|
46 |
+
(Pdb) --KeyboardInterrupt--
|
47 |
+
(Pdb) --KeyboardInterrupt--
|
48 |
+
(Pdb) --KeyboardInterrupt--
|
49 |
+
(Pdb) *** SyntaxError: invalid syntax
|
50 |
+
(Pdb) --KeyboardInterrupt--
|
51 |
+
(Pdb) *** SyntaxError: invalid syntax
|
52 |
+
(Pdb) --KeyboardInterrupt--
|
53 |
+
(Pdb) --KeyboardInterrupt--
|
54 |
+
(Pdb) *** SyntaxError: invalid syntax
|
55 |
+
(Pdb)
|
tracr/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (168 Bytes). View file
|
|
tracr/compiler/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (341 Bytes). View file
|
|
tracr/compiler/__pycache__/assemble.cpython-39.pyc
ADDED
Binary file (9.98 kB). View file
|
|
tracr/compiler/__pycache__/basis_inference.cpython-39.pyc
ADDED
Binary file (2.97 kB). View file
|
|
tracr/compiler/__pycache__/compiling.cpython-39.pyc
ADDED
Binary file (2.48 kB). View file
|
|
tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc
ADDED
Binary file (6.71 kB). View file
|
|
tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc
ADDED
Binary file (1.69 kB). View file
|
|
tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc
ADDED
Binary file (7.4 kB). View file
|
|
tracr/compiler/__pycache__/nodes.cpython-39.pyc
ADDED
Binary file (442 Bytes). View file
|
|
tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc
ADDED
Binary file (1.78 kB). View file
|
|
tracr/craft/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (174 Bytes). View file
|
|
tracr/craft/__pycache__/bases.cpython-39.pyc
ADDED
Binary file (10 kB). View file
|
|
tracr/craft/__pycache__/transformers.cpython-39.pyc
ADDED
Binary file (7.64 kB). View file
|
|
tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc
ADDED
Binary file (5.32 kB). View file
|
|
tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (182 Bytes). View file
|
|
tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc
ADDED
Binary file (4.25 kB). View file
|
|
tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc
ADDED
Binary file (5.04 kB). View file
|
|
tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc
ADDED
Binary file (10.6 kB). View file
|
|
tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc
ADDED
Binary file (4.54 kB). View file
|
|
tracr/rasp/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (173 Bytes). View file
|
|
tracr/rasp/__pycache__/rasp.cpython-39.pyc
ADDED
Binary file (36.6 kB). View file
|
|
tracr/transformer/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (180 Bytes). View file
|
|
tracr/transformer/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (4.83 kB). View file
|
|
tracr/transformer/__pycache__/encoder.cpython-39.pyc
ADDED
Binary file (5.39 kB). View file
|
|
tracr/transformer/__pycache__/model.cpython-39.pyc
ADDED
Binary file (5.25 kB). View file
|
|
tracr/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (174 Bytes). View file
|
|
tracr/utils/__pycache__/errors.cpython-39.pyc
ADDED
Binary file (928 Bytes). View file
|
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import time
|
3 |
+
import re
|
4 |
+
import ast
|
5 |
+
|
6 |
+
# Start the REPL subprocess
|
7 |
+
python_exe = '/Users/18084/Desktop/CS252R/final_project/rasp-env-py3.9/Scripts/python.exe' #SETUP THING: replace with path to your python environment
|
8 |
+
|
9 |
+
'''
|
10 |
+
THE FOLLOWING FUNCTIONS ARE DEPRECATED
|
11 |
+
'''
|
12 |
+
def clean_carrots(text):
|
13 |
+
pattern = r">>(.*?)>>"
|
14 |
+
|
15 |
+
match = re.search(pattern, text)
|
16 |
+
if match:
|
17 |
+
result = match.group(1).strip() # .strip() is used to remove any leading/trailing whitespace
|
18 |
+
return result
|
19 |
+
|
20 |
+
def parse_output(out):
|
21 |
+
out = clean_carrots(out)
|
22 |
+
out = ast.literal_eval(out)
|
23 |
+
# can arrive as tuple, list, or dictionary
|
24 |
+
# ultimately want to convert everything to list form
|
25 |
+
if isinstance(out, dict):
|
26 |
+
return list(out.values())
|
27 |
+
if isinstance(out, tuple):
|
28 |
+
return list(out)
|
29 |
+
if isinstance(out, list):
|
30 |
+
return list
|
31 |
+
raise Exception("Error executing rasp program.")
|
32 |
+
|
33 |
+
def run_repl(command):
|
34 |
+
'''
|
35 |
+
Runs the RASP repl in a separate subprocess.
|
36 |
+
'''
|
37 |
+
process = subprocess.Popen([python_exe, 'RASP/RASP_support/REPL.py'],
|
38 |
+
stdin=subprocess.PIPE,
|
39 |
+
stdout=subprocess.PIPE,
|
40 |
+
stderr=subprocess.PIPE,
|
41 |
+
text=True)
|
42 |
+
|
43 |
+
# Send commands to the REPL
|
44 |
+
process.stdin.write(f'{command}\nexit()\n')
|
45 |
+
process.stdin.flush()
|
46 |
+
|
47 |
+
# Check periodically if the subprocess has terminated
|
48 |
+
while True:
|
49 |
+
if process.poll() is not None:
|
50 |
+
# The subprocess has terminated
|
51 |
+
break
|
52 |
+
time.sleep(0.1) # Wait for a short period (e.g., 0.1 seconds) before checking again
|
53 |
+
|
54 |
+
# Close the subprocess if still running
|
55 |
+
if process.poll() is None:
|
56 |
+
process.terminate()
|
57 |
+
|
58 |
+
# Read output and error
|
59 |
+
output = process.stdout.readlines()
|
60 |
+
error = process.stderr.readlines()
|
61 |
+
|
62 |
+
# Print output and error
|
63 |
+
str_output = ""
|
64 |
+
str_error = ""
|
65 |
+
for line in output:
|
66 |
+
str_output += line.strip() + " "
|
67 |
+
for line in error:
|
68 |
+
str_error += line.strip() + " "
|
69 |
+
|
70 |
+
str_output = parse_output(str_output)
|
71 |
+
return str_output, str_error
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
command = "select(tokens, tokens, ==)(\"hi\");"
|
75 |
+
res, _res_err = run_repl(command)
|
76 |
+
print(res)
|
77 |
+
|
78 |
+
command = "selector_width(select(tokens, tokens, ==))(\"hi\");"
|
79 |
+
res, _res_err = run_repl(command)
|
80 |
+
print(res)
|