ayushnoori commited on
Commit
5b04db9
1 Parent(s): 3872a55

Complete on arithmetic DSL

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. abstract_syntax_tree.py +4 -2
  3. arithmetic.py +2 -0
  4. demonstration.ipynb +67 -195
  5. synthesizer.py +94 -3
README.md CHANGED
@@ -1,7 +1,5 @@
1
  # Bottom-Up Enumerative Program Synthesis
2
 
3
- 🚨🚨PLEASE DO NOT GRADE YET🚨🚨
4
-
5
  Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
6
 
7
  ## 🛠️ Background
@@ -51,7 +49,9 @@ To add additional input-output examples, modify `examples.py`. Add a new key to
51
 
52
  ## 🔎 Abstract Syntax Tree
53
 
54
- The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands, and also to generate a string representation of the program.
 
 
55
 
56
  ## 🔮 Virtual Environment
57
 
 
1
  # Bottom-Up Enumerative Program Synthesis
2
 
 
 
3
  Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
4
 
5
  ## 🛠️ Background
 
49
 
50
  ## 🔎 Abstract Syntax Tree
51
 
52
+ The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands and also to generate a string representation of the program.
53
+
54
+ At program evaluation time, the AST is evaluated from the bottom up. That is, the operands are evaluated first, and then the operator is evaluated on the operands. This is implemented in the `evaluate` method of the `OperatorNode` class. In the case of integers, variable inputs are represented by the `IntegerVariable` class in `arithmetic.py`. When input is not `None`, input type checking and validation is performed by the `evaluate` function in this class.
55
 
56
  ## 🔮 Virtual Environment
57
 
abstract_syntax_tree.py CHANGED
@@ -24,10 +24,12 @@ class OperatorNode:
24
  add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
25
  multiply_node.evaluate([7]) # returns 24
26
  '''
27
-
28
  def __init__(self, operator, children):
29
- self.operator = operator # Operator object (e.g., Add, Subtract, etc.)
30
  self.children = children # list of children nodes (operands)
 
 
31
 
32
  def evaluate(self, input = None):
33
 
 
24
  add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
25
  multiply_node.evaluate([7]) # returns 24
26
  '''
27
+
28
  def __init__(self, operator, children):
29
+ self.operator = operator # operator object (e.g., Add, Subtract, etc.)
30
  self.children = children # list of children nodes (operands)
31
+ self.weight = operator.weight + sum([child.weight for child in children]) # weight of the node
32
+ self.type = operator.return_type # return type of the operator object
33
 
34
  def evaluate(self, input = None):
35
 
arithmetic.py CHANGED
@@ -16,6 +16,7 @@ class IntegerVariable:
16
  # self.value = None # value of the variable, initially None
17
  self.position = position # zero-indexed position of the variable in the arguments to program
18
  self.type = int # type of the variable
 
19
 
20
  # def assign(self, value):
21
  # self.value = value
@@ -57,6 +58,7 @@ class IntegerConstant:
57
  def __init__(self, value):
58
  self.value = value # value of the constant
59
  self.type = int # type of the constant
 
60
 
61
  def evaluate(self, input = None):
62
  return self.value
 
16
  # self.value = None # value of the variable, initially None
17
  self.position = position # zero-indexed position of the variable in the arguments to program
18
  self.type = int # type of the variable
19
+ self.weight = 1 # weight of the variable
20
 
21
  # def assign(self, value):
22
  # self.value = value
 
58
  def __init__(self, value):
59
  self.value = value # value of the constant
60
  self.type = int # type of the constant
61
+ self.weight = 1 # weight of the constant
62
 
63
  def evaluate(self, input = None):
64
  return self.value
demonstration.ipynb CHANGED
@@ -29,6 +29,7 @@
29
  "from arithmetic import *\n",
30
  "from abstract_syntax_tree import OperatorNode\n",
31
  "from examples import example_set, check_examples\n",
 
32
  "import config"
33
  ]
34
  },
@@ -41,7 +42,7 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": null,
45
  "metadata": {},
46
  "outputs": [],
