import streamlit as st import numpy as np import argparse import itertools import time import ast import re from tracr.compiler import compiling from typing import get_args import inspect import pickle import base64 from abstract_syntax_tree import * from python_embedded_rasp import * from rasp_synthesizer import * # HELPER FUNCTIONS def download_model(model): output_model = pickle.dumps(model) b64 = base64.b64encode(output_model).decode() href = f'Download Haiku Model Parameters in a .pkl File' st.markdown(href, unsafe_allow_html=True) # APP DRIVER CODE st.title("Bottom Up Synthesis for RASP") max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 1) default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]" example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example) inputs, outs = analyze_examples(example_text) examples = list(zip(inputs, outs)) st.write("Received the following input and output examples:") st.write(examples) max_seq_len = 0 for i in inputs: max_seq_len = max(len(i), max_seq_len) vocab = get_vocabulary(examples) st.subheader("Synthesis Configuration") st.write("Running synthesizer with") st.write("Vocab: {}".format(vocab)) st.write("Max sequence length: {}".format(max_seq_len)) st.write("Max weight: {}".format(max_weight)) program, approx_programs = run_synthesizer(examples, max_weight) st.subheader("Synthesis Results:") st.caption("May take a while.") if program: algorithm = program.to_python() bos = "BOS" model = compiling.compile_rasp_to_model( algorithm, vocab=vocab, max_seq_len=max_seq_len, compiler_bos=bos, ) def extract_layer_number(s): match = re.search(r'layer_(\d+)', s) if match: return int(match.group(1)) + 1 else: return None layer_num = extract_layer_number(list(model.params.keys())[-1]) st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):") st.write(program.str()) st.write("Here is a model download link: ") hk_model = model.params download_model(hk_model) else: st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))