RASP-Synthesis / app.py
CSquid333's picture
made slider slightly more intuitive
275482b
import streamlit as st
import numpy as np
import argparse
import itertools
import time
import ast
import re
from tracr.compiler import compiling
from typing import get_args
import inspect
import pickle
import base64
from abstract_syntax_tree import *
from python_embedded_rasp import *
from rasp_synthesizer import *
# HELPER FUNCTIONS
def download_model(model):
output_model = pickle.dumps(model)
b64 = base64.b64encode(output_model).decode()
href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>'
st.markdown(href, unsafe_allow_html=True)
# APP DRIVER CODE
st.title("Bottom Up Synthesis for RASP")
max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 15)
default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]"
example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example)
inputs, outs = analyze_examples(example_text)
examples = list(zip(inputs, outs))
st.write("Received the following input and output examples:")
st.write(examples)
max_seq_len = 0
for i in inputs:
max_seq_len = max(len(i), max_seq_len)
vocab = get_vocabulary(examples)
st.subheader("Synthesis Configuration")
st.write("Running synthesizer with")
st.write("Vocab: {}".format(vocab))
st.write("Max sequence length: {}".format(max_seq_len))
st.write("Max weight: {}".format(max_weight))
program, approx_programs = run_synthesizer(examples, max_weight)
st.subheader("Synthesis Results:")
st.caption("May take a while.")
if program:
algorithm = program.to_python()
bos = "BOS"
model = compiling.compile_rasp_to_model(
algorithm,
vocab=vocab,
max_seq_len=max_seq_len,
compiler_bos=bos,
)
def extract_layer_number(s):
match = re.search(r'layer_(\d+)', s)
if match:
return int(match.group(1)) + 1
else:
return None
layer_num = extract_layer_number(list(model.params.keys())[-1])
st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
st.write(program.str())
st.write("Here is a model download link: ")
hk_model = model.params
download_model(hk_model)
else:
st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))