RASP-Synthesis / rasp_synthesizer.py
CSquid333's picture
made an app for the synthesizer
f0b559a
'''
BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP
Usage:
python rasp_synthesis.py --examples
'''
import numpy as np
import argparse
import itertools
import time
import ast
import re
from tracr.compiler import compiling
from typing import get_args
import inspect
import heapq
from abstract_syntax_tree import *
from python_embedded_rasp import *
# PARSE ARGUMENTS
def parse_args():
'''
Parse command line arguments.
'''
parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.")
parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis")
parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.")
args = parser.parse_args()
return args
# ANALYZE EXAMPLES
def analyze_examples(inputs):
'''
Returns a list of unique (input_sequence, output_sequence) tuples of proper python types.
Ensures each example is only numeric values or only char values.
Returns useful constants given the input examples.
'''
example_ins = []
example_outs = []
try:
# Safely evaluate the string to a Python object
examples_lst = ast.literal_eval(inputs)
except (SyntaxError, ValueError) as e:
raise argparse.ArgumentTypeError(f"Invalid examples format: {e}")
if not isinstance(examples_lst, list):
raise ValueError("Input should be a list.")
for ex in examples_lst:
try:
ins, outs = ex[0], ex[1]
except:
raise argparse.ArgumentTypeError(f"Invalid examples format.")
def same_legal_type(lst):
return (all(isinstance(x, int) for x in lst) or
all(isinstance(x, float) for x in lst) or
all(isinstance(x, bool) for x in lst) or
all(isinstance(x, str) for x in lst))
if same_legal_type(ins) and same_legal_type(outs):
example_ins.append(ins)
example_outs.append(outs)
continue
raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}")
return example_ins, example_outs
# GET VOCABULARY
def get_vocabulary(examples):
'''
Returns vocabulary for later compiling the RASP model.
'''
vocab = []
for ex in examples:
ins, outs = ex[0], ex[1]
vocab.extend([obj for obj in ins])
return set(vocab)
# CHECK OBSERVATIONAL EQUIVALENCE
def check_obs_equivalence(examples, program_a, program_b):
try:
inputs = [example[0] for example in examples]
a_output = None
b_output = None
if program_a not in rasp_consts:
a_output = [program_a.evaluate(input) for input in inputs]
if program_b not in rasp_consts:
b_output = [program_b.evaluate(input) for input in inputs]
except:
return True # force the synthesizer to not consider this program
return a_output == b_output
# CHECK CORRECTNESS
def check_correctness(examples, program):
'''
Checks if the programs output matches expected output on all examples.
'''
try:
inputs = [example[0] for example in examples]
outputs = [example[1] for example in examples]
program_output = [program.evaluate(input) for input in inputs]
except:
return False
is_correct = (program_output == outputs)
num_correct = sum([int(p == o) for p,o in zip(program_output, outputs)]) + \
sum([int(type(i) == type(j)) for p,o in zip(program_output, outputs) for i,j in zip(p, o)])
return is_correct, num_correct
# COMPARE TYPE SIGNATURES
def compare_types(list1, list2):
for idx, type1 in enumerate(list1):
if idx >= len(list2):
return False # The first list is longer than the second list
type2 = list2[idx]
# Check if type2 is a Union
if hasattr(type2, '__origin__') and type2.__origin__ is Union:
# Extract types from Union
types_in_union2 = get_args(type2)
# Check if type1 is a Union
if hasattr(type1, '__origin__') and type1.__origin__ is Union:
types_in_union1 = get_args(type1)
# Check if all types in type1's Union are in type2's Union
if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1):
return False
else:
# Check if type1 is in type2's Union
if not any(type1 == t2 for t2 in types_in_union2):
return False
else:
# Direct type comparison
if type1 != type2:
return False
return True
# RUN SYNTHESIZER
def run_synthesizer(examples, max_weight):
'''
Run bottom-up enumerative synthesis.
'''
program_bank = rasp_consts
program_bank_str = [p.str() for p in program_bank]
approx_progs = []
# iterate over each level
for weight in range(2, max_weight):
for op in rasp_operators:
combinations = itertools.permutations(program_bank, op.n_args)
for combination in combinations:
type_signature = [p.return_type for p in combination]
if not compare_types(type_signature, op.arg_types):
continue
if sum([p.weight for p in combination]) > weight:
continue
program = OperatorNode(op, combination)
if program.str() in program_bank_str:
continue
if any([check_obs_equivalence(examples, program, p) for p in program_bank]):
continue
program_bank.append(program)
program_bank_str.append(program.str())
is_correct, num_correct = check_correctness(examples, program)
if is_correct:
return(program), [ap[1] for ap in approx_progs]
if len(approx_progs) >= 3:
correct_cutoff, _prog = heapq.heappop(approx_progs)
if num_correct > correct_cutoff:
heapq.heappush(approx_progs, (num_correct, program.str()))
else:
heapq.heappush(approx_progs, (correct_cutoff, _prog))
else:
heapq.heappush(approx_progs, (num_correct, program.str()))
return None, [ap[1] for ap in approx_progs]
# COMPILE RASP MODEL
if __name__ == "__main__":
'''
Some examples:
Identify anagrams:
[[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]]
Output: times out
Calculate the median of a list of numbers:
[[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]]
Output: times out
Identity function:
[[['h','i'], ['h','i']]]
Output: (aggregate((select(tokens, tokens, ==)), tokens))
Histogram:
[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]
Output: (select_width((select(tokens, tokens, ==))))
Length:
[[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]]
Output: (select_width((select(tokens, tokens, true))))
Calculate mean of list of numbers:
[[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]]
Output: (aggregate((select(tokens, tokens, true)), tokens))
Reverse a string:
[[['h', 'i'], ['i', 'h']]]
Output: times out
Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens);
PERSONAL TODOS:
- output several similar programs
-
'''
args = parse_args()
inputs, outs = analyze_examples(args.examples)
examples = list(zip(inputs, outs))
print("Received the following input and output examples:")
print(examples)
max_seq_len = 0
for i in inputs:
max_seq_len = max(len(i), max_seq_len)
vocab = get_vocabulary(examples)
print("Running synthesizer with")
print("Vocab: {}".format(vocab))
print("Max sequence length: {}".format(max_seq_len))
print("Max weight: {}".format(args.max_weight))
program, approx_programs = run_synthesizer(examples, args.max_weight)
if program:
algorithm = program.to_python()
bos = "BOS"
model = compiling.compile_rasp_to_model(
algorithm,
vocab=vocab,
max_seq_len=max_seq_len,
compiler_bos=bos,
)
def extract_layer_number(s):
match = re.search(r'layer_(\d+)', s)
if match:
return int(match.group(1)) + 1
else:
return None
layer_num = extract_layer_number(list(model.params.keys())[-1])
print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
print(program.str())
else:
print("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))