CSquid333 commited on
Commit
f0b559a
1 Parent(s): 72cfe15

made an app for the synthesizer

Browse files
__pycache__/rasp_synthesizer.cpython-39.pyc CHANGED
Binary files a/__pycache__/rasp_synthesizer.cpython-39.pyc and b/__pycache__/rasp_synthesizer.cpython-39.pyc differ
 
app.py CHANGED
@@ -1,18 +1,81 @@
1
- '''
2
- For future reference with downloading model files:
3
-
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
5
  import pickle
6
  import base64
7
 
8
- x = {"my": "data"}
 
 
9
 
 
10
  def download_model(model):
11
  output_model = pickle.dumps(model)
12
  b64 = base64.b64encode(output_model).decode()
13
- href = f'<a href="data:file/output_model;base64,{b64}" download="myfile.pkl">Download Trained Model .pkl File</a>'
14
  st.markdown(href, unsafe_allow_html=True)
15
 
16
 
17
- download_model(x)
18
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+
3
+ import numpy as np
4
+ import argparse
5
+ import itertools
6
+ import time
7
+ import ast
8
+ import re
9
+ from tracr.compiler import compiling
10
+ from typing import get_args
11
+ import inspect
12
  import pickle
13
  import base64
14
 
15
+ from abstract_syntax_tree import *
16
+ from python_embedded_rasp import *
17
+ from rasp_synthesizer import *
18
 
19
+ # HELPER FUNCTIONS
20
  def download_model(model):
21
  output_model = pickle.dumps(model)
22
  b64 = base64.b64encode(output_model).decode()
23
+ href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>'
24
  st.markdown(href, unsafe_allow_html=True)
25
 
26
 
27
+ # APP DRIVER CODE
28
+ st.title("Bottom Up Synthesis for RASP")
29
+
30
+ max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 1)
31
+
32
+ default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]"
33
+ example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example)
34
+
35
+ inputs, outs = analyze_examples(example_text)
36
+ examples = list(zip(inputs, outs))
37
+ st.write("Received the following input and output examples:")
38
+ st.write(examples)
39
+ max_seq_len = 0
40
+ for i in inputs:
41
+ max_seq_len = max(len(i), max_seq_len)
42
+ vocab = get_vocabulary(examples)
43
+
44
+ st.subheader("Synthesis Configuration")
45
+ st.write("Running synthesizer with")
46
+ st.write("Vocab: {}".format(vocab))
47
+ st.write("Max sequence length: {}".format(max_seq_len))
48
+ st.write("Max weight: {}".format(max_weight))
49
+
50
+ program, approx_programs = run_synthesizer(examples, max_weight)
51
+
52
+ st.subheader("Synthesis Results:")
53
+ st.caption("May take a while.")
54
+ if program:
55
+ algorithm = program.to_python()
56
+
57
+ bos = "BOS"
58
+ model = compiling.compile_rasp_to_model(
59
+ algorithm,
60
+ vocab=vocab,
61
+ max_seq_len=max_seq_len,
62
+ compiler_bos=bos,
63
+ )
64
+
65
+
66
+ def extract_layer_number(s):
67
+ match = re.search(r'layer_(\d+)', s)
68
+ if match:
69
+ return int(match.group(1)) + 1
70
+ else:
71
+ return None
72
+
73
+ layer_num = extract_layer_number(list(model.params.keys())[-1])
74
+ st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
75
+ st.write(program.str())
76
+
77
+ st.write("Here is a model download link: ")
78
+ hk_model = model.params
79
+ download_model(hk_model)
80
+ else:
81
+ st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))
outtest.txt CHANGED
@@ -1,36 +1,940 @@
1
  Received the following input and output examples:
2
- [(['h', 'e', 'l', 'l', 'o'], [1, 1, 2, 2, 1])]
3
  Running synthesizer with
4
- Vocab: {'o', 'e', 'h', 'l'}
5
- Max sequence length: 5
6
- Max weight: 25
7
  (indices - indices)
8
- [[0, 0, 0, 0, 0]]
 
9
  (indices - 0)
10
- [[0, 1, 2, 3, 4]]
 
11
  (indices - 1)
12
- [[-1, 0, 1, 2, 3]]
 
13
  (0 - indices)
14
- [[0, -1, -2, -3, -4]]
 
15
  (1 - indices)
16
- [[1, 0, -1, -2, -3]]
 
17
  (select(tokens, tokens, ==))
18
- [[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, True, False], [False, False, True, True, False], [False, False, False, False, True]]]
 
19
  (select(tokens, tokens, true))
20
- [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
 
21
  (select(tokens, indices, ==))
22
- [[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
 
23
  (select(tokens, indices, true))
24
- [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
 
25
  (select(indices, tokens, ==))
26
- [[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
 
27
  (select(indices, tokens, true))
28
- [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
 
29
  (select(indices, indices, ==))
30
- [[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]]]
 
31
  (select(indices, indices, true))
32
- [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
33
- (select_width((select(tokens, tokens, ==))))
34
- [[1, 1, 2, 2, 1]]
35
- The following program has been compiled to a transformer with 1 layer(s):
36
  (select_width((select(tokens, tokens, ==))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  Received the following input and output examples:
2
+ [(['h', 'i'], ['i', 'h'])]
3
  Running synthesizer with
4
+ Vocab: {'i', 'h'}
5
+ Max sequence length: 2
6
+ Max weight: 7
7
  (indices - indices)
8
+ [[0, 0]]
9
+ num_correct score: 0
10
  (indices - 0)
11
+ [[0, 1]]
12
+ num_correct score: 0
13
  (indices - 1)
14
+ [[-1, 0]]
15
+ num_correct score: 0
16
  (0 - indices)
17
+ [[0, -1]]
18
+ num_correct score: 0
19
  (1 - indices)
20
+ [[1, 0]]
21
+ num_correct score: 0
22
  (select(tokens, tokens, ==))
23
+ [[[True, False], [False, True]]]
24
+ num_correct score: 0
25
  (select(tokens, tokens, true))
26
+ [[[True, True], [True, True]]]
27
+ num_correct score: 0
28
  (select(tokens, indices, ==))
29
+ [[[False, False], [False, False]]]
30
+ num_correct score: 0
31
  (select(tokens, indices, true))
32
+ [[[True, True], [True, True]]]
33
+ num_correct score: 0
34
  (select(indices, tokens, ==))
35
+ [[[False, False], [False, False]]]
36
+ num_correct score: 0
37
  (select(indices, tokens, true))
38
+ [[[True, True], [True, True]]]
39
+ num_correct score: 0
40
  (select(indices, indices, ==))
41
+ [[[True, False], [False, True]]]
42
+ num_correct score: 0
43
  (select(indices, indices, true))
44
+ [[[True, True], [True, True]]]
45
+ num_correct score: 0
 
 
46
  (select_width((select(tokens, tokens, ==))))
47
+ [[1, 1]]
48
+ num_correct score: 0
49
+ (select_width((select(tokens, tokens, true))))
50
+ [[2, 2]]
51
+ num_correct score: 0
52
+ (select_width((select(tokens, indices, ==))))
53
+ [[0, 0]]
54
+ num_correct score: 0
55
+ (select_width((select(tokens, indices, true))))
56
+ [[2, 2]]
57
+ num_correct score: 0
58
+ (select_width((select(indices, tokens, ==))))
59
+ [[0, 0]]
60
+ num_correct score: 0
61
+ (select_width((select(indices, tokens, true))))
62
+ [[2, 2]]
63
+ num_correct score: 0
64
+ (select_width((select(indices, indices, ==))))
65
+ [[1, 1]]
66
+ num_correct score: 0
67
+ (select_width((select(indices, indices, true))))
68
+ [[2, 2]]
69
+ num_correct score: 0
70
+ (indices - (indices - indices))
71
+ [[0, 1]]
72
+ num_correct score: 0
73
+ (indices - (indices - 0))
74
+ [[0, 0]]
75
+ num_correct score: 0
76
+ (indices - (indices - 1))
77
+ [[1, 1]]
78
+ num_correct score: 0
79
+ (indices - (0 - indices))
80
+ [[0, 2]]
81
+ num_correct score: 0
82
+ (indices - (1 - indices))
83
+ [[-1, 1]]
84
+ num_correct score: 0
85
+ (0 - (indices - indices))
86
+ [[0, 0]]
87
+ num_correct score: 0
88
+ (0 - (indices - 0))
89
+ [[0, -1]]
90
+ num_correct score: 0
91
+ (0 - (indices - 1))
92
+ [[1, 0]]
93
+ num_correct score: 0
94
+ (0 - (0 - indices))
95
+ [[0, 1]]
96
+ num_correct score: 0
97
+ (0 - (1 - indices))
98
+ [[-1, 0]]
99
+ num_correct score: 0
100
+ (1 - (indices - indices))
101
+ [[1, 1]]
102
+ num_correct score: 0
103
+ (1 - (indices - 0))
104
+ [[1, 0]]
105
+ num_correct score: 0
106
+ (1 - (indices - 1))
107
+ [[2, 1]]
108
+ num_correct score: 0
109
+ (1 - (0 - indices))
110
+ [[1, 2]]
111
+ num_correct score: 0
112
+ (1 - (1 - indices))
113
+ [[0, 1]]
114
+ num_correct score: 0
115
+ ((indices - indices) - indices)
116
+ [[0, -1]]
117
+ num_correct score: 0
118
+ ((indices - indices) - 0)
119
+ [[0, 0]]
120
+ num_correct score: 0
121
+ ((indices - indices) - 1)
122
+ [[-1, -1]]
123
+ num_correct score: 0
124
+ ((indices - 0) - indices)
125
+ [[0, 0]]
126
+ num_correct score: 0
127
+ ((indices - 0) - 0)
128
+ [[0, 1]]
129
+ num_correct score: 0
130
+ ((indices - 0) - 1)
131
+ [[-1, 0]]
132
+ num_correct score: 0
133
+ ((indices - 1) - indices)
134
+ [[-1, -1]]
135
+ num_correct score: 0
136
+ ((indices - 1) - 0)
137
+ [[-1, 0]]
138
+ num_correct score: 0
139
+ ((indices - 1) - 1)
140
+ [[-2, -1]]
141
+ num_correct score: 0
142
+ ((0 - indices) - indices)
143
+ [[0, -2]]
144
+ num_correct score: 0
145
+ ((0 - indices) - 0)
146
+ [[0, -1]]
147
+ num_correct score: 0
148
+ ((0 - indices) - 1)
149
+ [[-1, -2]]
150
+ num_correct score: 0
151
+ ((1 - indices) - indices)
152
+ [[1, -1]]
153
+ num_correct score: 0
154
+ ((1 - indices) - 0)
155
+ [[1, 0]]
156
+ num_correct score: 0
157
+ ((1 - indices) - 1)
158
+ [[0, -1]]
159
+ num_correct score: 0
160
+ (aggregate((select(tokens, tokens, ==)), tokens))
161
+ [['h', 'i']]
162
+ num_correct score: 2
163
+ (aggregate((select(tokens, tokens, ==)), indices))
164
+ [[0.0, 1.0]]
165
+ num_correct score: 0
166
+ (aggregate((select(tokens, tokens, true)), indices))
167
+ [[0.5, 0.5]]
168
+ num_correct score: 0
169
+ (aggregate((select(tokens, indices, ==)), tokens))
170
+ [[None, None]]
171
+ num_correct score: 0
172
+ (aggregate((select(tokens, indices, ==)), indices))
173
+ [[None, None]]
174
+ num_correct score: 0
175
+ (aggregate((select(tokens, indices, true)), indices))
176
+ [[0.5, 0.5]]
177
+ num_correct score: 0
178
+ (aggregate((select(indices, tokens, ==)), tokens))
179
+ [[None, None]]
180
+ num_correct score: 0
181
+ (aggregate((select(indices, tokens, ==)), indices))
182
+ [[None, None]]
183
+ num_correct score: 0
184
+ (aggregate((select(indices, tokens, true)), indices))
185
+ [[0.5, 0.5]]
186
+ num_correct score: 0
187
+ (aggregate((select(indices, indices, ==)), tokens))
188
+ [['h', 'i']]
189
+ num_correct score: 2
190
+ (aggregate((select(indices, indices, ==)), indices))
191
+ [[0.0, 1.0]]
192
+ num_correct score: 0
193
+ (aggregate((select(indices, indices, true)), indices))
194
+ [[0.5, 0.5]]
195
+ num_correct score: 0
196
+ (indices - (select_width((select(tokens, tokens, ==)))))
197
+ [[-1, 0]]
198
+ num_correct score: 0
199
+ (indices - (select_width((select(tokens, tokens, true)))))
200
+ [[-2, -1]]
201
+ num_correct score: 0
202
+ (indices - (select_width((select(tokens, indices, ==)))))
203
+ [[0, 1]]
204
+ num_correct score: 0
205
+ (indices - (select_width((select(tokens, indices, true)))))
206
+ [[-2, -1]]
207
+ num_correct score: 0
208
+ (indices - (select_width((select(indices, tokens, ==)))))
209
+ [[0, 1]]
210
+ num_correct score: 0
211
+ (indices - (select_width((select(indices, tokens, true)))))
212
+ [[-2, -1]]
213
+ num_correct score: 0
214
+ (indices - (select_width((select(indices, indices, ==)))))
215
+ [[-1, 0]]
216
+ num_correct score: 0
217
+ (indices - (select_width((select(indices, indices, true)))))
218
+ [[-2, -1]]
219
+ num_correct score: 0
220
+ (indices - (indices - (indices - indices)))
221
+ [[0, 0]]
222
+ num_correct score: 0
223
+ (indices - (indices - (indices - 0)))
224
+ [[0, 1]]
225
+ num_correct score: 0
226
+ (indices - (indices - (indices - 1)))
227
+ [[-1, 0]]
228
+ num_correct score: 0
229
+ (indices - (indices - (0 - indices)))
230
+ [[0, -1]]
231
+ num_correct score: 0
232
+ (indices - (indices - (1 - indices)))
233
+ [[1, 0]]
234
+ num_correct score: 0
235
+ (indices - (0 - (indices - indices)))
236
+ [[0, 1]]
237
+ num_correct score: 0
238
+ (indices - (0 - (indices - 0)))
239
+ [[0, 2]]
240
+ num_correct score: 0
241
+ (indices - (0 - (indices - 1)))
242
+ [[-1, 1]]
243
+ num_correct score: 0
244
+ (indices - (0 - (0 - indices)))
245
+ [[0, 0]]
246
+ num_correct score: 0
247
+ (indices - (0 - (1 - indices)))
248
+ [[1, 1]]
249
+ num_correct score: 0
250
+ (indices - (1 - (indices - indices)))
251
+ [[-1, 0]]
252
+ num_correct score: 0
253
+ (indices - (1 - (indices - 0)))
254
+ [[-1, 1]]
255
+ num_correct score: 0
256
+ (indices - (1 - (indices - 1)))
257
+ [[-2, 0]]
258
+ num_correct score: 0
259
+ (indices - (1 - (0 - indices)))
260
+ [[-1, -1]]
261
+ num_correct score: 0
262
+ (indices - (1 - (1 - indices)))
263
+ [[0, 0]]
264
+ num_correct score: 0
265
+ (indices - ((indices - indices) - indices))
266
+ [[0, 2]]
267
+ num_correct score: 0
268
+ (indices - ((indices - indices) - 0))
269
+ [[0, 1]]
270
+ num_correct score: 0
271
+ (indices - ((indices - indices) - 1))
272
+ [[1, 2]]
273
+ num_correct score: 0
274
+ (indices - ((indices - 0) - indices))
275
+ [[0, 1]]
276
+ num_correct score: 0
277
+ (indices - ((indices - 0) - 0))
278
+ [[0, 0]]
279
+ num_correct score: 0
280
+ (indices - ((indices - 0) - 1))
281
+ [[1, 1]]
282
+ num_correct score: 0
283
+ (indices - ((indices - 1) - indices))
284
+ [[1, 2]]
285
+ num_correct score: 0
286
+ (indices - ((indices - 1) - 0))
287
+ [[1, 1]]
288
+ num_correct score: 0
289
+ (indices - ((indices - 1) - 1))
290
+ [[2, 2]]
291
+ num_correct score: 0
292
+ (indices - ((0 - indices) - indices))
293
+ [[0, 3]]
294
+ num_correct score: 0
295
+ (indices - ((0 - indices) - 0))
296
+ [[0, 2]]
297
+ num_correct score: 0
298
+ (indices - ((0 - indices) - 1))
299
+ [[1, 3]]
300
+ num_correct score: 0
301
+ (indices - ((1 - indices) - indices))
302
+ [[-1, 2]]
303
+ num_correct score: 0
304
+ (indices - ((1 - indices) - 0))
305
+ [[-1, 1]]
306
+ num_correct score: 0
307
+ (indices - ((1 - indices) - 1))
308
+ [[0, 2]]
309
+ num_correct score: 0
310
+ (0 - (select_width((select(tokens, tokens, ==)))))
311
+ [[-1, -1]]
312
+ num_correct score: 0
313
+ (0 - (select_width((select(tokens, tokens, true)))))
314
+ [[-2, -2]]
315
+ num_correct score: 0
316
+ (0 - (select_width((select(tokens, indices, ==)))))
317
+ [[0, 0]]
318
+ num_correct score: 0
319
+ (0 - (select_width((select(tokens, indices, true)))))
320
+ [[-2, -2]]
321
+ num_correct score: 0
322
+ (0 - (select_width((select(indices, tokens, ==)))))
323
+ [[0, 0]]
324
+ num_correct score: 0
325
+ (0 - (select_width((select(indices, tokens, true)))))
326
+ [[-2, -2]]
327
+ num_correct score: 0
328
+ (0 - (select_width((select(indices, indices, ==)))))
329
+ [[-1, -1]]
330
+ num_correct score: 0
331
+ (0 - (select_width((select(indices, indices, true)))))
332
+ [[-2, -2]]
333
+ num_correct score: 0
334
+ (0 - (indices - (indices - indices)))
335
+ [[0, -1]]
336
+ num_correct score: 0
337
+ (0 - (indices - (indices - 0)))
338
+ [[0, 0]]
339
+ num_correct score: 0
340
+ (0 - (indices - (indices - 1)))
341
+ [[-1, -1]]
342
+ num_correct score: 0
343
+ (0 - (indices - (0 - indices)))
344
+ [[0, -2]]
345
+ num_correct score: 0
346
+ (0 - (indices - (1 - indices)))
347
+ [[1, -1]]
348
+ num_correct score: 0
349
+ (0 - (0 - (indices - indices)))
350
+ [[0, 0]]
351
+ num_correct score: 0
352
+ (0 - (0 - (indices - 0)))
353
+ [[0, 1]]
354
+ num_correct score: 0
355
+ (0 - (0 - (indices - 1)))
356
+ [[-1, 0]]
357
+ num_correct score: 0
358
+ (0 - (0 - (0 - indices)))
359
+ [[0, -1]]
360
+ num_correct score: 0
361
+ (0 - (0 - (1 - indices)))
362
+ [[1, 0]]
363
+ num_correct score: 0
364
+ (0 - (1 - (indices - indices)))
365
+ [[-1, -1]]
366
+ num_correct score: 0
367
+ (0 - (1 - (indices - 0)))
368
+ [[-1, 0]]
369
+ num_correct score: 0
370
+ (0 - (1 - (indices - 1)))
371
+ [[-2, -1]]
372
+ num_correct score: 0
373
+ (0 - (1 - (0 - indices)))
374
+ [[-1, -2]]
375
+ num_correct score: 0
376
+ (0 - (1 - (1 - indices)))
377
+ [[0, -1]]
378
+ num_correct score: 0
379
+ (0 - ((indices - indices) - indices))
380
+ [[0, 1]]
381
+ num_correct score: 0
382
+ (0 - ((indices - indices) - 0))
383
+ [[0, 0]]
384
+ num_correct score: 0
385
+ (0 - ((indices - indices) - 1))
386
+ [[1, 1]]
387
+ num_correct score: 0
388
+ (0 - ((indices - 0) - indices))
389
+ [[0, 0]]
390
+ num_correct score: 0
391
+ (0 - ((indices - 0) - 0))
392
+ [[0, -1]]
393
+ num_correct score: 0
394
+ (0 - ((indices - 0) - 1))
395
+ [[1, 0]]
396
+ num_correct score: 0
397
+ (0 - ((indices - 1) - indices))
398
+ [[1, 1]]
399
+ num_correct score: 0
400
+ (0 - ((indices - 1) - 0))
401
+ [[1, 0]]
402
+ num_correct score: 0
403
+ (0 - ((indices - 1) - 1))
404
+ [[2, 1]]
405
+ num_correct score: 0
406
+ (0 - ((0 - indices) - indices))
407
+ [[0, 2]]
408
+ num_correct score: 0
409
+ (0 - ((0 - indices) - 0))
410
+ [[0, 1]]
411
+ num_correct score: 0
412
+ (0 - ((0 - indices) - 1))
413
+ [[1, 2]]
414
+ num_correct score: 0
415
+ (0 - ((1 - indices) - indices))
416
+ [[-1, 1]]
417
+ num_correct score: 0
418
+ (0 - ((1 - indices) - 0))
419
+ [[-1, 0]]
420
+ num_correct score: 0
421
+ (0 - ((1 - indices) - 1))
422
+ [[0, 1]]
423
+ num_correct score: 0
424
+ (1 - (select_width((select(tokens, tokens, ==)))))
425
+ [[0, 0]]
426
+ num_correct score: 0
427
+ (1 - (select_width((select(tokens, tokens, true)))))
428
+ [[-1, -1]]
429
+ num_correct score: 0
430
+ (1 - (select_width((select(tokens, indices, ==)))))
431
+ [[1, 1]]
432
+ num_correct score: 0
433
+ (1 - (select_width((select(tokens, indices, true)))))
434
+ [[-1, -1]]
435
+ num_correct score: 0
436
+ (1 - (select_width((select(indices, tokens, ==)))))
437
+ [[1, 1]]
438
+ num_correct score: 0
439
+ (1 - (select_width((select(indices, tokens, true)))))
440
+ [[-1, -1]]
441
+ num_correct score: 0
442
+ (1 - (select_width((select(indices, indices, ==)))))
443
+ [[0, 0]]
444
+ num_correct score: 0
445
+ (1 - (select_width((select(indices, indices, true)))))
446
+ [[-1, -1]]
447
+ num_correct score: 0
448
+ (1 - (indices - (indices - indices)))
449
+ [[1, 0]]
450
+ num_correct score: 0
451
+ (1 - (indices - (indices - 0)))
452
+ [[1, 1]]
453
+ num_correct score: 0
454
+ (1 - (indices - (indices - 1)))
455
+ [[0, 0]]
456
+ num_correct score: 0
457
+ (1 - (indices - (0 - indices)))
458
+ [[1, -1]]
459
+ num_correct score: 0
460
+ (1 - (indices - (1 - indices)))
461
+ [[2, 0]]
462
+ num_correct score: 0
463
+ (1 - (0 - (indices - indices)))
464
+ [[1, 1]]
465
+ num_correct score: 0
466
+ (1 - (0 - (indices - 0)))
467
+ [[1, 2]]
468
+ num_correct score: 0
469
+ (1 - (0 - (indices - 1)))
470
+ [[0, 1]]
471
+ num_correct score: 0
472
+ (1 - (0 - (0 - indices)))
473
+ [[1, 0]]
474
+ num_correct score: 0
475
+ (1 - (0 - (1 - indices)))
476
+ [[2, 1]]
477
+ num_correct score: 0
478
+ (1 - (1 - (indices - indices)))
479
+ [[0, 0]]
480
+ num_correct score: 0
481
+ (1 - (1 - (indices - 0)))
482
+ [[0, 1]]
483
+ num_correct score: 0
484
+ (1 - (1 - (indices - 1)))
485
+ [[-1, 0]]
486
+ num_correct score: 0
487
+ (1 - (1 - (0 - indices)))
488
+ [[0, -1]]
489
+ num_correct score: 0
490
+ (1 - (1 - (1 - indices)))
491
+ [[1, 0]]
492
+ num_correct score: 0
493
+ (1 - ((indices - indices) - indices))
494
+ [[1, 2]]
495
+ num_correct score: 0
496
+ (1 - ((indices - indices) - 0))
497
+ [[1, 1]]
498
+ num_correct score: 0
499
+ (1 - ((indices - indices) - 1))
500
+ [[2, 2]]
501
+ num_correct score: 0
502
+ (1 - ((indices - 0) - indices))
503
+ [[1, 1]]
504
+ num_correct score: 0
505
+ (1 - ((indices - 0) - 0))
506
+ [[1, 0]]
507
+ num_correct score: 0
508
+ (1 - ((indices - 0) - 1))
509
+ [[2, 1]]
510
+ num_correct score: 0
511
+ (1 - ((indices - 1) - indices))
512
+ [[2, 2]]
513
+ num_correct score: 0
514
+ (1 - ((indices - 1) - 0))
515
+ [[2, 1]]
516
+ num_correct score: 0
517
+ (1 - ((indices - 1) - 1))
518
+ [[3, 2]]
519
+ num_correct score: 0
520
+ (1 - ((0 - indices) - indices))
521
+ [[1, 3]]
522
+ num_correct score: 0
523
+ (1 - ((0 - indices) - 0))
524
+ [[1, 2]]
525
+ num_correct score: 0
526
+ (1 - ((0 - indices) - 1))
527
+ [[2, 3]]
528
+ num_correct score: 0
529
+ (1 - ((1 - indices) - indices))
530
+ [[0, 2]]
531
+ num_correct score: 0
532
+ (1 - ((1 - indices) - 0))
533
+ [[0, 1]]
534
+ num_correct score: 0
535
+ (1 - ((1 - indices) - 1))
536
+ [[1, 2]]
537
+ num_correct score: 0
538
+ ((indices - indices) - (indices - 0))
539
+ [[0, -1]]
540
+ num_correct score: 0
541
+ ((indices - indices) - (indices - 1))
542
+ [[1, 0]]
543
+ num_correct score: 0
544
+ ((indices - indices) - (0 - indices))
545
+ [[0, 1]]
546
+ num_correct score: 0
547
+ ((indices - indices) - (1 - indices))
548
+ [[-1, 0]]
549
+ num_correct score: 0
550
+ ((indices - 0) - (indices - indices))
551
+ [[0, 1]]
552
+ num_correct score: 0
553
+ ((indices - 0) - (indices - 1))
554
+ [[1, 1]]
555
+ num_correct score: 0
556
+ ((indices - 0) - (0 - indices))
557
+ [[0, 2]]
558
+ num_correct score: 0
559
+ ((indices - 0) - (1 - indices))
560
+ [[-1, 1]]
561
+ num_correct score: 0
562
+ ((indices - 1) - (indices - indices))
563
+ [[-1, 0]]
564
+ num_correct score: 0
565
+ ((indices - 1) - (indices - 0))
566
+ [[-1, -1]]
567
+ num_correct score: 0
568
+ ((indices - 1) - (0 - indices))
569
+ [[-1, 1]]
570
+ num_correct score: 0
571
+ ((indices - 1) - (1 - indices))
572
+ [[-2, 0]]
573
+ num_correct score: 0
574
+ ((0 - indices) - (indices - indices))
575
+ [[0, -1]]
576
+ num_correct score: 0
577
+ ((0 - indices) - (indices - 0))
578
+ [[0, -2]]
579
+ num_correct score: 0
580
+ ((0 - indices) - (indices - 1))
581
+ [[1, -1]]
582
+ num_correct score: 0
583
+ ((0 - indices) - (1 - indices))
584
+ [[-1, -1]]
585
+ num_correct score: 0
586
+ ((1 - indices) - (indices - indices))
587
+ [[1, 0]]
588
+ num_correct score: 0
589
+ ((1 - indices) - (indices - 0))
590
+ [[1, -1]]
591
+ num_correct score: 0
592
+ ((1 - indices) - (indices - 1))
593
+ [[2, 0]]
594
+ num_correct score: 0
595
+ ((1 - indices) - (0 - indices))
596
+ [[1, 1]]
597
+ num_correct score: 0
598
+ ((select_width((select(tokens, tokens, ==)))) - indices)
599
+ [[1, 0]]
600
+ num_correct score: 0
601
+ ((select_width((select(tokens, tokens, ==)))) - 0)
602
+ [[1, 1]]
603
+ num_correct score: 0
604
+ ((select_width((select(tokens, tokens, ==)))) - 1)
605
+ [[0, 0]]
606
+ num_correct score: 0
607
+ ((select_width((select(tokens, tokens, true)))) - indices)
608
+ [[2, 1]]
609
+ num_correct score: 0
610
+ ((select_width((select(tokens, tokens, true)))) - 0)
611
+ [[2, 2]]
612
+ num_correct score: 0
613
+ ((select_width((select(tokens, tokens, true)))) - 1)
614
+ [[1, 1]]
615
+ num_correct score: 0
616
+ ((select_width((select(tokens, indices, ==)))) - indices)
617
+ [[0, -1]]
618
+ num_correct score: 0
619
+ ((select_width((select(tokens, indices, ==)))) - 0)
620
+ [[0, 0]]
621
+ num_correct score: 0
622
+ ((select_width((select(tokens, indices, ==)))) - 1)
623
+ [[-1, -1]]
624
+ num_correct score: 0
625
+ ((select_width((select(tokens, indices, true)))) - indices)
626
+ [[2, 1]]
627
+ num_correct score: 0
628
+ ((select_width((select(tokens, indices, true)))) - 0)
629
+ [[2, 2]]
630
+ num_correct score: 0
631
+ ((select_width((select(tokens, indices, true)))) - 1)
632
+ [[1, 1]]
633
+ num_correct score: 0
634
+ ((select_width((select(indices, tokens, ==)))) - indices)
635
+ [[0, -1]]
636
+ num_correct score: 0
637
+ ((select_width((select(indices, tokens, ==)))) - 0)
638
+ [[0, 0]]
639
+ num_correct score: 0
640
+ ((select_width((select(indices, tokens, ==)))) - 1)
641
+ [[-1, -1]]
642
+ num_correct score: 0
643
+ ((select_width((select(indices, tokens, true)))) - indices)
644
+ [[2, 1]]
645
+ num_correct score: 0
646
+ ((select_width((select(indices, tokens, true)))) - 0)
647
+ [[2, 2]]
648
+ num_correct score: 0
649
+ ((select_width((select(indices, tokens, true)))) - 1)
650
+ [[1, 1]]
651
+ num_correct score: 0
652
+ ((select_width((select(indices, indices, ==)))) - indices)
653
+ [[1, 0]]
654
+ num_correct score: 0
655
+ ((select_width((select(indices, indices, ==)))) - 0)
656
+ [[1, 1]]
657
+ num_correct score: 0
658
+ ((select_width((select(indices, indices, ==)))) - 1)
659
+ [[0, 0]]
660
+ num_correct score: 0
661
+ ((select_width((select(indices, indices, true)))) - indices)
662
+ [[2, 1]]
663
+ num_correct score: 0
664
+ ((select_width((select(indices, indices, true)))) - 0)
665
+ [[2, 2]]
666
+ num_correct score: 0
667
+ ((select_width((select(indices, indices, true)))) - 1)
668
+ [[1, 1]]
669
+ num_correct score: 0
670
+ ((indices - (indices - indices)) - indices)
671
+ [[0, 0]]
672
+ num_correct score: 0
673
+ ((indices - (indices - indices)) - 0)
674
+ [[0, 1]]
675
+ num_correct score: 0
676
+ ((indices - (indices - indices)) - 1)
677
+ [[-1, 0]]
678
+ num_correct score: 0
679
+ ((indices - (indices - 0)) - indices)
680
+ [[0, -1]]
681
+ num_correct score: 0
682
+ ((indices - (indices - 0)) - 0)
683
+ [[0, 0]]
684
+ num_correct score: 0
685
+ ((indices - (indices - 0)) - 1)
686
+ [[-1, -1]]
687
+ num_correct score: 0
688
+ ((indices - (indices - 1)) - indices)
689
+ [[1, 0]]
690
+ num_correct score: 0
691
+ ((indices - (indices - 1)) - 0)
692
+ [[1, 1]]
693
+ num_correct score: 0
694
+ ((indices - (indices - 1)) - 1)
695
+ [[0, 0]]
696
+ num_correct score: 0
697
+ ((indices - (0 - indices)) - indices)
698
+ [[0, 1]]
699
+ num_correct score: 0
700
+ ((indices - (0 - indices)) - 0)
701
+ [[0, 2]]
702
+ num_correct score: 0
703
+ ((indices - (0 - indices)) - 1)
704
+ [[-1, 1]]
705
+ num_correct score: 0
706
+ ((indices - (1 - indices)) - indices)
707
+ [[-1, 0]]
708
+ num_correct score: 0
709
+ ((indices - (1 - indices)) - 0)
710
+ [[-1, 1]]
711
+ num_correct score: 0
712
+ ((indices - (1 - indices)) - 1)
713
+ [[-2, 0]]
714
+ num_correct score: 0
715
+ ((0 - (indices - indices)) - indices)
716
+ [[0, -1]]
717
+ num_correct score: 0
718
+ ((0 - (indices - indices)) - 0)
719
+ [[0, 0]]
720
+ num_correct score: 0
721
+ ((0 - (indices - indices)) - 1)
722
+ [[-1, -1]]
723
+ num_correct score: 0
724
+ ((0 - (indices - 0)) - indices)
725
+ [[0, -2]]
726
+ num_correct score: 0
727
+ ((0 - (indices - 0)) - 0)
728
+ [[0, -1]]
729
+ num_correct score: 0
730
+ ((0 - (indices - 0)) - 1)
731
+ [[-1, -2]]
732
+ num_correct score: 0
733
+ ((0 - (indices - 1)) - indices)
734
+ [[1, -1]]
735
+ num_correct score: 0
736
+ ((0 - (indices - 1)) - 0)
737
+ [[1, 0]]
738
+ num_correct score: 0
739
+ ((0 - (indices - 1)) - 1)
740
+ [[0, -1]]
741
+ num_correct score: 0
742
+ ((0 - (0 - indices)) - indices)
743
+ [[0, 0]]
744
+ num_correct score: 0
745
+ ((0 - (0 - indices)) - 0)
746
+ [[0, 1]]
747
+ num_correct score: 0
748
+ ((0 - (0 - indices)) - 1)
749
+ [[-1, 0]]
750
+ num_correct score: 0
751
+ ((0 - (1 - indices)) - indices)
752
+ [[-1, -1]]
753
+ num_correct score: 0
754
+ ((0 - (1 - indices)) - 0)
755
+ [[-1, 0]]
756
+ num_correct score: 0
757
+ ((0 - (1 - indices)) - 1)
758
+ [[-2, -1]]
759
+ num_correct score: 0
760
+ ((1 - (indices - indices)) - indices)
761
+ [[1, 0]]
762
+ num_correct score: 0
763
+ ((1 - (indices - indices)) - 0)
764
+ [[1, 1]]
765
+ num_correct score: 0
766
+ ((1 - (indices - indices)) - 1)
767
+ [[0, 0]]
768
+ num_correct score: 0
769
+ ((1 - (indices - 0)) - indices)
770
+ [[1, -1]]
771
+ num_correct score: 0
772
+ ((1 - (indices - 0)) - 0)
773
+ [[1, 0]]
774
+ num_correct score: 0
775
+ ((1 - (indices - 0)) - 1)
776
+ [[0, -1]]
777
+ num_correct score: 0
778
+ ((1 - (indices - 1)) - indices)
779
+ [[2, 0]]
780
+ num_correct score: 0
781
+ ((1 - (indices - 1)) - 0)
782
+ [[2, 1]]
783
+ num_correct score: 0
784
+ ((1 - (indices - 1)) - 1)
785
+ [[1, 0]]
786
+ num_correct score: 0
787
+ ((1 - (0 - indices)) - indices)
788
+ [[1, 1]]
789
+ num_correct score: 0
790
+ ((1 - (0 - indices)) - 0)
791
+ [[1, 2]]
792
+ num_correct score: 0
793
+ ((1 - (0 - indices)) - 1)
794
+ [[0, 1]]
795
+ num_correct score: 0
796
+ ((1 - (1 - indices)) - indices)
797
+ [[0, 0]]
798
+ num_correct score: 0
799
+ ((1 - (1 - indices)) - 0)
800
+ [[0, 1]]
801
+ num_correct score: 0
802
+ ((1 - (1 - indices)) - 1)
803
+ [[-1, 0]]
804
+ num_correct score: 0
805
+ (((indices - indices) - indices) - indices)
806
+ [[0, -2]]
807
+ num_correct score: 0
808
+ (((indices - indices) - indices) - 0)
809
+ [[0, -1]]
810
+ num_correct score: 0
811
+ (((indices - indices) - indices) - 1)
812
+ [[-1, -2]]
813
+ num_correct score: 0
814
+ (((indices - indices) - 0) - indices)
815
+ [[0, -1]]
816
+ num_correct score: 0
817
+ (((indices - indices) - 0) - 0)
818
+ [[0, 0]]
819
+ num_correct score: 0
820
+ (((indices - indices) - 0) - 1)
821
+ [[-1, -1]]
822
+ num_correct score: 0
823
+ (((indices - indices) - 1) - indices)
824
+ [[-1, -2]]
825
+ num_correct score: 0
826
+ (((indices - indices) - 1) - 0)
827
+ [[-1, -1]]
828
+ num_correct score: 0
829
+ (((indices - indices) - 1) - 1)
830
+ [[-2, -2]]
831
+ num_correct score: 0
832
+ (((indices - 0) - indices) - indices)
833
+ [[0, -1]]
834
+ num_correct score: 0
835
+ (((indices - 0) - indices) - 0)
836
+ [[0, 0]]
837
+ num_correct score: 0
838
+ (((indices - 0) - indices) - 1)
839
+ [[-1, -1]]
840
+ num_correct score: 0
841
+ (((indices - 0) - 0) - indices)
842
+ [[0, 0]]
843
+ num_correct score: 0
844
+ (((indices - 0) - 0) - 0)
845
+ [[0, 1]]
846
+ num_correct score: 0
847
+ (((indices - 0) - 0) - 1)
848
+ [[-1, 0]]
849
+ num_correct score: 0
850
+ (((indices - 0) - 1) - indices)
851
+ [[-1, -1]]
852
+ num_correct score: 0
853
+ (((indices - 0) - 1) - 0)
854
+ [[-1, 0]]
855
+ num_correct score: 0
856
+ (((indices - 0) - 1) - 1)
857
+ [[-2, -1]]
858
+ num_correct score: 0
859
+ (((indices - 1) - indices) - indices)
860
+ [[-1, -2]]
861
+ num_correct score: 0
862
+ (((indices - 1) - indices) - 0)
863
+ [[-1, -1]]
864
+ num_correct score: 0
865
+ (((indices - 1) - indices) - 1)
866
+ [[-2, -2]]
867
+ num_correct score: 0
868
+ (((indices - 1) - 0) - indices)
869
+ [[-1, -1]]
870
+ num_correct score: 0
871
+ (((indices - 1) - 0) - 0)
872
+ [[-1, 0]]
873
+ num_correct score: 0
874
+ (((indices - 1) - 0) - 1)
875
+ [[-2, -1]]
876
+ num_correct score: 0
877
+ (((indices - 1) - 1) - indices)
878
+ [[-2, -2]]
879
+ num_correct score: 0
880
+ (((indices - 1) - 1) - 0)
881
+ [[-2, -1]]
882
+ num_correct score: 0
883
+ (((indices - 1) - 1) - 1)
884
+ [[-3, -2]]
885
+ num_correct score: 0
886
+ (((0 - indices) - indices) - indices)
887
+ [[0, -3]]
888
+ num_correct score: 0
889
+ (((0 - indices) - indices) - 0)
890
+ [[0, -2]]
891
+ num_correct score: 0
892
+ (((0 - indices) - indices) - 1)
893
+ [[-1, -3]]
894
+ num_correct score: 0
895
+ (((0 - indices) - 0) - indices)
896
+ [[0, -2]]
897
+ num_correct score: 0
898
+ (((0 - indices) - 0) - 0)
899
+ [[0, -1]]
900
+ num_correct score: 0
901
+ (((0 - indices) - 0) - 1)
902
+ [[-1, -2]]
903
+ num_correct score: 0
904
+ (((0 - indices) - 1) - indices)
905
+ [[-1, -3]]
906
+ num_correct score: 0
907
+ (((0 - indices) - 1) - 0)
908
+ [[-1, -2]]
909
+ num_correct score: 0
910
+ (((0 - indices) - 1) - 1)
911
+ [[-2, -3]]
912
+ num_correct score: 0
913
+ (((1 - indices) - indices) - indices)
914
+ [[1, -2]]
915
+ num_correct score: 0
916
+ (((1 - indices) - indices) - 0)
917
+ [[1, -1]]
918
+ num_correct score: 0
919
+ (((1 - indices) - indices) - 1)
920
+ [[0, -2]]
921
+ num_correct score: 0
922
+ (((1 - indices) - 0) - indices)
923
+ [[1, -1]]
924
+ num_correct score: 0
925
+ (((1 - indices) - 0) - 0)
926
+ [[1, 0]]
927
+ num_correct score: 0
928
+ (((1 - indices) - 0) - 1)
929
+ [[0, -1]]
930
+ num_correct score: 0
931
+ (((1 - indices) - 1) - indices)
932
+ [[0, -2]]
933
+ num_correct score: 0
934
+ (((1 - indices) - 1) - 0)
935
+ [[0, -1]]
936
+ num_correct score: 0
937
+ (((1 - indices) - 1) - 1)
938
+ [[-1, -2]]
939
+ num_correct score: 0
940
+ No exact program found but here are some approximate ideas: ['(indices - 0)', '(indices - 1)', '(aggregate((select(indices, indices, ==)), tokens))', '(aggregate((select(tokens, tokens, ==)), tokens))', '(indices - indices)']
rasp_synthesizer.py CHANGED
@@ -13,6 +13,7 @@ import re
13
  from tracr.compiler import compiling
14
  from typing import get_args
15
  import inspect
 
16
 
17
  from abstract_syntax_tree import *
18
  from python_embedded_rasp import *
@@ -103,12 +104,11 @@ def check_correctness(examples, program):
103
  except:
104
  return False
105
 
106
- print(program.str())
107
- print(program_output)
 
108
 
109
- # TODO return number that match and return this
110
-
111
- return program_output == outputs
112
 
113
  # COMPARE TYPE SIGNATURES
114
  def compare_types(list1, list2):
@@ -147,7 +147,7 @@ def run_synthesizer(examples, max_weight):
147
  program_bank = rasp_consts
148
  program_bank_str = [p.str() for p in program_bank]
149
 
150
- # TODO: store approximate programs, measured by number of output examples that match
151
 
152
  # iterate over each level
153
  for weight in range(2, max_weight):
@@ -176,10 +176,21 @@ def run_synthesizer(examples, max_weight):
176
  program_bank.append(program)
177
  program_bank_str.append(program.str())
178
 
179
- if check_correctness(examples, program):
180
- return(program)
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- return None
183
 
184
  # COMPILE RASP MODEL
185
  if __name__ == "__main__":
@@ -229,7 +240,7 @@ if __name__ == "__main__":
229
  print("Max sequence length: {}".format(max_seq_len))
230
  print("Max weight: {}".format(args.max_weight))
231
 
232
- program = run_synthesizer(examples, args.max_weight)
233
 
234
  if program:
235
  algorithm = program.to_python()
@@ -254,4 +265,4 @@ if __name__ == "__main__":
254
  print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
255
  print(program.str())
256
  else:
257
- print("No program found.")
 
13
  from tracr.compiler import compiling
14
  from typing import get_args
15
  import inspect
16
+ import heapq
17
 
18
  from abstract_syntax_tree import *
19
  from python_embedded_rasp import *
 
104
  except:
105
  return False
106
 
107
+ is_correct = (program_output == outputs)
108
+ num_correct = sum([int(p == o) for p,o in zip(program_output, outputs)]) + \
109
+ sum([int(type(i) == type(j)) for p,o in zip(program_output, outputs) for i,j in zip(p, o)])
110
 
111
+ return is_correct, num_correct
 
 
112
 
113
  # COMPARE TYPE SIGNATURES
114
  def compare_types(list1, list2):
 
147
  program_bank = rasp_consts
148
  program_bank_str = [p.str() for p in program_bank]
149
 
150
+ approx_progs = []
151
 
152
  # iterate over each level
153
  for weight in range(2, max_weight):
 
176
  program_bank.append(program)
177
  program_bank_str.append(program.str())
178
 
179
+ is_correct, num_correct = check_correctness(examples, program)
180
+
181
+ if is_correct:
182
+ return(program), [ap[1] for ap in approx_progs]
183
+
184
+ if len(approx_progs) >= 3:
185
+ correct_cutoff, _prog = heapq.heappop(approx_progs)
186
+ if num_correct > correct_cutoff:
187
+ heapq.heappush(approx_progs, (num_correct, program.str()))
188
+ else:
189
+ heapq.heappush(approx_progs, (correct_cutoff, _prog))
190
+ else:
191
+ heapq.heappush(approx_progs, (num_correct, program.str()))
192
 
193
+ return None, [ap[1] for ap in approx_progs]
194
 
195
  # COMPILE RASP MODEL
196
  if __name__ == "__main__":
 
240
  print("Max sequence length: {}".format(max_seq_len))
241
  print("Max weight: {}".format(args.max_weight))
242
 
243
+ program, approx_programs = run_synthesizer(examples, args.max_weight)
244
 
245
  if program:
246
  algorithm = program.to_python()
 
265
  print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
266
  print(program.str())
267
  else:
268
+ print("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))
tracr/compiler/__pycache__/assemble.cpython-39.pyc CHANGED
Binary files a/tracr/compiler/__pycache__/assemble.cpython-39.pyc and b/tracr/compiler/__pycache__/assemble.cpython-39.pyc differ
 
tracr/compiler/__pycache__/basis_inference.cpython-39.pyc CHANGED
Binary files a/tracr/compiler/__pycache__/basis_inference.cpython-39.pyc and b/tracr/compiler/__pycache__/basis_inference.cpython-39.pyc differ
 
tracr/compiler/__pycache__/compiling.cpython-39.pyc CHANGED
Binary files a/tracr/compiler/__pycache__/compiling.cpython-39.pyc and b/tracr/compiler/__pycache__/compiling.cpython-39.pyc differ
 
tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc CHANGED
Binary files a/tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc and b/tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc differ
 
tracr/compiler/__pycache__/validating.cpython-39.pyc ADDED
Binary file (6 kB). View file
 
tracr/craft/__pycache__/transformers.cpython-39.pyc CHANGED
Binary files a/tracr/craft/__pycache__/transformers.cpython-39.pyc and b/tracr/craft/__pycache__/transformers.cpython-39.pyc differ
 
tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc CHANGED
Binary files a/tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc and b/tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc differ
 
tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc CHANGED
Binary files a/tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc and b/tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc differ
 
tracr/rasp/__pycache__/rasp.cpython-39.pyc CHANGED
Binary files a/tracr/rasp/__pycache__/rasp.cpython-39.pyc and b/tracr/rasp/__pycache__/rasp.cpython-39.pyc differ