47
  "source": [
@@ -55,233 +56,104 @@
55
  "cell_type": "markdown",
56
  "metadata": {},
57
  "source": [
58
- "First, I define a function to check that, across all input-output pairs, all inputs are of the same length and that argument types are consistent across inputs."
59
- ]
60
- },
61
- {
62
- "cell_type": "markdown",
63
- "metadata": {},
64
- "source": [
65
- "I provide examples of arithmetic operations."
66
  ]
67
  },
68
  {
69
  "cell_type": "code",
70
- "execution_count": null,
71
  "metadata": {},
72
  "outputs": [],
73
  "source": [
74
- "class IntegerVariable:\n",
75
- " '''\n",
76
- " Class to represent an integer variable. Note that position is the position of the variable in the input.\n",
77
- " For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.\n",
78
- " '''\n",
79
- " def __init__(self, position):\n",
80
- " self.value = None # value of the variable, initially None\n",
81
- " self.position = position # position of the variable in the arguments to program\n",
82
- " self.type = int # type of the variable\n",
83
- "\n",
84
- " def assign(self, value):\n",
85
- " self.value = value\n",
86
- "\n",
87
- "class IntegerConstant:\n",
88
- " '''\n",
89
- " Class to represent an integer constant.\n",
90
- " '''\n",
91
- " def __init__(self, value):\n",
92
- " self.value = value # value of the constant\n",
93
- " self.type = int # type of the constant\n",
94
- "\n",
95
- "class Add:\n",
96
- " '''\n",
97
- " Operator to add two numerical values.\n",
98
- " '''\n",
99
- " def __init__(self):\n",
100
- " self.arity = 2 # number of arguments\n",
101
- " self.arg_types = [int, int] # argument types\n",
102
- " self.return_type = int # return type\n",
103
- " self.weight = 1 # weight\n",
104
- "\n",
105
- " def __call__(self, x, y):\n",
106
- " return x + y\n",
107
- " \n",
108
- " def str(x, y):\n",
109
- " return f\"{x} + {y}\"\n",
110
- "\n",
111
- "class Subtract:\n",
112
- " '''\n",
113
- " Operator to subtract two numerical values.\n",
114
- " '''\n",
115
- " def __init__(self):\n",
116
- " self.arity = 2 # number of arguments\n",
117
- " self.arg_types = [int, int] # argument types\n",
118
- " self.return_type = int # return type\n",
119
- " self.weight = 1 # weight\n",
120
- "\n",
121
- " def __call__(self, x, y):\n",
122
- " return x - y\n",
123
- " \n",
124
- " def str(x, y):\n",
125
- " return f\"{x} - {y}\"\n",
126
- " \n",
127
- "class Multiply:\n",
128
- " '''\n",
129
- " Operator to multiply two numerical values.\n",
130
- " '''\n",
131
- " def __init__(self):\n",
132
- " self.arity = 2 # number of arguments\n",
133
- " self.arg_types = [int, int] # argument types\n",
134
- " self.return_type = int # return type\n",
135
- " self.weight = 1 # weight\n",
136
- "\n",
137
- " def __call__(self, x, y):\n",
138
- " return x * y\n",
139
- " \n",
140
- " def str(x, y):\n",
141
- " return f\"{x} * {y}\" \n",
142
- "\n",
143
- "class Divide:\n",
144
- " '''\n",
145
- " Operator to divide two numerical values.\n",
146
- " '''\n",
147
- " def __init__(self):\n",
148
- " self.arity = 2 # number of arguments\n",
149
- " self.arg_types = [int, int] # argument types\n",
150
- " self.return_type = int # return type\n",
151
- " self.weight = 1 # weight\n",
152
- "\n",
153
- " def __call__(self, x, y):\n",
154
- " try: # check for division by zero error\n",
155
- " return x / y\n",
156
- " except ZeroDivisionError:\n",
157
- " return None\n",
158
- " \n",
159
- " def str(x, y):\n",
160
- " return f\"{x} / {y}\"\n",
161
- "\n",
162
- "\n",
163
- "'''\n",
164
- "GLOBAL CONSTANTS\n",
165
- "''' \n",
166
- "\n",
167
- "# define operators\n",
168
- "arithmetic_operators = [Add(), Subtract(), Multiply(), Divide()]"
169
  ]
170
  },
