Spaces:
Sleeping
Sleeping
''' | |
BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP | |
Usage: | |
python rasp_synthesis.py --examples | |
''' | |
import numpy as np | |
import argparse | |
import itertools | |
import time | |
import ast | |
import re | |
from tracr.compiler import compiling | |
from typing import get_args | |
import inspect | |
import heapq | |
from abstract_syntax_tree import * | |
from python_embedded_rasp import * | |
# PARSE ARGUMENTS | |
def parse_args(): | |
''' | |
Parse command line arguments. | |
''' | |
parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.") | |
parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis") | |
parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.") | |
args = parser.parse_args() | |
return args | |
# ANALYZE EXAMPLES | |
def analyze_examples(inputs): | |
''' | |
Returns a list of unique (input_sequence, output_sequence) tuples of proper python types. | |
Ensures each example is only numeric values or only char values. | |
Returns useful constants given the input examples. | |
''' | |
example_ins = [] | |
example_outs = [] | |
try: | |
# Safely evaluate the string to a Python object | |
examples_lst = ast.literal_eval(inputs) | |
except (SyntaxError, ValueError) as e: | |
raise argparse.ArgumentTypeError(f"Invalid examples format: {e}") | |
if not isinstance(examples_lst, list): | |
raise ValueError("Input should be a list.") | |
for ex in examples_lst: | |
try: | |
ins, outs = ex[0], ex[1] | |
except: | |
raise argparse.ArgumentTypeError(f"Invalid examples format.") | |
def same_legal_type(lst): | |
return (all(isinstance(x, int) for x in lst) or | |
all(isinstance(x, float) for x in lst) or | |
all(isinstance(x, bool) for x in lst) or | |
all(isinstance(x, str) for x in lst)) | |
if same_legal_type(ins) and same_legal_type(outs): | |
example_ins.append(ins) | |
example_outs.append(outs) | |
continue | |
raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}") | |
return example_ins, example_outs | |
# GET VOCABULARY | |
def get_vocabulary(examples): | |
''' | |
Returns vocabulary for later compiling the RASP model. | |
''' | |
vocab = [] | |
for ex in examples: | |
ins, outs = ex[0], ex[1] | |
vocab.extend([obj for obj in ins]) | |
return set(vocab) | |
# CHECK OBSERVATIONAL EQUIVALENCE | |
def check_obs_equivalence(examples, program_a, program_b): | |
try: | |
inputs = [example[0] for example in examples] | |
a_output = None | |
b_output = None | |
if program_a not in rasp_consts: | |
a_output = [program_a.evaluate(input) for input in inputs] | |
if program_b not in rasp_consts: | |
b_output = [program_b.evaluate(input) for input in inputs] | |
except: | |
return True # force the synthesizer to not consider this program | |
return a_output == b_output | |
# CHECK CORRECTNESS | |
def check_correctness(examples, program): | |
''' | |
Checks if the programs output matches expected output on all examples. | |
''' | |
try: | |
inputs = [example[0] for example in examples] | |
outputs = [example[1] for example in examples] | |
program_output = [program.evaluate(input) for input in inputs] | |
except: | |
return False | |
is_correct = (program_output == outputs) | |
num_correct = sum([int(p == o) for p,o in zip(program_output, outputs)]) + \ | |
sum([int(type(i) == type(j)) for p,o in zip(program_output, outputs) for i,j in zip(p, o)]) | |
return is_correct, num_correct | |
# COMPARE TYPE SIGNATURES | |
def compare_types(list1, list2): | |
for idx, type1 in enumerate(list1): | |
if idx >= len(list2): | |
return False # The first list is longer than the second list | |
type2 = list2[idx] | |
# Check if type2 is a Union | |
if hasattr(type2, '__origin__') and type2.__origin__ is Union: | |
# Extract types from Union | |
types_in_union2 = get_args(type2) | |
# Check if type1 is a Union | |
if hasattr(type1, '__origin__') and type1.__origin__ is Union: | |
types_in_union1 = get_args(type1) | |
# Check if all types in type1's Union are in type2's Union | |
if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1): | |
return False | |
else: | |
# Check if type1 is in type2's Union | |
if not any(type1 == t2 for t2 in types_in_union2): | |
return False | |
else: | |
# Direct type comparison | |
if type1 != type2: | |
return False | |
return True | |
# RUN SYNTHESIZER | |
def run_synthesizer(examples, max_weight): | |
''' | |
Run bottom-up enumerative synthesis. | |
''' | |
program_bank = rasp_consts | |
program_bank_str = [p.str() for p in program_bank] | |
approx_progs = [] | |
# iterate over each level | |
for weight in range(2, max_weight): | |
for op in rasp_operators: | |
combinations = itertools.permutations(program_bank, op.n_args) | |
for combination in combinations: | |
type_signature = [p.return_type for p in combination] | |
if not compare_types(type_signature, op.arg_types): | |
continue | |
if sum([p.weight for p in combination]) > weight: | |
continue | |
program = OperatorNode(op, combination) | |
if program.str() in program_bank_str: | |
continue | |
if any([check_obs_equivalence(examples, program, p) for p in program_bank]): | |
continue | |
program_bank.append(program) | |
program_bank_str.append(program.str()) | |
is_correct, num_correct = check_correctness(examples, program) | |
if is_correct: | |
return(program), [ap[1] for ap in approx_progs] | |
if len(approx_progs) >= 3: | |
correct_cutoff, _prog = heapq.heappop(approx_progs) | |
if num_correct > correct_cutoff: | |
heapq.heappush(approx_progs, (num_correct, program.str())) | |
else: | |
heapq.heappush(approx_progs, (correct_cutoff, _prog)) | |
else: | |
heapq.heappush(approx_progs, (num_correct, program.str())) | |
return None, [ap[1] for ap in approx_progs] | |
# COMPILE RASP MODEL | |
if __name__ == "__main__": | |
''' | |
Some examples: | |
Identify anagrams: | |
[[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]] | |
Output: times out | |
Calculate the median of a list of numbers: | |
[[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]] | |
Output: times out | |
Identity function: | |
[[['h','i'], ['h','i']]] | |
Output: (aggregate((select(tokens, tokens, ==)), tokens)) | |
Histogram: | |
[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]] | |
Output: (select_width((select(tokens, tokens, ==)))) | |
Length: | |
[[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]] | |
Output: (select_width((select(tokens, tokens, true)))) | |
Calculate mean of list of numbers: | |
[[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]] | |
Output: (aggregate((select(tokens, tokens, true)), tokens)) | |
Reverse a string: | |
[[['h', 'i'], ['i', 'h']]] | |
Output: times out | |
Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens); | |
PERSONAL TODOS: | |
- output several similar programs | |
- | |
''' | |
args = parse_args() | |
inputs, outs = analyze_examples(args.examples) | |
examples = list(zip(inputs, outs)) | |
print("Received the following input and output examples:") | |
print(examples) | |
max_seq_len = 0 | |
for i in inputs: | |
max_seq_len = max(len(i), max_seq_len) | |
vocab = get_vocabulary(examples) | |
print("Running synthesizer with") | |
print("Vocab: {}".format(vocab)) | |
print("Max sequence length: {}".format(max_seq_len)) | |
print("Max weight: {}".format(args.max_weight)) | |
program, approx_programs = run_synthesizer(examples, args.max_weight) | |
if program: | |
algorithm = program.to_python() | |
bos = "BOS" | |
model = compiling.compile_rasp_to_model( | |
algorithm, | |
vocab=vocab, | |
max_seq_len=max_seq_len, | |
compiler_bos=bos, | |
) | |
def extract_layer_number(s): | |
match = re.search(r'layer_(\d+)', s) | |
if match: | |
return int(match.group(1)) + 1 | |
else: | |
return None | |
layer_num = extract_layer_number(list(model.params.keys())[-1]) | |
print(f"The following program has been compiled to a transformer with {layer_num} layer(s):") | |
print(program.str()) | |
else: | |
print("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs)) |