Spaces:
Sleeping
Sleeping
| ''' | |
| BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP | |
| Usage: | |
| python rasp_synthesis.py --examples | |
| ''' | |
| import numpy as np | |
| import argparse | |
| import itertools | |
| import time | |
| import ast | |
| import re | |
| from tracr.compiler import compiling | |
| from typing import get_args | |
| import inspect | |
| import heapq | |
| from abstract_syntax_tree import * | |
| from python_embedded_rasp import * | |
| # PARSE ARGUMENTS | |
| def parse_args(): | |
| ''' | |
| Parse command line arguments. | |
| ''' | |
| parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.") | |
| parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis") | |
| parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.") | |
| args = parser.parse_args() | |
| return args | |
| # ANALYZE EXAMPLES | |
| def analyze_examples(inputs): | |
| ''' | |
| Returns a list of unique (input_sequence, output_sequence) tuples of proper python types. | |
| Ensures each example is only numeric values or only char values. | |
| Returns useful constants given the input examples. | |
| ''' | |
| example_ins = [] | |
| example_outs = [] | |
| try: | |
| # Safely evaluate the string to a Python object | |
| examples_lst = ast.literal_eval(inputs) | |
| except (SyntaxError, ValueError) as e: | |
| raise argparse.ArgumentTypeError(f"Invalid examples format: {e}") | |
| if not isinstance(examples_lst, list): | |
| raise ValueError("Input should be a list.") | |
| for ex in examples_lst: | |
| try: | |
| ins, outs = ex[0], ex[1] | |
| except: | |
| raise argparse.ArgumentTypeError(f"Invalid examples format.") | |
| def same_legal_type(lst): | |
| return (all(isinstance(x, int) for x in lst) or | |
| all(isinstance(x, float) for x in lst) or | |
| all(isinstance(x, bool) for x in lst) or | |
| all(isinstance(x, str) for x in lst)) | |
| if same_legal_type(ins) and same_legal_type(outs): | |
| example_ins.append(ins) | |
| example_outs.append(outs) | |
| continue | |
| raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}") | |
| return example_ins, example_outs | |
| # GET VOCABULARY | |
| def get_vocabulary(examples): | |
| ''' | |
| Returns vocabulary for later compiling the RASP model. | |
| ''' | |
| vocab = [] | |
| for ex in examples: | |
| ins, outs = ex[0], ex[1] | |
| vocab.extend([obj for obj in ins]) | |
| return set(vocab) | |
| # CHECK OBSERVATIONAL EQUIVALENCE | |
| def check_obs_equivalence(examples, program_a, program_b): | |
| try: | |
| inputs = [example[0] for example in examples] | |
| a_output = None | |
| b_output = None | |
| if program_a not in rasp_consts: | |
| a_output = [program_a.evaluate(input) for input in inputs] | |
| if program_b not in rasp_consts: | |
| b_output = [program_b.evaluate(input) for input in inputs] | |
| except: | |
| return True # force the synthesizer to not consider this program | |
| return a_output == b_output | |
| # CHECK CORRECTNESS | |
| def check_correctness(examples, program): | |
| ''' | |
| Checks if the programs output matches expected output on all examples. | |
| ''' | |
| try: | |
| inputs = [example[0] for example in examples] | |
| outputs = [example[1] for example in examples] | |
| program_output = [program.evaluate(input) for input in inputs] | |
| except: | |
| return False | |
| is_correct = (program_output == outputs) | |
| num_correct = sum([int(p == o) for p,o in zip(program_output, outputs)]) + \ | |
| sum([int(type(i) == type(j)) for p,o in zip(program_output, outputs) for i,j in zip(p, o)]) | |
| return is_correct, num_correct | |
| # COMPARE TYPE SIGNATURES | |
| def compare_types(list1, list2): | |
| for idx, type1 in enumerate(list1): | |
| if idx >= len(list2): | |
| return False # The first list is longer than the second list | |
| type2 = list2[idx] | |
| # Check if type2 is a Union | |
| if hasattr(type2, '__origin__') and type2.__origin__ is Union: | |
| # Extract types from Union | |
| types_in_union2 = get_args(type2) | |
| # Check if type1 is a Union | |
| if hasattr(type1, '__origin__') and type1.__origin__ is Union: | |
| types_in_union1 = get_args(type1) | |
| # Check if all types in type1's Union are in type2's Union | |
| if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1): | |
| return False | |
| else: | |
| # Check if type1 is in type2's Union | |
| if not any(type1 == t2 for t2 in types_in_union2): | |
| return False | |
| else: | |
| # Direct type comparison | |
| if type1 != type2: | |
| return False | |
| return True | |
| # RUN SYNTHESIZER | |
| def run_synthesizer(examples, max_weight): | |
| ''' | |
| Run bottom-up enumerative synthesis. | |
| ''' | |
| program_bank = rasp_consts | |
| program_bank_str = [p.str() for p in program_bank] | |
| approx_progs = [] | |
| # iterate over each level | |
| for weight in range(2, max_weight): | |
| for op in rasp_operators: | |
| combinations = itertools.permutations(program_bank, op.n_args) | |
| for combination in combinations: | |
| type_signature = [p.return_type for p in combination] | |
| if not compare_types(type_signature, op.arg_types): | |
| continue | |
| if sum([p.weight for p in combination]) > weight: | |
| continue | |
| program = OperatorNode(op, combination) | |
| if program.str() in program_bank_str: | |
| continue | |
| if any([check_obs_equivalence(examples, program, p) for p in program_bank]): | |
| continue | |
| program_bank.append(program) | |
| program_bank_str.append(program.str()) | |
| is_correct, num_correct = check_correctness(examples, program) | |
| if is_correct: | |
| return(program), [ap[1] for ap in approx_progs] | |
| if len(approx_progs) >= 3: | |
| correct_cutoff, _prog = heapq.heappop(approx_progs) | |
| if num_correct > correct_cutoff: | |
| heapq.heappush(approx_progs, (num_correct, program.str())) | |
| else: | |
| heapq.heappush(approx_progs, (correct_cutoff, _prog)) | |
| else: | |
| heapq.heappush(approx_progs, (num_correct, program.str())) | |
| return None, [ap[1] for ap in approx_progs] | |
| # COMPILE RASP MODEL | |
| if __name__ == "__main__": | |
| ''' | |
| Some examples: | |
| Identify anagrams: | |
| [[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]] | |
| Output: times out | |
| Calculate the median of a list of numbers: | |
| [[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]] | |
| Output: times out | |
| Identity function: | |
| [[['h','i'], ['h','i']]] | |
| Output: (aggregate((select(tokens, tokens, ==)), tokens)) | |
| Histogram: | |
| [[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]] | |
| Output: (select_width((select(tokens, tokens, ==)))) | |
| Length: | |
| [[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]] | |
| Output: (select_width((select(tokens, tokens, true)))) | |
| Calculate mean of list of numbers: | |
| [[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]] | |
| Output: (aggregate((select(tokens, tokens, true)), tokens)) | |
| Reverse a string: | |
| [[['h', 'i'], ['i', 'h']]] | |
| Output: times out | |
| Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens); | |
| PERSONAL TODOS: | |
| - output several similar programs | |
| - | |
| ''' | |
| args = parse_args() | |
| inputs, outs = analyze_examples(args.examples) | |
| examples = list(zip(inputs, outs)) | |
| print("Received the following input and output examples:") | |
| print(examples) | |
| max_seq_len = 0 | |
| for i in inputs: | |
| max_seq_len = max(len(i), max_seq_len) | |
| vocab = get_vocabulary(examples) | |
| print("Running synthesizer with") | |
| print("Vocab: {}".format(vocab)) | |
| print("Max sequence length: {}".format(max_seq_len)) | |
| print("Max weight: {}".format(args.max_weight)) | |
| program, approx_programs = run_synthesizer(examples, args.max_weight) | |
| if program: | |
| algorithm = program.to_python() | |
| bos = "BOS" | |
| model = compiling.compile_rasp_to_model( | |
| algorithm, | |
| vocab=vocab, | |
| max_seq_len=max_seq_len, | |
| compiler_bos=bos, | |
| ) | |
| def extract_layer_number(s): | |
| match = re.search(r'layer_(\d+)', s) | |
| if match: | |
| return int(match.group(1)) + 1 | |
| else: | |
| return None | |
| layer_num = extract_layer_number(list(model.params.keys())[-1]) | |
| print(f"The following program has been compiled to a transformer with {layer_num} layer(s):") | |
| print(program.str()) | |
| else: | |
| print("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs)) |