171
  {
172
  "cell_type": "markdown",
173
  "metadata": {},
174
  "source": [
175
- "I define a function to extract constants from examples."
176
  ]
177
  },
178
  {
179
  "cell_type": "code",
180
- "execution_count": null,
181
  "metadata": {},
182
- "outputs": [],
 
 
 
 
 
 
 
 
183
  "source": [
184
- "def extract_constants(examples):\n",
185
- " '''\n",
186
- " Extracts the constants from the input-output examples. Also constructs variables as needed\n",
187
- " based on the input-output examples, and adds them to the list of constants.\n",
188
- " '''\n",
189
  "\n",
190
- " # check validity of provided examples\n",
191
- " # if valid, extract arity and argument types\n",
192
- " arity, arg_types = check_examples(examples)\n",
193
  "\n",
194
- " # initialize list of constants\n",
195
- " constants = []\n",
196
  "\n",
197
- " # get unique set of inputs\n",
198
- " inputs = [input for example in examples for input in example[0]]\n",
199
- " inputs = set(inputs)\n",
200
  "\n",
201
- " # add 1 to the set of inputs\n",
202
- " inputs.add(1)\n",
203
  "\n",
204
- " # extract constants in input\n",
205
- " for input in inputs:\n",
206
  "\n",
207
- " if type(input) == int:\n",
208
- " constants.append(IntegerConstant(input))\n",
209
- " elif type(input) == str:\n",
210
- " # constants.append(StringConstant(input))\n",
211
- " pass\n",
212
- " else:\n",
213
- " raise Exception(\"Input of unknown type.\")\n",
214
- " \n",
215
- " # initialize list of variables\n",
216
- " variables = []\n",
217
  "\n",
218
- " # extract variables in input\n",
219
- " for position, arg in enumerate(arg_types):\n",
220
- " if arg == int:\n",
221
- " variables.append(IntegerVariable(position))\n",
222
- " elif arg == str:\n",
223
- " # variables.append(StringVariable(position))\n",
224
- " pass\n",
225
- " else:\n",
226
- " raise Exception(\"Input of unknown type.\")\n",
227
  "\n",
228
- " return constants + variables"
229
- ]
230
- },
231
- {
232
- "cell_type": "code",
233
- "execution_count": null,
234
- "metadata": {},
235
- "outputs": [],
236
- "source": [
237
- "# initialize program bank\n",
238
- "program_bank = extract_constants(examples)"
239
- ]
240
- },
241
- {
242
- "cell_type": "markdown",
243
- "metadata": {},
244
- "source": [
245
- "I define a function to determine observational equivalence."
246
- ]
247
- },
248
- {
249
- "cell_type": "code",
250
- "execution_count": null,
251
- "metadata": {},
252
- "outputs": [],
253
- "source": [
254
- "def observationally_equivalent(a, b):\n",
255
- " \"\"\"\n",
256
- " Returns True if a and b are observationally equivalent, False otherwise.\n",
257
- " \"\"\"\n",
258
  "\n",
259
- " pass"
260
- ]
261
- },
262
- {
263
- "cell_type": "markdown",
264
- "metadata": {},
265
- "source": [
266
- "Next, I define the bottom-up synthesis algorithm."
 
 
 
 
 
 
 
 
267
  ]
268
  },
269
  {
270
  "cell_type": "code",
271
- "execution_count": null,
272
  "metadata": {},
273
- "outputs": [],
274
- "source": [
275
- "# iterate over each level\n",
276
- "for i in range(2, max_weight):\n",
277
- "\n",
278
- " # define level program bank\n",
279
- " level_program_bank = []\n",
280
- "\n",
281
- " for op in arithmetic_operators:\n",
282
- "\n",
283
- " break"
284
- ]
 
285
  }
286
  ],
287
  "metadata": {
 
29
  "from arithmetic import *\n",
30
  "from abstract_syntax_tree import OperatorNode\n",
31
  "from examples import example_set, check_examples\n",
32
+ "from synthesizer import extract_constants, observationally_equivalent, check_program\n",
33
  "import config"
34
  ]
