Spaces:
Sleeping
Sleeping
ayushnoori
commited on
Commit
•
5b04db9
1
Parent(s):
3872a55
Complete on arithmetic DSL
Browse files- README.md +3 -3
- abstract_syntax_tree.py +4 -2
- arithmetic.py +2 -0
- demonstration.ipynb +67 -195
- synthesizer.py +94 -3
README.md
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
# Bottom-Up Enumerative Program Synthesis
|
2 |
|
3 |
-
🚨🚨PLEASE DO NOT GRADE YET🚨🚨
|
4 |
-
|
5 |
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
|
6 |
|
7 |
## 🛠️ Background
|
@@ -51,7 +49,9 @@ To add additional input-output examples, modify `examples.py`. Add a new key to
|
|
51 |
|
52 |
## 🔎 Abstract Syntax Tree
|
53 |
|
54 |
-
The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands
|
|
|
|
|
55 |
|
56 |
## 🔮 Virtual Environment
|
57 |
|
|
|
1 |
# Bottom-Up Enumerative Program Synthesis
|
2 |
|
|
|
|
|
3 |
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
|
4 |
|
5 |
## 🛠️ Background
|
|
|
49 |
|
50 |
## 🔎 Abstract Syntax Tree
|
51 |
|
52 |
+
The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands and also to generate a string representation of the program.
|
53 |
+
|
54 |
+
At program evaluation time, the AST is evaluated from the bottom up. That is, the operands are evaluated first, and then the operator is evaluated on the operands. This is implemented in the `evaluate` method of the `OperatorNode` class. In the case of integers, variable inputs are represented by the `IntegerVariable` class in `arithmetic.py`. When input is not `None`, input type checking and validation is performed by the `evaluate` function in this class.
|
55 |
|
56 |
## 🔮 Virtual Environment
|
57 |
|
abstract_syntax_tree.py
CHANGED
@@ -24,10 +24,12 @@ class OperatorNode:
|
|
24 |
add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
|
25 |
multiply_node.evaluate([7]) # returns 24
|
26 |
'''
|
27 |
-
|
28 |
def __init__(self, operator, children):
|
29 |
-
self.operator = operator #
|
30 |
self.children = children # list of children nodes (operands)
|
|
|
|
|
31 |
|
32 |
def evaluate(self, input = None):
|
33 |
|
|
|
24 |
add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
|
25 |
multiply_node.evaluate([7]) # returns 24
|
26 |
'''
|
27 |
+
|
28 |
def __init__(self, operator, children):
|
29 |
+
self.operator = operator # operator object (e.g., Add, Subtract, etc.)
|
30 |
self.children = children # list of children nodes (operands)
|
31 |
+
self.weight = operator.weight + sum([child.weight for child in children]) # weight of the node
|
32 |
+
self.type = operator.return_type # return type of the operator object
|
33 |
|
34 |
def evaluate(self, input = None):
|
35 |
|
arithmetic.py
CHANGED
@@ -16,6 +16,7 @@ class IntegerVariable:
|
|
16 |
# self.value = None # value of the variable, initially None
|
17 |
self.position = position # zero-indexed position of the variable in the arguments to program
|
18 |
self.type = int # type of the variable
|
|
|
19 |
|
20 |
# def assign(self, value):
|
21 |
# self.value = value
|
@@ -57,6 +58,7 @@ class IntegerConstant:
|
|
57 |
def __init__(self, value):
|
58 |
self.value = value # value of the constant
|
59 |
self.type = int # type of the constant
|
|
|
60 |
|
61 |
def evaluate(self, input = None):
|
62 |
return self.value
|
|
|
16 |
# self.value = None # value of the variable, initially None
|
17 |
self.position = position # zero-indexed position of the variable in the arguments to program
|
18 |
self.type = int # type of the variable
|
19 |
+
self.weight = 1 # weight of the variable
|
20 |
|
21 |
# def assign(self, value):
|
22 |
# self.value = value
|
|
|
58 |
def __init__(self, value):
|
59 |
self.value = value # value of the constant
|
60 |
self.type = int # type of the constant
|
61 |
+
self.weight = 1 # weight of the constant
|
62 |
|
63 |
def evaluate(self, input = None):
|
64 |
return self.value
|
demonstration.ipynb
CHANGED
@@ -29,6 +29,7 @@
|
|
29 |
"from arithmetic import *\n",
|
30 |
"from abstract_syntax_tree import OperatorNode\n",
|
31 |
"from examples import example_set, check_examples\n",
|
|
|
32 |
"import config"
|
33 |
]
|
34 |
},
|
@@ -41,7 +42,7 @@
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
-
"execution_count":
|
45 |
"metadata": {},
|
46 |
"outputs": [],
|
47 |
"source": [
|
@@ -55,233 +56,104 @@
|
|
55 |
"cell_type": "markdown",
|
56 |
"metadata": {},
|
57 |
"source": [
|
58 |
-
"
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"cell_type": "markdown",
|
63 |
-
"metadata": {},
|
64 |
-
"source": [
|
65 |
-
"I provide examples of arithmetic operations."
|
66 |
]
|
67 |
},
|
68 |
{
|
69 |
"cell_type": "code",
|
70 |
-
"execution_count":
|
71 |
"metadata": {},
|
72 |
"outputs": [],
|
73 |
"source": [
|
74 |
-
"
|
75 |
-
"
|
76 |
-
"
|
77 |
-
" For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.\n",
|
78 |
-
" '''\n",
|
79 |
-
" def __init__(self, position):\n",
|
80 |
-
" self.value = None # value of the variable, initially None\n",
|
81 |
-
" self.position = position # position of the variable in the arguments to program\n",
|
82 |
-
" self.type = int # type of the variable\n",
|
83 |
-
"\n",
|
84 |
-
" def assign(self, value):\n",
|
85 |
-
" self.value = value\n",
|
86 |
-
"\n",
|
87 |
-
"class IntegerConstant:\n",
|
88 |
-
" '''\n",
|
89 |
-
" Class to represent an integer constant.\n",
|
90 |
-
" '''\n",
|
91 |
-
" def __init__(self, value):\n",
|
92 |
-
" self.value = value # value of the constant\n",
|
93 |
-
" self.type = int # type of the constant\n",
|
94 |
-
"\n",
|
95 |
-
"class Add:\n",
|
96 |
-
" '''\n",
|
97 |
-
" Operator to add two numerical values.\n",
|
98 |
-
" '''\n",
|
99 |
-
" def __init__(self):\n",
|
100 |
-
" self.arity = 2 # number of arguments\n",
|
101 |
-
" self.arg_types = [int, int] # argument types\n",
|
102 |
-
" self.return_type = int # return type\n",
|
103 |
-
" self.weight = 1 # weight\n",
|
104 |
-
"\n",
|
105 |
-
" def __call__(self, x, y):\n",
|
106 |
-
" return x + y\n",
|
107 |
-
" \n",
|
108 |
-
" def str(x, y):\n",
|
109 |
-
" return f\"{x} + {y}\"\n",
|
110 |
-
"\n",
|
111 |
-
"class Subtract:\n",
|
112 |
-
" '''\n",
|
113 |
-
" Operator to subtract two numerical values.\n",
|
114 |
-
" '''\n",
|
115 |
-
" def __init__(self):\n",
|
116 |
-
" self.arity = 2 # number of arguments\n",
|
117 |
-
" self.arg_types = [int, int] # argument types\n",
|
118 |
-
" self.return_type = int # return type\n",
|
119 |
-
" self.weight = 1 # weight\n",
|
120 |
-
"\n",
|
121 |
-
" def __call__(self, x, y):\n",
|
122 |
-
" return x - y\n",
|
123 |
-
" \n",
|
124 |
-
" def str(x, y):\n",
|
125 |
-
" return f\"{x} - {y}\"\n",
|
126 |
-
" \n",
|
127 |
-
"class Multiply:\n",
|
128 |
-
" '''\n",
|
129 |
-
" Operator to multiply two numerical values.\n",
|
130 |
-
" '''\n",
|
131 |
-
" def __init__(self):\n",
|
132 |
-
" self.arity = 2 # number of arguments\n",
|
133 |
-
" self.arg_types = [int, int] # argument types\n",
|
134 |
-
" self.return_type = int # return type\n",
|
135 |
-
" self.weight = 1 # weight\n",
|
136 |
-
"\n",
|
137 |
-
" def __call__(self, x, y):\n",
|
138 |
-
" return x * y\n",
|
139 |
-
" \n",
|
140 |
-
" def str(x, y):\n",
|
141 |
-
" return f\"{x} * {y}\" \n",
|
142 |
-
"\n",
|
143 |
-
"class Divide:\n",
|
144 |
-
" '''\n",
|
145 |
-
" Operator to divide two numerical values.\n",
|
146 |
-
" '''\n",
|
147 |
-
" def __init__(self):\n",
|
148 |
-
" self.arity = 2 # number of arguments\n",
|
149 |
-
" self.arg_types = [int, int] # argument types\n",
|
150 |
-
" self.return_type = int # return type\n",
|
151 |
-
" self.weight = 1 # weight\n",
|
152 |
-
"\n",
|
153 |
-
" def __call__(self, x, y):\n",
|
154 |
-
" try: # check for division by zero error\n",
|
155 |
-
" return x / y\n",
|
156 |
-
" except ZeroDivisionError:\n",
|
157 |
-
" return None\n",
|
158 |
-
" \n",
|
159 |
-
" def str(x, y):\n",
|
160 |
-
" return f\"{x} / {y}\"\n",
|
161 |
-
"\n",
|
162 |
-
"\n",
|
163 |
-
"'''\n",
|
164 |
-
"GLOBAL CONSTANTS\n",
|
165 |
-
"''' \n",
|
166 |
-
"\n",
|
167 |
-
"# define operators\n",
|
168 |
-
"arithmetic_operators = [Add(), Subtract(), Multiply(), Divide()]"
|
169 |
]
|
170 |
},
|
171 |
{
|
172 |
"cell_type": "markdown",
|
173 |
"metadata": {},
|
174 |
"source": [
|
175 |
-
"I define
|
176 |
]
|
177 |
},
|
178 |
{
|
179 |
"cell_type": "code",
|
180 |
-
"execution_count":
|
181 |
"metadata": {},
|
182 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
"source": [
|
184 |
-
"
|
185 |
-
"
|
186 |
-
" Extracts the constants from the input-output examples. Also constructs variables as needed\n",
|
187 |
-
" based on the input-output examples, and adds them to the list of constants.\n",
|
188 |
-
" '''\n",
|
189 |
"\n",
|
190 |
-
"
|
191 |
-
"
|
192 |
-
" arity, arg_types = check_examples(examples)\n",
|
193 |
"\n",
|
194 |
-
"
|
195 |
-
" constants = []\n",
|
196 |
"\n",
|
197 |
-
"
|
198 |
-
"
|
199 |
-
" inputs = set(inputs)\n",
|
200 |
"\n",
|
201 |
-
"
|
202 |
-
"
|
203 |
"\n",
|
204 |
-
"
|
205 |
-
"
|
206 |
"\n",
|
207 |
-
"
|
208 |
-
"
|
209 |
-
"
|
210 |
-
" # constants.append(StringConstant(input))\n",
|
211 |
-
" pass\n",
|
212 |
-
" else:\n",
|
213 |
-
" raise Exception(\"Input of unknown type.\")\n",
|
214 |
-
" \n",
|
215 |
-
" # initialize list of variables\n",
|
216 |
-
" variables = []\n",
|
217 |
"\n",
|
218 |
-
"
|
219 |
-
"
|
220 |
-
"
|
221 |
-
" variables.append(IntegerVariable(position))\n",
|
222 |
-
" elif arg == str:\n",
|
223 |
-
" # variables.append(StringVariable(position))\n",
|
224 |
-
" pass\n",
|
225 |
-
" else:\n",
|
226 |
-
" raise Exception(\"Input of unknown type.\")\n",
|
227 |
"\n",
|
228 |
-
"
|
229 |
-
|
230 |
-
},
|
231 |
-
{
|
232 |
-
"cell_type": "code",
|
233 |
-
"execution_count": null,
|
234 |
-
"metadata": {},
|
235 |
-
"outputs": [],
|
236 |
-
"source": [
|
237 |
-
"# initialize program bank\n",
|
238 |
-
"program_bank = extract_constants(examples)"
|
239 |
-
]
|
240 |
-
},
|
241 |
-
{
|
242 |
-
"cell_type": "markdown",
|
243 |
-
"metadata": {},
|
244 |
-
"source": [
|
245 |
-
"I define a function to determine observational equivalence."
|
246 |
-
]
|
247 |
-
},
|
248 |
-
{
|
249 |
-
"cell_type": "code",
|
250 |
-
"execution_count": null,
|
251 |
-
"metadata": {},
|
252 |
-
"outputs": [],
|
253 |
-
"source": [
|
254 |
-
"def observationally_equivalent(a, b):\n",
|
255 |
-
" \"\"\"\n",
|
256 |
-
" Returns True if a and b are observationally equivalent, False otherwise.\n",
|
257 |
-
" \"\"\"\n",
|
258 |
"\n",
|
259 |
-
"
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
]
|
268 |
},
|
269 |
{
|
270 |
"cell_type": "code",
|
271 |
-
"execution_count":
|
272 |
"metadata": {},
|
273 |
-
"outputs": [
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
]
|
|
|
285 |
}
|
286 |
],
|
287 |
"metadata": {
|
|
|
29 |
"from arithmetic import *\n",
|
30 |
"from abstract_syntax_tree import OperatorNode\n",
|
31 |
"from examples import example_set, check_examples\n",
|
32 |
+
"from synthesizer import extract_constants, observationally_equivalent, check_program\n",
|
33 |
"import config"
|
34 |
]
|
35 |
},
|
|
|
42 |
},
|
43 |
{
|
44 |
"cell_type": "code",
|
45 |
+
"execution_count": 2,
|
46 |
"metadata": {},
|
47 |
"outputs": [],
|
48 |
"source": [
|
|
|
56 |
"cell_type": "markdown",
|
57 |
"metadata": {},
|
58 |
"source": [
|
59 |
+
"I define a function to extract constants from examples."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
]
|
61 |
},
|
62 |
{
|
63 |
"cell_type": "code",
|
64 |
+
"execution_count": 3,
|
65 |
"metadata": {},
|
66 |
"outputs": [],
|
67 |
"source": [
|
68 |
+
"# initialize program bank\n",
|
69 |
+
"program_bank = extract_constants(examples)\n",
|
70 |
+
"program_bank_str = [p.str() for p in program_bank]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
]
|
72 |
},
|
73 |
{
|
74 |
"cell_type": "markdown",
|
75 |
"metadata": {},
|
76 |
"source": [
|
77 |
+
"Next, I define the bottom-up synthesis algorithm."
|
78 |
]
|
79 |
},
|
80 |
{
|
81 |
"cell_type": "code",
|
82 |
+
"execution_count": 4,
|
83 |
"metadata": {},
|
84 |
+
"outputs": [
|
85 |
+
{
|
86 |
+
"name": "stdout",
|
87 |
+
"output_type": "stream",
|
88 |
+
"text": [
|
89 |
+
"(x0 + x1)\n"
|
90 |
+
]
|
91 |
+
}
|
92 |
+
],
|
93 |
"source": [
|
94 |
+
"# define operators\n",
|
95 |
+
"operators = arithmetic_operators\n",
|
|
|
|
|
|
|
96 |
"\n",
|
97 |
+
"# iterate over each level\n",
|
98 |
+
"for weight in range(2, max_weight):\n",
|
|
|
99 |
"\n",
|
100 |
+
" for op in operators:\n",
|
|
|
101 |
"\n",
|
102 |
+
" # get all possible combinations of primitives in program bank\n",
|
103 |
+
" combinations = itertools.combinations(program_bank, op.arity)\n",
|
|
|
104 |
"\n",
|
105 |
+
" # iterate over each combination\n",
|
106 |
+
" for combination in combinations:\n",
|
107 |
"\n",
|
108 |
+
" # get type signature\n",
|
109 |
+
" type_signature = [p.type for p in combination]\n",
|
110 |
"\n",
|
111 |
+
" # check if type signature matches operator\n",
|
112 |
+
" if type_signature != op.arg_types:\n",
|
113 |
+
" continue\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
"\n",
|
115 |
+
" # check that sum of weights of arguments <= w\n",
|
116 |
+
" if sum([p.weight for p in combination]) > weight:\n",
|
117 |
+
" continue\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
"\n",
|
119 |
+
" # create new program\n",
|
120 |
+
" program = OperatorNode(op, combination)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
"\n",
|
122 |
+
" # check if program is in program bank using string representation\n",
|
123 |
+
" if program.str() in program_bank_str:\n",
|
124 |
+
" continue\n",
|
125 |
+
" \n",
|
126 |
+
" # check if program is observationally equivalent to any program in program bank\n",
|
127 |
+
" if any([observationally_equivalent(program, p, examples) for p in program_bank]):\n",
|
128 |
+
" continue\n",
|
129 |
+
"\n",
|
130 |
+
" # add program to program bank\n",
|
131 |
+
" program_bank.append(program)\n",
|
132 |
+
" program_bank_str.append(program.str())\n",
|
133 |
+
"\n",
|
134 |
+
" # check if program passes all examples\n",
|
135 |
+
" if check_program(program, examples):\n",
|
136 |
+
" # return(program)\n",
|
137 |
+
" print(program.str())"
|
138 |
]
|
139 |
},
|
140 |
{
|
141 |
"cell_type": "code",
|
142 |
+
"execution_count": 1,
|
143 |
"metadata": {},
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"data": {
|
147 |
+
"text/plain": [
|
148 |
+
"\"<class 'int'>\""
|
149 |
+
]
|
150 |
+
},
|
151 |
+
"execution_count": 1,
|
152 |
+
"metadata": {},
|
153 |
+
"output_type": "execute_result"
|
154 |
+
}
|
155 |
+
],
|
156 |
+
"source": []
|
157 |
}
|
158 |
],
|
159 |
"metadata": {
|
synthesizer.py
CHANGED
@@ -10,9 +10,12 @@ python synthesizer.py --domain arithmetic --examples addition
|
|
10 |
# load libraries
|
11 |
import numpy as np
|
12 |
import argparse
|
|
|
|
|
13 |
|
14 |
# import examples
|
15 |
from arithmetic import *
|
|
|
16 |
from examples import example_set, check_examples
|
17 |
import config
|
18 |
|
@@ -92,6 +95,32 @@ def extract_constants(examples):
|
|
92 |
return constants + variables
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
# RUN SYNTHESIZER
|
96 |
def run_synthesizer(args):
|
97 |
'''
|
@@ -103,9 +132,59 @@ def run_synthesizer(args):
|
|
103 |
|
104 |
# extract constants from examples
|
105 |
program_bank = extract_constants(examples)
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
|
|
109 |
|
110 |
|
111 |
if __name__ == '__main__':
|
@@ -115,4 +194,16 @@ if __name__ == '__main__':
|
|
115 |
# print(args)
|
116 |
|
117 |
# run bottom-up enumerative synthesis
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# load libraries
|
11 |
import numpy as np
|
12 |
import argparse
|
13 |
+
import itertools
|
14 |
+
import time
|
15 |
|
16 |
# import examples
|
17 |
from arithmetic import *
|
18 |
+
from abstract_syntax_tree import *
|
19 |
from examples import example_set, check_examples
|
20 |
import config
|
21 |
|
|
|
95 |
return constants + variables
|
96 |
|
97 |
|
98 |
+
# CHECK OBSERVATIONAL EQUIVALENCE
|
99 |
+
def observationally_equivalent(program_a, program_b, examples):
|
100 |
+
"""
|
101 |
+
Returns True if Program A and Program B are observationally equivalent, False otherwise.
|
102 |
+
"""
|
103 |
+
|
104 |
+
inputs = [example[0] for example in examples]
|
105 |
+
a_output = [program_a.evaluate(input) for input in inputs]
|
106 |
+
b_output = [program_b.evaluate(input) for input in inputs]
|
107 |
+
|
108 |
+
return a_output == b_output
|
109 |
+
|
110 |
+
|
111 |
+
# CHECK CORRECTNESS
|
112 |
+
def check_program(program, examples):
|
113 |
+
'''
|
114 |
+
Check whether the program satisfies the input-output examples.
|
115 |
+
'''
|
116 |
+
|
117 |
+
inputs = [example[0] for example in examples]
|
118 |
+
outputs = [example[1] for example in examples]
|
119 |
+
program_output = [program.evaluate(input) for input in inputs]
|
120 |
+
|
121 |
+
return program_output == outputs
|
122 |
+
|
123 |
+
|
124 |
# RUN SYNTHESIZER
|
125 |
def run_synthesizer(args):
|
126 |
'''
|
|
|
132 |
|
133 |
# extract constants from examples
|
134 |
program_bank = extract_constants(examples)
|
135 |
+
program_bank_str = [p.str() for p in program_bank]
|
136 |
+
print(f"- Extracted {len(program_bank)} constants from examples.")
|
137 |
+
|
138 |
+
# define operators
|
139 |
+
operators = arithmetic_operators
|
140 |
+
|
141 |
+
# iterate over each level
|
142 |
+
for weight in range(2, args.max_weight):
|
143 |
+
|
144 |
+
# print message
|
145 |
+
print(f"- Searching level {weight} with {len(program_bank)} primitives.")
|
146 |
+
|
147 |
+
# iterate over each operator
|
148 |
+
for op in operators:
|
149 |
+
|
150 |
+
# get all possible combinations of primitives in program bank
|
151 |
+
combinations = itertools.combinations(program_bank, op.arity)
|
152 |
+
|
153 |
+
# iterate over each combination
|
154 |
+
for combination in combinations:
|
155 |
+
|
156 |
+
# get type signature
|
157 |
+
type_signature = [p.type for p in combination]
|
158 |
+
|
159 |
+
# check if type signature matches operator
|
160 |
+
if type_signature != op.arg_types:
|
161 |
+
continue
|
162 |
+
|
163 |
+
# check that sum of weights of arguments <= w
|
164 |
+
if sum([p.weight for p in combination]) > weight:
|
165 |
+
continue
|
166 |
+
|
167 |
+
# create new program
|
168 |
+
program = OperatorNode(op, combination)
|
169 |
+
|
170 |
+
# check if program is in program bank using string representation
|
171 |
+
if program.str() in program_bank_str:
|
172 |
+
continue
|
173 |
+
|
174 |
+
# check if program is observationally equivalent to any program in program bank
|
175 |
+
if any([observationally_equivalent(program, p, examples) for p in program_bank]):
|
176 |
+
continue
|
177 |
+
|
178 |
+
# add program to program bank
|
179 |
+
program_bank.append(program)
|
180 |
+
program_bank_str.append(program.str())
|
181 |
+
|
182 |
+
# check if program passes all examples
|
183 |
+
if check_program(program, examples):
|
184 |
+
return(program)
|
185 |
|
186 |
+
# return None if no program is found
|
187 |
+
return None
|
188 |
|
189 |
|
190 |
if __name__ == '__main__':
|
|
|
194 |
# print(args)
|
195 |
|
196 |
# run bottom-up enumerative synthesis
|
197 |
+
start_time = time.time()
|
198 |
+
program = run_synthesizer(args)
|
199 |
+
end_time = time.time()
|
200 |
+
elapsed_time = round(end_time - start_time, 4)
|
201 |
+
|
202 |
+
# check if program was found
|
203 |
+
if program is None:
|
204 |
+
print(f"Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.")
|
205 |
+
else:
|
206 |
+
print(f"Program found in {elapsed_time}s.")
|
207 |
+
print(f"Program: {program.str()}")
|
208 |
+
print(f"Program weight: {program.weight}")
|
209 |
+
print(f"Program return type: {program.type.__name__}")
|