35
  },
 
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": 2,
46
  "metadata": {},
47
  "outputs": [],
48
  "source": [
 
56
  "cell_type": "markdown",
57
  "metadata": {},
58
  "source": [
59
+ "I define a function to extract constants from examples."
 
 
 
 
 
 
 
60
  ]
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 3,
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
+ "# initialize program bank\n",
69
+ "program_bank = extract_constants(examples)\n",
70
+ "program_bank_str = [p.str() for p in program_bank]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ]
72
  },
73
  {
74
  "cell_type": "markdown",
75
  "metadata": {},
76
  "source": [
77
+ "Next, I define the bottom-up synthesis algorithm."
78
  ]
79
  },
80
  {
81
  "cell_type": "code",
82
+ "execution_count": 4,
83
  "metadata": {},
84
+ "outputs": [
85
+ {
86
+ "name": "stdout",
87
+ "output_type": "stream",
88
+ "text": [
89
+ "(x0 + x1)\n"
90
+ ]
91
+ }
92
+ ],
93
  "source": [
94
+ "# define operators\n",
95
+ "operators = arithmetic_operators\n",
 
 
 
96
  "\n",
97
+ "# iterate over each level\n",
98
+ "for weight in range(2, max_weight):\n",
 
99
  "\n",
100
+ " for op in operators:\n",
 
101
  "\n",
102
+ " # get all possible combinations of primitives in program bank\n",
103
+ " combinations = itertools.combinations(program_bank, op.arity)\n",
 
104
  "\n",
105
+ " # iterate over each combination\n",
106
+ " for combination in combinations:\n",
107
  "\n",
108
+ " # get type signature\n",
109
+ " type_signature = [p.type for p in combination]\n",
110
  "\n",
111
+ " # check if type signature matches operator\n",
112
+ " if type_signature != op.arg_types:\n",
113
+ " continue\n",
 
 
 
 
 
 
 
114
  "\n",
115
+ " # check that sum of weights of arguments <= w\n",
116
+ " if sum([p.weight for p in combination]) > weight:\n",
117
+ " continue\n",
 
 
 
 
 
 
118
  "\n",
119
+ " # create new program\n",
120
+ " program = OperatorNode(op, combination)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  "\n",
122
+ " # check if program is in program bank using string representation\n",
123
+ " if program.str() in program_bank_str:\n",
124
+ " continue\n",
125
+ " \n",
126
+ " # check if program is observationally equivalent to any program in program bank\n",
127
+ " if any([observationally_equivalent(program, p, examples) for p in program_bank]):\n",
128
+ " continue\n",
129
+ "\n",
130
+ " # add program to program bank\n",
131
+ " program_bank.append(program)\n",
132
+ " program_bank_str.append(program.str())\n",
133
+ "\n",
134
+ " # check if program passes all examples\n",
135
+ " if check_program(program, examples):\n",
136
+ " # return(program)\n",
137
+ " print(program.str())"
138
  ]
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": 1,
143
  "metadata": {},
144
+ "outputs": [
145
+ {
146
+ "data": {
147
+ "text/plain": [
148
+ "\"<class 'int'>\""
149
+ ]
150
+ },
151
+ "execution_count": 1,
152
+ "metadata": {},
153
+ "output_type": "execute_result"
154
+ }
155
+ ],
156
+ "source": []
157
  }
158
  ],
159
  "metadata": {
synthesizer.py CHANGED
@@ -10,9 +10,12 @@ python synthesizer.py --domain arithmetic --examples addition
10
  # load libraries
11
  import numpy as np
12
  import argparse
 
 
13
 
14
  # import examples
15
  from arithmetic import *
 
16
  from examples import example_set, check_examples
17
  import config
18
 
@@ -92,6 +95,32 @@ def extract_constants(examples):
92
  return constants + variables
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # RUN SYNTHESIZER
96
  def run_synthesizer(args):
97
  '''
@@ -103,9 +132,59 @@ def run_synthesizer(args):
103
 
104
  # extract constants from examples
105
  program_bank = extract_constants(examples)
106
- print(examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- pass
 
109
 
110
 
111
  if __name__ == '__main__':
@@ -115,4 +194,16 @@ if __name__ == '__main__':
115
  # print(args)
116
 
117
  # run bottom-up enumerative synthesis
118
- run_synthesizer(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # load libraries
11
  import numpy as np
12
  import argparse
13
+ import itertools
14
+ import time
15
 
16
  # import examples
17
  from arithmetic import *
18
+ from abstract_syntax_tree import *
19
  from examples import example_set, check_examples
20
  import config
21
 
 
95
  return constants + variables
96
 
97
 
98
+ # CHECK OBSERVATIONAL EQUIVALENCE
99
+ def observationally_equivalent(program_a, program_b, examples):
100
+ """
101
+ Returns True if Program A and Program B are observationally equivalent, False otherwise.
102
+ """
103
+
104
+ inputs = [example[0] for example in examples]
105
+ a_output = [program_a.evaluate(input) for input in inputs]
106
+ b_output = [program_b.evaluate(input) for input in inputs]
107
+
108
+ return a_output == b_output
109
+
110
+
111
+ # CHECK CORRECTNESS
112
+ def check_program(program, examples):
113
+ '''
114
+ Check whether the program satisfies the input-output examples.
115
+ '''
116
+
117
+ inputs = [example[0] for example in examples]
118
+ outputs = [example[1] for example in examples]
119
+ program_output = [program.evaluate(input) for input in inputs]
120
+
121
+ return program_output == outputs
122
+
123
+
124
  # RUN SYNTHESIZER
125
  def run_synthesizer(args):
126
  '''
 
132
 
133
  # extract constants from examples
134
  program_bank = extract_constants(examples)
135
+ program_bank_str = [p.str() for p in program_bank]
136
+ print(f"- Extracted {len(program_bank)} constants from examples.")
137
+
138
+ # define operators
139
+ operators = arithmetic_operators
140
+
141
+ # iterate over each level
142
+ for weight in range(2, args.max_weight):
143
+
144
+ # print message
145
+ print(f"- Searching level {weight} with {len(program_bank)} primitives.")
146
+
147
+ # iterate over each operator
148
+ for op in operators:
149
+
150
+ # get all possible combinations of primitives in program bank
151
+ combinations = itertools.combinations(program_bank, op.arity)
152
+
153
+ # iterate over each combination
154
+ for combination in combinations:
155
+
156
+ # get type signature
157
+ type_signature = [p.type for p in combination]
158
+
159
+ # check if type signature matches operator
160
+ if type_signature != op.arg_types:
161
+ continue
162
+
163
+ # check that sum of weights of arguments <= w
164
+ if sum([p.weight for p in combination]) > weight:
165
+ continue
166
+
167
+ # create new program
168
+ program = OperatorNode(op, combination)
169
+
170
+ # check if program is in program bank using string representation
171
+ if program.str() in program_bank_str:
172
+ continue
173
+
174
+ # check if program is observationally equivalent to any program in program bank
175
+ if any([observationally_equivalent(program, p, examples) for p in program_bank]):
176
+ continue
177
+
178
+ # add program to program bank
179
+ program_bank.append(program)
180
+ program_bank_str.append(program.str())
181
+
182
+ # check if program passes all examples
183
+ if check_program(program, examples):
184
+ return(program)
185
 
186
+ # return None if no program is found
187
+ return None
188
 
189
 
190
  if __name__ == '__main__':
 
194
  # print(args)
195
 
196
  # run bottom-up enumerative synthesis
197
+ start_time = time.time()
198
+ program = run_synthesizer(args)
199
+ end_time = time.time()
200
+ elapsed_time = round(end_time - start_time, 4)
201
+
202
+ # check if program was found
203
+ if program is None:
204
+ print(f"Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.")
205
+ else:
206
+ print(f"Program found in {elapsed_time}s.")
207
+ print(f"Program: {program.str()}")
208
+ print(f"Program weight: {program.weight}")
209
+ print(f"Program return type: {program.type.__name__}")